Predicting Non-Linear Data with Decision Tree Regression | My Machine Learning Project


Introduction:
Hey everyone! 👋
I recently completed a hands-on project using Decision Tree Regression, and I was amazed by how simple yet powerful this model is. In this article, I’ll explain what Decision Tree Regression is, why it works so well for certain types of data, and how I built and visualized it using Python.If you're exploring ML regression models or want to understand how trees can predict continuous outcomes—this one’s for you!
→What is Decision Tree Regression?
Decision Tree Regression is a non-linear, non-parametric supervised learning algorithm that predicts a target value by learning simple decision rules inferred from the data features.
It works by:
Splitting the data into branches based on feature thresholds
Making predictions by averaging the target values in the final leaf nodes
It looks like a flowchart where each node splits the data, and predictions are made at the leaves.
→Why Use Decision Tree Regression?
Captures non-linear patterns easily
No need for feature scaling
Interpretable model (you can visualize the tree)
Handles both numerical and categorical data well
Unlike Linear Regression, it doesn’t assume a linear relationship between features and target.
🛠️ Tools & Libraries Used:
Python
NumPy
Pandas
Matplotlib / Seaborn
Scikit-learn
📊 Dataset:
I used a dataset that maps position level to salary, similar to what’s commonly used in regression tutorials.
Position Level | Salary |
1 | 45000 |
2 | 50000 |
3 | 60000 |
4 | 80000 |
5 | 110000 |
6 | 150000 |
7 | 200000 |
8 | 300000 |
9 | 500000 |
10 | 1000000 |
🔍 Step-by-Step Implementation:
1. Import Libraries
pythonCopyEditimport numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
2. Load the Dataset
pythonCopyEditdataset = pd.read_csv('Position_Salaries.csv')
X = dataset.iloc[:, 1:2].values
y = dataset.iloc[:, 2].values
3. Fit Decision Tree Regressor
pythonCopyEditregressor = DecisionTreeRegressor(random_state=0)
regressor.fit(X, y)
4. Predict a New Result
pythonCopyEdity_pred = regressor.predict([[6.5]])
print(f"Predicted Salary for level 6.5: {y_pred}")
5. Visualize the Results (Higher Resolution)
pythonCopyEditX_grid = np.arange(min(X), max(X), 0.01)
X_grid = X_grid.reshape((len(X_grid), 1))
plt.scatter(X, y, color='red')
plt.plot(X_grid, regressor.predict(X_grid), color='blue')
plt.title('Decision Tree Regression')
plt.xlabel('Position Level')
plt.ylabel('Salary')
plt.show()
📈 Output & Results:
The decision tree model produces a step-like graph instead of a smooth curve. This shows how it splits the feature space into intervals and assigns a constant value within each.
Prediction for level 6.5 might return something like 150000
, which is the salary associated with the range where 6.5 falls.
→Key Takeaways:
Easy to implement and works great for both regression and classification tasks.
Doesn’t require feature scaling.
Prone to overfitting, but performs well on small datasets.
Ideal for non-linear and discontinuous data.
→What I Learned:
How decision trees make predictions by learning thresholds.
Visualization helped understand how trees predict in “jumps.”
Importance of setting
random_state
for reproducibility.
→ What’s Next?
Try using Random Forest Regression to improve accuracy and reduce overfitting.
Use GridSearchCV for hyperparameter tuning like
max_depth
,min_samples_split
, etc.Apply it to real-world datasets (e.g., crop yield prediction, stock prices, etc.)
Subscribe to my newsletter
Read articles from Lokesh Patidar directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Lokesh Patidar
Lokesh Patidar
Hey, I'm Lokesh Patidar! I'm a 2nd-year student at SATI Vidisha, passionate about AI, Machine Learning, Full-Stack Development , and DSA. What I'm Learning: Currently Exploring Machine Learning 🤖 Completed DSA & Frontend Development 🌐 Now exploring Backend Development 💡 Interests: I love solving problems, building projects, and integrating AI into real-world applications. Excited to contribute to tech communities and share my learning journey! 📌 Follow my blog for insights on AI, ML, and Full-Stack projects!