Press ESC to exit fullscreen
πŸ“– Lesson ⏱️ 90 minutes

Decision Trees

Tree-based algorithms for classification and regression

Introduction

Decision Trees are a fundamental algorithm in machine learning used for both classification and regression tasks. They work by splitting the data into subsets based on feature values, creating a tree-like structure that is easy to interpret and visualize.


Why Use Decision Trees?

βœ… Easy to understand and interpret.
βœ… Can handle both numerical and categorical data.
βœ… Require little data preprocessing.
βœ… Can capture non-linear relationships.

However, decision trees are prone to overfitting, which can be addressed using pruning or ensemble methods like Random Forest.


How Do Decision Trees Work?

At each node in the tree:

  • The algorithm selects the feature and threshold that best splits the data to reduce impurity (e.g., using Gini impurity or entropy).
  • The dataset is split into subsets recursively until a stopping criterion is met (e.g., max depth, min samples per leaf).

Step-by-Step Implementation

1️⃣ Import Libraries

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

2️⃣ Load Dataset

We’ll use a simple dataset for binary classification:

data = {
    'Feature1': [2, 3, 10, 19, 25, 30, 40, 50],
    'Feature2': [5, 7, 15, 20, 28, 35, 40, 55],
    'Label': [0, 0, 0, 1, 1, 1, 1, 1]
}
df = pd.DataFrame(data)
print(df.head())

3️⃣ Prepare Data

X = df[['Feature1', 'Feature2']]
y = df['Label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

4️⃣ Train the Decision Tree Model

model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X_train, y_train)

5️⃣ Visualize the Decision Tree

plt.figure(figsize=(12,8))
plot_tree(model, feature_names=['Feature1', 'Feature2'], class_names=['0', '1'], filled=True)
plt.show()

6️⃣ Evaluate the Model

y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

Conclusion

πŸŽ‰ You have successfully:

βœ… Understood how decision trees work.
βœ… Implemented a decision tree classifier using scikit-learn.
βœ… Visualized your decision tree for interpretability.
βœ… Evaluated your model’s accuracy.


What’s Next?

  • Try using a real-world dataset like the Iris dataset to practice.
  • Explore hyperparameter tuning (max_depth, min_samples_split) to prevent overfitting.
  • Learn about ensemble methods like Random Forest and Gradient Boosting for improved performance.

For questions or to share your results, join our SuperML Community and continue your learning journey!