🌳 Cracking Decision Trees: A Visual Guide to Classification in Python

Table of contents

“A Decision Tree is not just a model — it’s a flowchart of logic and learning.”
— Tilak Savani
🧠 Introduction
Decision Trees are one of the most intuitive and explainable machine learning algorithms. They mimic human decision-making and are used for both classification and regression tasks.
In this blog, we’ll break down how they work — from logic to math — and implement one using scikit-learn
.
🤔 What Is a Decision Tree?
A Decision Tree splits your dataset into smaller subsets based on feature values, forming a tree-like structure.
Each internal node asks a question, and the branches represent the answers.
Example:
[Age > 25?]
/ \
Yes No
/ \
[Income > 50K?] Reject
⚙️ How It Works (Step-by-Step)
Start with all the data.
For each feature, calculate how "pure" the split would be.
Choose the best feature to split on (highest information gain or lowest Gini).
Repeat the process on each subset recursively until stopping criteria is met.
🧮 Math Behind Decision Trees
Decision Trees use criteria like Gini Impurity or Information Gain (Entropy) to decide the best split.
✳️ 1. Gini Impurity
Used in sklearn
by default. Measures how "mixed" the labels are in a node.
Gini = 1 − Σ(pᵢ²)
Where:
pᵢ
= probability of classi
in the node
A pure node (all same class) has Gini = 0.
🔍 2. Entropy & Information Gain
Entropy quantifies disorder:
Entropy = − Σ(pᵢ * log₂(pᵢ))
Information Gain is the reduction in entropy after a split:
Gain = Entropy(parent) − [Weighted avg. Entropy(children)]
Decision Trees using entropy try to maximize gain at each node.
🧪 Python Code: Classification Example
Let’s build a tree to classify whether a person buys a product based on age and salary.
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import matplotlib.pyplot as plt
# Sample Data
data = {
'Age': [22, 25, 47, 52, 46, 56],
'Salary': [15000, 29000, 48000, 60000, 52000, 61000],
'Buys': [0, 0, 1, 1, 1, 1]
}
df = pd.DataFrame(data)
X = df[['Age', 'Salary']]
y = df['Buys']
# Train Model
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
clf.fit(X, y)
📊 Visualize the Tree
plt.figure(figsize=(10, 6))
tree.plot_tree(clf, filled=True, feature_names=['Age', 'Salary'], class_names=['No', 'Yes'])
plt.title("Decision Tree for Product Purchase")
plt.show()
🧪 Predict
# Predict for someone aged 30 with 40K salary
print(clf.predict([[30, 40000]])) # Output: [0] (Not likely to buy)
🌍 Real-World Applications
Industry | Use Case |
Finance | Loan approval (yes/no) |
Health | Disease diagnosis (benign/malignant) |
Retail | Predict customer churn |
HR | Employee attrition prediction |
✅ Advantages
Easy to understand and interpret
No need for feature scaling
Works with both numerical and categorical data
⚠️ Limitations
Prone to overfitting (can be solved by pruning or using ensembles like Random Forest)
Not as accurate as other models for complex datasets
🧩 Final Thoughts
Decision Trees are a powerful blend of logic, math, and machine learning. They're transparent, fast, and form the building block of ensemble models like Random Forest and Gradient Boosting.
Whether you're a beginner or building AI at scale, decision trees are a must-know algorithm.
📬 Subscribe
If you liked this blog, follow me on Hasenode for more posts on Machine Learning and Python.
Thanks for reading! 😊
Subscribe to my newsletter
Read articles from Tilak Savani directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
