Understanding Decision Trees

Learn what decision trees are, how they work, and how to implement them using Python and scikit-learn for classification and regression tasks.

🔰 beginner
⏱️ 20 minutes
👤 SuperML Team

· Machine Learning · 2 min read

📋 Prerequisites

  • Basic Python knowledge
  • Understanding of supervised learning
  • Familiarity with pandas and scikit-learn

🎯 What You'll Learn

  • Understand how decision trees work for classification and regression
  • Visualize and interpret decision tree splits
  • Implement decision trees using scikit-learn
  • Evaluate decision tree models

Introduction

Decision Trees are a fundamental algorithm in machine learning used for both classification and regression tasks. They work by splitting the data into subsets based on feature values, creating a tree-like structure that is easy to interpret and visualize.


Why Use Decision Trees?

✅ Easy to understand and interpret.
✅ Can handle both numerical and categorical data.
✅ Require little data preprocessing.
✅ Can capture non-linear relationships.

However, decision trees are prone to overfitting, which can be addressed using pruning or ensemble methods like Random Forest.


How Do Decision Trees Work?

At each node in the tree:

  • The algorithm selects the feature and threshold that best splits the data to reduce impurity (e.g., using Gini impurity or entropy).
  • The dataset is split into subsets recursively until a stopping criterion is met (e.g., max depth, min samples per leaf).

Step-by-Step Implementation

1️⃣ Import Libraries

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

2️⃣ Load Dataset

We’ll use a simple dataset for binary classification:

data = {
    'Feature1': [2, 3, 10, 19, 25, 30, 40, 50],
    'Feature2': [5, 7, 15, 20, 28, 35, 40, 55],
    'Label': [0, 0, 0, 1, 1, 1, 1, 1]
}
df = pd.DataFrame(data)
print(df.head())

3️⃣ Prepare Data

X = df[['Feature1', 'Feature2']]
y = df['Label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

4️⃣ Train the Decision Tree Model

model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X_train, y_train)

5️⃣ Visualize the Decision Tree

plt.figure(figsize=(12,8))
plot_tree(model, feature_names=['Feature1', 'Feature2'], class_names=['0', '1'], filled=True)
plt.show()

6️⃣ Evaluate the Model

y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

Conclusion

🎉 You have successfully:

✅ Understood how decision trees work.
✅ Implemented a decision tree classifier using scikit-learn.
✅ Visualized your decision tree for interpretability.
✅ Evaluated your model’s accuracy.


What’s Next?

  • Try using a real-world dataset like the Iris dataset to practice.
  • Explore hyperparameter tuning (max_depth, min_samples_split) to prevent overfitting.
  • Learn about ensemble methods like Random Forest and Gradient Boosting for improved performance.

For questions or to share your results, join our SuperML Community and continue your learning journey!

Back to Tutorials

Related Tutorials

🔰beginner ⏱️ 45 minutes

Random Forest Regression

Learn what Random Forest Regression is, how it works, and how it helps in building robust, accurate machine learning models.

Machine Learning2 min read
machine learningrandom forestregression +1
🔰beginner ⏱️ 50 minutes

Regression Analysis

Learn what regression analysis is, how it helps in understanding relationships between variables, and see practical examples to build your ML intuition.

Machine Learning2 min read
machine learningregressionanalysis +1
🔰beginner ⏱️ 50 minutes

Support Vector Machines (SVMs)

Learn what Support Vector Machines are, how they work, and see clear examples to understand this powerful ML algorithm for classification.

Machine Learning2 min read
machine learningsupport vector machinesclassification +1
🔰beginner ⏱️ 20 minutes

Introduction to Logistic Regression

Learn what logistic regression is, how it works, and how to implement it using Python and scikit-learn in this clear, beginner-friendly tutorial.

Machine Learning2 min read
beginnermachine learningclassification