Day 10: Decision Trees – Learning Rules from Data!

Saket KhopkarSaket Khopkar
4 min read

Imagine you are heading outside for a picnic in the park. But your decision is a bit dependent on weather forecast as sudden heavy rains have been predicted.

Now, for your decisions, you may have to look at following:

Q1) Is it raining?

Yes → Take an umbrella

No → No umbrella needed

Q2) Is it too hot?

Yes → Wear light clothes

No → Dress normally

Basically this is how a decision tree works, it splits the data in questions, or shall I say conditions. We have a look at all the conditions and then make our way towards final decision.

💡
Decision trees are used for both of the scenarios we have seen in earlier blogs: Classification (e.g., "Will a customer buy a product?" → Yes/No) & Regression (e.g., "What will be the price of a house?" → A continuous value)

Understanding the Decision Tree

A Decision Tree is made up of:

  • Root Node: The first decision point (e.g., "Is the salary above 50K?")

  • Branches: Possible answers (e.g., "Yes" or "No")

  • Internal Nodes: More decision points (e.g., "Is the person’s age above 30?")

  • Leaf Nodes: The final decision (e.g., "Will buy" or "Won’t buy")

A simple example: Let's predict if a person will buy a car based on age and salary.

Q1) Is the salary above $50,000?

Yes → Proceed to next level / node

No → Won’t Buy

Q2) Is age above 30?

Yes → Likely to buy

No → Won’t Buy

If you put forth the same scenario as in above diagram, it forms a decision tree. A typical tree like structure.


Gini Impurity & Information Gain (Simplified)

A Decision Tree splits data at each step, but the question is how does it choose the best split?

It uses 2 things:

  • Gini Impurity (Measures "impurity" of a group; lower is better)

  • Entropy & Information Gain (Measures "disorder" in data; higher gain is better)

Gini impurity values range from 0 to 0.5. 0 stands for perfect purity i.e. All instances in the node belong to the same class. Whereas 0.5 which stands for Maximum Impurity means that the classes are equally distributed within the node.

p1 and p2 are probabilities of each class.

Consider having 10 customers, out of them 7 bought the car and 3 did not. Putting values as per our examples in the formula, we get:

The lower the Gini, the better the split! The goal is to reduce the overall impurity of the tree by splitting the nodes in a way that creates more homogeneous subsets of data.

Gini impurity is a valuable tool for building accurate and reliable decision trees by helping to identify the optimal splits that lead to more homogeneous subsets of data.


Time to Code

Let's predict whether a customer will buy a car using Age and Salary.

import numpy as np  
import pandas as pd  
import matplotlib.pyplot as plt  
import seaborn as sns  
from sklearn.model_selection import train_test_split  
from sklearn.tree import DecisionTreeClassifier  
from sklearn import tree  
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix  

data = {'Age': [22, 30, 35, 40, 50, 60, 18, 27, 55, 45],
        'Salary': [25000, 60000, 70000, 50000, 80000, 95000, 20000, 48000, 85000, 72000],
        'Bought_Car': [0, 1, 1, 0, 1, 1, 0, 0, 1, 1]}

df = pd.DataFrame(data)
print(df)

Here is the sample dataset which we are going to use further. Lets split the data in training and testing sets.

X = df[['Age', 'Salary']]
y = df['Bought_Car']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print(f"Training samples: {len(X_train)}, Testing samples: {len(X_test)}")
# Train Decision Tree Classifier
model = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
model.fit(X_train, y_train)
# Visualizng the decision tree
plt.figure(figsize=(10,6))
tree.plot_tree(model, feature_names=['Age', 'Salary'], class_names=['No', 'Yes'], filled=True)
plt.show()

Our next step is to try making prediction and evaluating our model.

y_pred = model.predict(X_test)
print("Predicted values:", y_pred)
# Accuracy Score
accuracy = accuracy_score(y_test, y_pred)
print(f"Model Accuracy: {accuracy * 100:.2f}%")

# Confusion Matrix
conf_matrix = confusion_matrix(y_test, y_pred)
sns.heatmap(conf_matrix, annot=True, cmap="Blues", fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# Classification Report
print("Classification Report:\n", classification_report(y_test, y_pred))

Not so accurate, yet acceptable & trustworthy. After all it got 2 out of 3 right!!

So, what we came to know about decision trees:

  • It mimics human decision-making by splitting data into rules.

  • It uses Gini Impurity & Information Gain to choose the best split.

  • Provides easily visualized using tree diagrams.

  • Works well for both Classification and Regression tasks.

  • Simple, but can overfit (controlled using max_depth).


Closing the day

We learnt about decison tree today; in fact I would say mastered it by learning how actually it splits the data to proceed forward to make correct decisions. We built a Decision Tree model in Python and visualized its decision-making process.

Well, you may be getting into more complex examples in practical implementations so I would once again advice you to play with data, so you may underatnd things in more depth.

Well I have to consult a decision tree, whether to take a day off or not, it says its enough for the day ;) :)

Ciao!!

0
Subscribe to my newsletter

Read articles from Saket Khopkar directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Saket Khopkar
Saket Khopkar

Developer based in India. Passionate learner and blogger. All blogs are basically Notes of Tech Learning Journey.