Press ESC to exit fullscreen
πŸ“– Lesson ⏱️ 120 minutes

Decision Trees and Random Forest

Tree-based algorithms using SuperML Java

Decision Trees and Random Forest with SuperML Java

Tree-based algorithms are among the most popular and powerful machine learning techniques. They can handle both numerical and categorical features naturally, capture non-linear relationships, and provide excellent interpretability. This tutorial covers decision trees and random forests using SuperML Java.

Understanding Decision Trees

Decision trees work by recursively splitting the data based on feature values to create a tree-like model of decisions. Each internal node represents a test on a feature, each branch represents the outcome of the test, and each leaf node represents a class label or numerical value.

Key Advantages

  • Interpretability: Easy to visualize and understand
  • No assumptions: No assumptions about data distribution
  • Mixed data types: Handles numerical and categorical features
  • Feature interactions: Automatically captures feature interactions
  • Non-linear patterns: Can model complex non-linear relationships

Key Disadvantages

  • Overfitting: Tendency to overfit, especially with deep trees
  • Instability: Small changes in data can lead to very different trees
  • Bias: Biased toward features with more levels

Decision Tree Classification

Let’s start with a classification example using a customer satisfaction dataset:

import org.superml.tree_models.DecisionTree;
import org.superml.tree_models.RandomForest;
import org.superml.datasets.Datasets;
import org.superml.metrics.Metrics;

public class DecisionTreeClassificationExample {
    
    public static void main(String[] args) {
        DecisionTreeClassificationExample example = new DecisionTreeClassificationExample();
        example.customerSatisfactionPrediction();
    }
    
    public void customerSatisfactionPrediction() {
        // Load customer satisfaction data
        Dataset data = DataLoader.fromCSV("data/customer_satisfaction.csv");
        
        System.out.println("=== Customer Satisfaction Prediction ===");
        exploreData(data);
        
        // Preprocess data
        data = preprocessData(data);
        
        // Train and evaluate decision tree
        trainDecisionTree(data);
        
        // Compare with other parameters
        compareTreeParameters(data);
    }
    
    private void exploreData(Dataset data) {
        System.out.println("Dataset shape: " + data.getShape());
        System.out.println("Features: " + data.getFeatureNames());
        
        // Check target distribution
        Map<Object, Integer> targetDist = data.getValueCounts("satisfaction");
        System.out.println("Satisfaction levels: " + targetDist);
        
        // Check for missing values
        System.out.println("Missing values: " + data.getMissingCounts());
    }
    
    private Dataset preprocessData(Dataset data) {
        System.out.println("\n=== Data Preprocessing ===");
        
        // Handle missing values
        data = data.fillMissing(Map.of(
            "age", Strategy.MEDIAN,
            "income", Strategy.MEDIAN,
            "support_calls", Strategy.MODE
        ));
        
        // Note: Decision trees can handle categorical variables directly
        // No need for one-hot encoding unless specified
        
        return data;
    }
    
    private void trainDecisionTree(Dataset data) {
        System.out.println("\n=== Decision Tree Training ===");
        
        // Split data
        DataSplit split = data.split(0.8, stratify=true);
        
        // Create decision tree classifier
        DecisionTreeClassifier tree = new DecisionTreeClassifier()
            .setCriterion(Criterion.GINI)           // or ENTROPY
            .setMaxDepth(10)                        // Prevent overfitting
            .setMinSamplesSplit(20)                 // Minimum samples to split
            .setMinSamplesLeaf(10)                  // Minimum samples in leaf
            .setMaxFeatures("sqrt")                 // Feature selection strategy
            .setRandomState(42);                    // For reproducibility
        
        // Train the model
        tree.fit(split.getTrainX(), split.getTrainY());
        
        // Make predictions
        double[] predictions = tree.predict(split.getTestX());
        double[][] probabilities = tree.predictProba(split.getTestX());
        
        // Evaluate performance
        evaluateClassifier(split.getTestY(), predictions, probabilities);
        
        // Analyze the tree
        analyzeTree(tree, data.getFeatureNames());
    }
    
    private void evaluateClassifier(double[] yTrue, double[] yPred, double[][] probabilities) {
        System.out.println("\n--- Classification Performance ---");
        
        double accuracy = Metrics.accuracy(yTrue, yPred);
        double f1 = Metrics.f1Score(yTrue, yPred, F1Average.WEIGHTED);
        
        System.out.printf("Accuracy: %.3f%n", accuracy);
        System.out.printf("F1 Score: %.3f%n", f1);
        
        // Detailed classification report
        ClassificationReport report = Metrics.classificationReport(yTrue, yPred);
        System.out.println(report);
        
        // Confusion matrix
        ConfusionMatrix cm = Metrics.confusionMatrix(yTrue, yPred);
        System.out.println("Confusion Matrix:");
        System.out.println(cm);
    }
    
    private void analyzeTree(DecisionTreeClassifier tree, List<String> featureNames) {
        System.out.println("\n--- Tree Analysis ---");
        
        // Tree statistics
        System.out.println("Tree depth: " + tree.getDepth());
        System.out.println("Number of leaves: " + tree.getNumLeaves());
        System.out.println("Number of nodes: " + tree.getNumNodes());
        
        // Feature importance
        double[] importance = tree.getFeatureImportances();
        System.out.println("\nFeature Importances:");
        
        // Sort features by importance
        List<FeatureImportance> importanceList = new ArrayList<>();
        for (int i = 0; i < featureNames.size(); i++) {
            importanceList.add(new FeatureImportance(featureNames.get(i), importance[i]));
        }
        
        importanceList.sort((a, b) -> Double.compare(b.importance, a.importance));
        importanceList.forEach(fi -> 
            System.out.printf("%-20s: %.4f%n", fi.name, fi.importance));
        
        // Tree visualization (text representation)
        System.out.println("\nTree Structure (first few levels):");
        System.out.println(tree.toTextTree(maxDepth=3));
    }
    
    private void compareTreeParameters(Dataset data) {
        System.out.println("\n=== Parameter Comparison ===");
        
        DataSplit split = data.split(0.8, stratify=true);
        
        // Test different max_depth values
        int[] depths = {3, 5, 10, 15, null}; // null means no limit
        
        System.out.printf("%-10s %-10s %-10s %-10s%n", "Max Depth", "Train Acc", "Test Acc", "Overfitting");
        
        for (Integer depth : depths) {
            DecisionTreeClassifier tree = new DecisionTreeClassifier()
                .setMaxDepth(depth)
                .setRandomState(42);
            
            tree.fit(split.getTrainX(), split.getTrainY());
            
            double trainAcc = Metrics.accuracy(split.getTrainY(), 
                tree.predict(split.getTrainX()));
            double testAcc = Metrics.accuracy(split.getTestY(), 
                tree.predict(split.getTestX()));
            
            double overfitting = trainAcc - testAcc;
            String depthStr = depth != null ? depth.toString() : "None";
            
            System.out.printf("%-10s %-10.3f %-10.3f %-10.3f%n", 
                depthStr, trainAcc, testAcc, overfitting);
        }
    }
    
    private static class FeatureImportance {
        String name;
        double importance;
        
        FeatureImportance(String name, double importance) {
            this.name = name;
            this.importance = importance;
        }
    }
}

Decision Tree Regression

Decision trees can also be used for regression problems:

public class DecisionTreeRegressionExample {
    
    public void housePricePrediction() {
        // Load housing data
        Dataset data = DataLoader.fromCSV("data/house_prices.csv");
        
        System.out.println("=== House Price Prediction with Decision Tree ===");
        
        // Preprocess data
        data = preprocessRegressionData(data);
        
        // Train regression tree
        trainRegressionTree(data);
        
        // Compare with linear regression baseline
        compareWithLinearRegression(data);
    }
    
    private Dataset preprocessRegressionData(Dataset data) {
        // Handle missing values
        data = data.fillMissing(Strategy.MEDIAN);
        
        // Create some categorical features for the tree to work with
        data = data.withNewColumn("price_category", row -> {
            double price = row.getDouble("price");
            if (price < 200000) return "low";
            else if (price < 400000) return "medium";
            else return "high";
        });
        
        data = data.withNewColumn("size_category", row -> {
            double size = row.getDouble("size");
            if (size < 1200) return "small";
            else if (size < 2000) return "medium";
            else return "large";
        });
        
        return data;
    }
    
    private void trainRegressionTree(Dataset data) {
        System.out.println("\n=== Regression Tree Training ===");
        
        DataSplit split = data.split(0.8);
        
        // Create regression tree
        DecisionTreeRegressor tree = new DecisionTreeRegressor()
            .setCriterion(Criterion.MSE)            // or MAE
            .setMaxDepth(10)
            .setMinSamplesSplit(20)
            .setMinSamplesLeaf(10)
            .setRandomState(42);
        
        tree.fit(split.getTrainX(), split.getTrainY());
        
        // Make predictions
        double[] predictions = tree.predict(split.getTestX());
        
        // Evaluate performance
        evaluateRegressor(split.getTestY(), predictions);
        
        // Analyze tree
        analyzeRegressionTree(tree, data.getFeatureNames());
    }
    
    private void evaluateRegressor(double[] yTrue, double[] yPred) {
        System.out.println("\n--- Regression Performance ---");
        
        double rmse = Metrics.rmse(yTrue, yPred);
        double mae = Metrics.mae(yTrue, yPred);
        double r2 = Metrics.r2Score(yTrue, yPred);
        
        System.out.printf("RMSE: %.2f%n", rmse);
        System.out.printf("MAE:  %.2f%n", mae);
        System.out.printf("RΒ²:   %.3f%n", r2);
    }
    
    private void analyzeRegressionTree(DecisionTreeRegressor tree, List<String> featureNames) {
        System.out.println("\n--- Regression Tree Analysis ---");
        
        System.out.println("Tree depth: " + tree.getDepth());
        System.out.println("Number of leaves: " + tree.getNumLeaves());
        
        // Feature importance
        double[] importance = tree.getFeatureImportances();
        System.out.println("\nFeature Importances:");
        
        IntStream.range(0, featureNames.size())
            .boxed()
            .sorted((i, j) -> Double.compare(importance[j], importance[i]))
            .limit(10)
            .forEach(i -> System.out.printf("%-20s: %.4f%n", 
                featureNames.get(i), importance[i]));
    }
    
    private void compareWithLinearRegression(Dataset data) {
        System.out.println("\n=== Comparison with Linear Regression ===");
        
        DataSplit split = data.split(0.8);
        
        // Decision Tree
        DecisionTreeRegressor tree = new DecisionTreeRegressor()
            .setMaxDepth(10)
            .setRandomState(42);
        tree.fit(split.getTrainX(), split.getTrainY());
        double[] treePreds = tree.predict(split.getTestX());
        
        // Linear Regression (needs preprocessing)
        Dataset numericData = data.selectNumericColumns();
        StandardScaler scaler = new StandardScaler();
        Dataset scaledData = scaler.fitTransform(numericData);
        DataSplit scaledSplit = scaledData.split(0.8);
        
        LinearRegression lr = new LinearRegression();
        lr.fit(scaledSplit.getTrainX(), scaledSplit.getTrainY());
        double[] lrPreds = lr.predict(scaledSplit.getTestX());
        
        // Compare performance
        double treeRmse = Metrics.rmse(split.getTestY(), treePreds);
        double lrRmse = Metrics.rmse(scaledSplit.getTestY(), lrPreds);
        
        System.out.printf("Decision Tree RMSE: %.2f%n", treeRmse);
        System.out.printf("Linear Regression RMSE: %.2f%n", lrRmse);
        
        if (treeRmse < lrRmse) {
            System.out.println("Decision Tree performs better!");
        } else {
            System.out.println("Linear Regression performs better!");
        }
    }
}

Random Forest

Random Forest is an ensemble method that combines multiple decision trees to create a more robust and accurate model:

public class RandomForestExample {
    
    public static void main(String[] args) {
        RandomForestExample example = new RandomForestExample();
        example.comprehensiveRandomForestExample();
    }
    
    public void comprehensiveRandomForestExample() {
        // Load a complex dataset
        Dataset data = DataLoader.fromCSV("data/complex_classification.csv");
        
        System.out.println("=== Random Forest Comprehensive Example ===");
        
        // Preprocess data
        data = preprocessComplexData(data);
        
        // Train Random Forest
        trainRandomForest(data);
        
        // Compare individual trees vs forest
        compareTreeVsForest(data);
        
        // Hyperparameter tuning
        tuneRandomForest(data);
    }
    
    private Dataset preprocessComplexData(Dataset data) {
        System.out.println("Preprocessing complex dataset...");
        
        // Handle missing values
        data = data.fillMissing(Map.of(
            "numeric_features", Strategy.MEDIAN,
            "categorical_features", Strategy.MODE
        ));
        
        // Encode categorical variables
        LabelEncoder encoder = new LabelEncoder();
        String[] categoricalCols = data.getCategoricalColumns().toArray(new String[0]);
        data = encoder.fitTransform(data, categoricalCols);
        
        return data;
    }
    
    private void trainRandomForest(Dataset data) {
        System.out.println("\n=== Random Forest Training ===");
        
        DataSplit split = data.split(0.8, stratify=true);
        
        // Create Random Forest classifier
        RandomForestClassifier rf = new RandomForestClassifier()
            .setNumTrees(100)                       // Number of trees
            .setMaxDepth(10)                        // Maximum depth of trees
            .setMinSamplesSplit(5)                  // Minimum samples to split
            .setMinSamplesLeaf(2)                   // Minimum samples in leaf
            .setMaxFeatures("sqrt")                 // Features per tree
            .setBootstrap(true)                     // Bootstrap sampling
            .setOobScore(true)                      // Out-of-bag scoring
            .setRandomState(42)                     // For reproducibility
            .setNJobs(-1);                          // Use all processors
        
        // Train the model
        rf.fit(split.getTrainX(), split.getTrainY());
        
        // Make predictions
        double[] predictions = rf.predict(split.getTestX());
        double[][] probabilities = rf.predictProba(split.getTestX());
        
        // Evaluate performance
        evaluateRandomForest(rf, split, predictions, probabilities);
        
        // Analyze the forest
        analyzeRandomForest(rf, data.getFeatureNames());
    }
    
    private void evaluateRandomForest(RandomForestClassifier rf, DataSplit split, 
                                    double[] predictions, double[][] probabilities) {
        System.out.println("\n--- Random Forest Performance ---");
        
        // Basic metrics
        double accuracy = Metrics.accuracy(split.getTestY(), predictions);
        double f1 = Metrics.f1Score(split.getTestY(), predictions, F1Average.WEIGHTED);
        
        System.out.printf("Test Accuracy: %.3f%n", accuracy);
        System.out.printf("Test F1 Score: %.3f%n", f1);
        
        // Out-of-bag score
        double oobScore = rf.getOobScore();
        System.out.printf("OOB Score: %.3f%n", oobScore);
        
        // Training accuracy (to check overfitting)
        double[] trainPreds = rf.predict(split.getTrainX());
        double trainAccuracy = Metrics.accuracy(split.getTrainY(), trainPreds);
        System.out.printf("Train Accuracy: %.3f%n", trainAccuracy);
        
        double overfitting = trainAccuracy - accuracy;
        System.out.printf("Overfitting: %.3f%n", overfitting);
        
        if (overfitting > 0.1) {
            System.out.println("Warning: Significant overfitting detected!");
        }
    }
    
    private void analyzeRandomForest(RandomForestClassifier rf, List<String> featureNames) {
        System.out.println("\n--- Random Forest Analysis ---");
        
        // Forest statistics
        System.out.println("Number of trees: " + rf.getNumTrees());
        
        // Feature importance
        double[] importance = rf.getFeatureImportances();
        System.out.println("\nFeature Importances:");
        
        // Sort and display top features
        IntStream.range(0, featureNames.size())
            .boxed()
            .sorted((i, j) -> Double.compare(importance[j], importance[i]))
            .limit(15)
            .forEach(i -> System.out.printf("%-25s: %.4f%n", 
                featureNames.get(i), importance[i]));
        
        // Tree diversity analysis
        analyzeTreeDiversity(rf);
    }
    
    private void analyzeTreeDiversity(RandomForestClassifier rf) {
        System.out.println("\n--- Tree Diversity Analysis ---");
        
        List<DecisionTreeClassifier> trees = rf.getTrees();
        
        // Calculate average tree depth
        double avgDepth = trees.stream()
            .mapToInt(DecisionTreeClassifier::getDepth)
            .average()
            .orElse(0.0);
        
        // Calculate tree depth statistics
        IntSummaryStatistics depthStats = trees.stream()
            .mapToInt(DecisionTreeClassifier::getDepth)
            .summaryStatistics();
        
        System.out.printf("Average tree depth: %.1f%n", avgDepth);
        System.out.printf("Min tree depth: %d%n", depthStats.getMin());
        System.out.printf("Max tree depth: %d%n", depthStats.getMax());
        
        // Feature usage diversity
        Map<Integer, Integer> featureUsage = new HashMap<>();
        for (DecisionTreeClassifier tree : trees) {
            Set<Integer> usedFeatures = tree.getUsedFeatures();
            for (Integer feature : usedFeatures) {
                featureUsage.merge(feature, 1, Integer::sum);
            }
        }
        
        System.out.printf("Average features used per tree: %.1f%n", 
            featureUsage.values().stream().mapToInt(i -> i).average().orElse(0.0));
    }
    
    private void compareTreeVsForest(Dataset data) {
        System.out.println("\n=== Single Tree vs Random Forest Comparison ===");
        
        DataSplit split = data.split(0.8, stratify=true);
        
        // Single Decision Tree
        DecisionTreeClassifier singleTree = new DecisionTreeClassifier()
            .setMaxDepth(10)
            .setRandomState(42);
        
        singleTree.fit(split.getTrainX(), split.getTrainY());
        double[] singleTreePreds = singleTree.predict(split.getTestX());
        
        // Random Forest
        RandomForestClassifier forest = new RandomForestClassifier()
            .setNumTrees(100)
            .setMaxDepth(10)
            .setRandomState(42);
        
        forest.fit(split.getTrainX(), split.getTrainY());
        double[] forestPreds = forest.predict(split.getTestX());
        
        // Compare performance
        double singleTreeAcc = Metrics.accuracy(split.getTestY(), singleTreePreds);
        double forestAcc = Metrics.accuracy(split.getTestY(), forestPreds);
        
        System.out.printf("Single Tree Accuracy: %.3f%n", singleTreeAcc);
        System.out.printf("Random Forest Accuracy: %.3f%n", forestAcc);
        System.out.printf("Improvement: %.3f%n", forestAcc - singleTreeAcc);
        
        // Cross-validation comparison
        compareWithCrossValidation(data);
    }
    
    private void compareWithCrossValidation(Dataset data) {
        System.out.println("\n--- Cross-Validation Comparison ---");
        
        // Single Tree CV
        DecisionTreeClassifier tree = new DecisionTreeClassifier().setMaxDepth(10);
        CrossValidationResult treeCv = CrossValidator.validate(
            tree, data.getFeatures(), data.getTargets(), 5, Scoring.ACCURACY);
        
        // Random Forest CV
        RandomForestClassifier forest = new RandomForestClassifier()
            .setNumTrees(50)  // Smaller for faster CV
            .setMaxDepth(10);
        CrossValidationResult forestCv = CrossValidator.validate(
            forest, data.getFeatures(), data.getTargets(), 5, Scoring.ACCURACY);
        
        System.out.printf("Single Tree CV: %.3f (Β±%.3f)%n", 
            treeCv.getMean(), treeCv.getStd());
        System.out.printf("Random Forest CV: %.3f (Β±%.3f)%n", 
            forestCv.getMean(), forestCv.getStd());
    }
    
    private void tuneRandomForest(Dataset data) {
        System.out.println("\n=== Random Forest Hyperparameter Tuning ===");
        
        // Grid search for optimal parameters
        GridSearchCV gridSearch = new GridSearchCV()
            .setEstimator(new RandomForestClassifier())
            .setParamGrid(Map.of(
                "n_estimators", Arrays.asList(50, 100, 200),
                "max_depth", Arrays.asList(5, 10, 15, null),
                "min_samples_split", Arrays.asList(2, 5, 10),
                "max_features", Arrays.asList("sqrt", "log2", null)
            ))
            .setScoring(Scoring.F1_WEIGHTED)
            .setCv(5)
            .setVerbose(true)
            .setNJobs(-1);
        
        gridSearch.fit(data.getFeatures(), data.getTargets());
        
        System.out.println("Best parameters: " + gridSearch.getBestParams());
        System.out.printf("Best CV score: %.3f%n", gridSearch.getBestScore());
        
        // Test best model
        RandomForestClassifier bestModel = gridSearch.getBestEstimator();
        DataSplit split = data.split(0.8, stratify=true);
        
        double[] predictions = bestModel.predict(split.getTestX());
        double testAccuracy = Metrics.accuracy(split.getTestY(), predictions);
        
        System.out.printf("Best model test accuracy: %.3f%n", testAccuracy);
    }
}

Random Forest for Regression

Random Forest works equally well for regression problems:

public class RandomForestRegressionExample {
    
    public void advancedHousePricePrediction() {
        // Load comprehensive housing dataset
        Dataset data = DataLoader.fromCSV("data/house_prices_full.csv");
        
        System.out.println("=== Advanced House Price Prediction ===");
        
        // Feature engineering for better tree performance
        data = engineerFeaturesForTrees(data);
        
        // Train Random Forest Regressor
        trainRandomForestRegressor(data);
        
        // Analyze feature interactions
        analyzeFeatureInteractions(data);
    }
    
    private Dataset engineerFeaturesForTrees(Dataset data) {
        System.out.println("Engineering features for tree algorithms...");
        
        // Trees work well with categorical features and interactions
        data = data.withNewColumn("price_per_sqft", 
            row -> row.getDouble("price") / row.getDouble("sqft"));
        
        data = data.withNewColumn("total_rooms", 
            row -> row.getDouble("bedrooms") + row.getDouble("bathrooms"));
        
        data = data.withNewColumn("age_group", row -> {
            double age = row.getDouble("age");
            if (age < 5) return "new";
            else if (age < 15) return "recent";
            else if (age < 30) return "established";
            else return "old";
        });
        
        data = data.withNewColumn("size_category", row -> {
            double sqft = row.getDouble("sqft");
            if (sqft < 1200) return "small";
            else if (sqft < 2000) return "medium";
            else if (sqft < 3000) return "large";
            else return "mansion";
        });
        
        // Encode categorical variables for the tree
        LabelEncoder encoder = new LabelEncoder();
        data = encoder.fitTransform(data, "neighborhood", "age_group", "size_category");
        
        return data;
    }
    
    private void trainRandomForestRegressor(Dataset data) {
        System.out.println("\n=== Random Forest Regression Training ===");
        
        DataSplit split = data.split(0.8);
        
        // Create Random Forest Regressor
        RandomForestRegressor rf = new RandomForestRegressor()
            .setNumTrees(200)
            .setMaxDepth(15)
            .setMinSamplesSplit(5)
            .setMinSamplesLeaf(3)
            .setMaxFeatures("sqrt")
            .setBootstrap(true)
            .setOobScore(true)
            .setRandomState(42)
            .setNJobs(-1);
        
        rf.fit(split.getTrainX(), split.getTrainY());
        
        // Make predictions
        double[] predictions = rf.predict(split.getTestX());
        
        // Evaluate performance
        evaluateRegressionForest(rf, split, predictions);
        
        // Analyze the forest
        analyzeRegressionForest(rf, data.getFeatureNames());
    }
    
    private void evaluateRegressionForest(RandomForestRegressor rf, DataSplit split, 
                                        double[] predictions) {
        System.out.println("\n--- Regression Forest Performance ---");
        
        // Regression metrics
        double rmse = Metrics.rmse(split.getTestY(), predictions);
        double mae = Metrics.mae(split.getTestY(), predictions);
        double r2 = Metrics.r2Score(split.getTestY(), predictions);
        double mape = Metrics.meanAbsolutePercentageError(split.getTestY(), predictions);
        
        System.out.printf("Test RMSE: %.2f%n", rmse);
        System.out.printf("Test MAE:  %.2f%n", mae);
        System.out.printf("Test RΒ²:   %.3f%n", r2);
        System.out.printf("Test MAPE: %.2f%%%n", mape);
        
        // Out-of-bag score
        double oobScore = rf.getOobScore();
        System.out.printf("OOB RΒ²:    %.3f%n", oobScore);
        
        // Training performance
        double[] trainPreds = rf.predict(split.getTrainX());
        double trainR2 = Metrics.r2Score(split.getTrainY(), trainPreds);
        System.out.printf("Train RΒ²:  %.3f%n", trainR2);
        
        // Overfitting check
        double overfitting = trainR2 - r2;
        System.out.printf("Overfitting: %.3f%n", overfitting);
    }
    
    private void analyzeRegressionForest(RandomForestRegressor rf, List<String> featureNames) {
        System.out.println("\n--- Regression Forest Analysis ---");
        
        // Feature importance
        double[] importance = rf.getFeatureImportances();
        System.out.println("Feature Importances:");
        
        IntStream.range(0, featureNames.size())
            .boxed()
            .sorted((i, j) -> Double.compare(importance[j], importance[i]))
            .limit(10)
            .forEach(i -> System.out.printf("%-25s: %.4f%n", 
                featureNames.get(i), importance[i]));
        
        // Partial dependence analysis for top features
        analyzePartialDependence(rf, featureNames, importance);
    }
    
    private void analyzePartialDependence(RandomForestRegressor rf, List<String> featureNames, 
                                        double[] importance) {
        System.out.println("\n--- Partial Dependence Analysis ---");
        
        // Find top 3 most important features
        List<Integer> topFeatures = IntStream.range(0, importance.length)
            .boxed()
            .sorted((i, j) -> Double.compare(importance[j], importance[i]))
            .limit(3)
            .collect(Collectors.toList());
        
        for (Integer featureIdx : topFeatures) {
            String featureName = featureNames.get(featureIdx);
            System.out.printf("\nPartial dependence for '%s':%n", featureName);
            
            // This would typically involve calculating partial dependence plots
            // For demonstration, we'll show the concept
            System.out.printf("  Feature importance: %.4f%n", importance[featureIdx]);
            System.out.println("  [Partial dependence plot would be generated here]");
        }
    }
    
    private void analyzeFeatureInteractions(Dataset data) {
        System.out.println("\n=== Feature Interaction Analysis ===");
        
        // Random Forest naturally captures feature interactions
        // Let's create some explicit interaction features and compare
        
        Dataset originalData = data.copy();
        
        // Add interaction features
        data = data.withNewColumn("sqft_bedrooms_interaction",
            row -> row.getDouble("sqft") * row.getDouble("bedrooms"));
        
        data = data.withNewColumn("age_neighborhood_interaction",
            row -> row.getDouble("age") * row.getDouble("neighborhood"));
        
        // Compare models with and without interaction features
        compareWithInteractions(originalData, data);
    }
    
    private void compareWithInteractions(Dataset originalData, Dataset dataWithInteractions) {
        System.out.println("\n--- Interaction Features Comparison ---");
        
        // Model without interactions
        RandomForestRegressor rfOriginal = new RandomForestRegressor()
            .setNumTrees(100)
            .setRandomState(42);
        
        CrossValidationResult originalCv = CrossValidator.validate(
            rfOriginal, originalData.getFeatures(), originalData.getTargets(), 
            5, Scoring.R2);
        
        // Model with interactions
        RandomForestRegressor rfInteractions = new RandomForestRegressor()
            .setNumTrees(100)
            .setRandomState(42);
        
        CrossValidationResult interactionsCv = CrossValidator.validate(
            rfInteractions, dataWithInteractions.getFeatures(), dataWithInteractions.getTargets(), 
            5, Scoring.R2);
        
        System.out.printf("Original features CV RΒ²: %.3f (Β±%.3f)%n", 
            originalCv.getMean(), originalCv.getStd());
        System.out.printf("With interactions CV RΒ²: %.3f (Β±%.3f)%n", 
            interactionsCv.getMean(), interactionsCv.getStd());
        
        double improvement = interactionsCv.getMean() - originalCv.getMean();
        System.out.printf("Improvement: %.3f%n", improvement);
        
        if (improvement > 0.01) {
            System.out.println("Interaction features provide meaningful improvement!");
        } else {
            System.out.println("Random Forest already captures these interactions automatically.");
        }
    }
}

Advanced Ensemble Techniques

Extra Trees (Extremely Randomized Trees)

public class ExtraTreesExample {
    
    public void extraTreesComparison() {
        Dataset data = DataLoader.fromCSV("data/classification_data.csv");
        data = preprocessData(data);
        
        System.out.println("=== Extra Trees vs Random Forest ===");
        
        DataSplit split = data.split(0.8, stratify=true);
        
        // Random Forest
        RandomForestClassifier rf = new RandomForestClassifier()
            .setNumTrees(100)
            .setRandomState(42);
        
        rf.fit(split.getTrainX(), split.getTrainY());
        double[] rfPreds = rf.predict(split.getTestX());
        
        // Extra Trees
        ExtraTreesClassifier et = new ExtraTreesClassifier()
            .setNumTrees(100)
            .setRandomState(42);
        
        et.fit(split.getTrainX(), split.getTrainY());
        double[] etPreds = et.predict(split.getTestX());
        
        // Compare performance
        double rfAccuracy = Metrics.accuracy(split.getTestY(), rfPreds);
        double etAccuracy = Metrics.accuracy(split.getTestY(), etPreds);
        
        System.out.printf("Random Forest Accuracy: %.3f%n", rfAccuracy);
        System.out.printf("Extra Trees Accuracy: %.3f%n", etAccuracy);
        
        // Training time comparison
        long rfTime = measureTrainingTime(() -> rf.fit(split.getTrainX(), split.getTrainY()));
        long etTime = measureTrainingTime(() -> et.fit(split.getTrainX(), split.getTrainY()));
        
        System.out.printf("Random Forest training time: %d ms%n", rfTime);
        System.out.printf("Extra Trees training time: %d ms%n", etTime);
        
        System.out.printf("Extra Trees speedup: %.1fx%n", (double) rfTime / etTime);
    }
    
    private long measureTrainingTime(Runnable training) {
        long start = System.currentTimeMillis();
        training.run();
        return System.currentTimeMillis() - start;
    }
}

Production Deployment

Model Serving with Random Forest

@RestController
public class TreeModelController {
    
    private final RandomForestClassifier model;
    private final Pipeline preprocessor;
    
    public TreeModelController() {
        this.model = ModelSerializer.load("random_forest_model.pkl");
        this.preprocessor = Pipeline.load("tree_preprocessing.pkl");
    }
    
    @PostMapping("/predict")
    public PredictionResponse predict(@RequestBody FeatureRequest features) {
        try {
            // Convert to dataset
            Dataset input = Dataset.fromMap(features.toMap());
            
            // Preprocess
            Dataset processed = preprocessor.transform(input);
            
            // Predict with probabilities
            double prediction = model.predict(processed.getFeatures())[0];
            double[] probabilities = model.predictProba(processed.getFeatures())[0];
            
            // Get feature importance for this prediction (SHAP-like values)
            double[] contributions = model.getFeatureContributions(processed.getFeatures()[0]);
            
            return new PredictionResponse(prediction, probabilities, contributions);
            
        } catch (Exception e) {
            throw new PredictionException("Prediction failed: " + e.getMessage());
        }
    }
    
    @GetMapping("/model/info")
    public ModelInfo getModelInfo() {
        return new ModelInfo(
            model.getNumTrees(),
            model.getFeatureImportances(),
            model.getOobScore(),
            preprocessor.getSteps()
        );
    }
}

Best Practices and Tips

Avoiding Overfitting

public class TreeBestPractices {
    
    public void demonstrateBestPractices() {
        Dataset data = DataLoader.fromLargeCSV("data/large_dataset.csv");
        
        // 1. Proper train/validation/test split
        DataSplit split = data.split(0.6, 0.2, 0.2);  // 60/20/20 split
        
        // 2. Early stopping based on validation performance
        RandomForestClassifier rf = new RandomForestClassifier()
            .setNumTrees(1000)  // Large number
            .setValidationFraction(0.2)  // Use for early stopping
            .setEarlyStopping(true)
            .setPatience(10);  // Stop if no improvement for 10 iterations
        
        rf.fit(split.getTrainX(), split.getTrainY());
        
        System.out.println("Optimal number of trees: " + rf.getOptimalNumTrees());
        
        // 3. Feature importance stability
        checkFeatureImportanceStability(data);
        
        // 4. Proper cross-validation
        properCrossValidation(data);
    }
    
    private void checkFeatureImportanceStability(Dataset data) {
        System.out.println("\n=== Feature Importance Stability ===");
        
        List<double[]> importanceScores = new ArrayList<>();
        
        // Train multiple models with different random seeds
        for (int seed = 0; seed < 10; seed++) {
            RandomForestClassifier rf = new RandomForestClassifier()
                .setNumTrees(100)
                .setRandomState(seed);
            
            rf.fit(data.getFeatures(), data.getTargets());
            importanceScores.add(rf.getFeatureImportances());
        }
        
        // Calculate stability metrics
        List<String> featureNames = data.getFeatureNames();
        for (int i = 0; i < featureNames.size(); i++) {
            final int featureIdx = i;
            double[] scores = importanceScores.stream()
                .mapToDouble(importance -> importance[featureIdx])
                .toArray();
            
            double mean = Arrays.stream(scores).average().orElse(0.0);
            double std = calculateStandardDeviation(scores);
            double cv = std / mean;  // Coefficient of variation
            
            if (cv < 0.3) {  // Stable feature
                System.out.printf("%-20s: stable (CV=%.3f)%n", featureNames.get(i), cv);
            }
        }
    }
    
    private void properCrossValidation(Dataset data) {
        System.out.println("\n=== Proper Cross-Validation ===");
        
        // Stratified cross-validation for classification
        StratifiedKFoldValidator validator = new StratifiedKFoldValidator(k=5);
        
        RandomForestClassifier rf = new RandomForestClassifier()
            .setNumTrees(100)
            .setMaxDepth(10);
        
        CrossValidationResult result = validator.validate(
            rf, data.getFeatures(), data.getTargets(), Scoring.F1_WEIGHTED);
        
        System.out.printf("CV Score: %.3f (Β±%.3f)%n", result.getMean(), result.getStd());
        
        // Check for high variance
        if (result.getStd() > 0.05) {
            System.out.println("Warning: High variance in CV scores. Consider:");
            System.out.println("- More regularization (lower max_depth)");
            System.out.println("- More data");
            System.out.println("- Different train/test split");
        }
    }
    
    private double calculateStandardDeviation(double[] values) {
        double mean = Arrays.stream(values).average().orElse(0.0);
        double variance = Arrays.stream(values)
            .map(x -> Math.pow(x - mean, 2))
            .average()
            .orElse(0.0);
        return Math.sqrt(variance);
    }
}

Summary

In this comprehensive tutorial, we covered:

  • Decision Trees: Understanding the algorithm, classification and regression
  • Random Forest: Ensemble methods, hyperparameter tuning, feature importance
  • Advanced Techniques: Extra Trees, feature interactions, ensemble comparison
  • Practical Applications: Real-world examples with comprehensive preprocessing
  • Model Analysis: Feature importance, tree visualization, overfitting detection
  • Production Deployment: Model serving and monitoring
  • Best Practices: Cross-validation, stability analysis, overfitting prevention

Key Takeaways:

  1. Decision trees are interpretable but prone to overfitting
  2. Random Forest reduces overfitting through ensemble averaging
  3. Feature engineering matters less for tree-based algorithms
  4. Feature importance provides valuable insights but should be stable
  5. Out-of-bag scoring provides honest performance estimates
  6. Hyperparameter tuning is crucial for optimal performance
  7. Cross-validation is essential for reliable evaluation

Tree-based algorithms are among the most practical and powerful tools in machine learning. They work well out-of-the-box, handle mixed data types naturally, and provide excellent interpretability. Random Forest, in particular, is often a go-to algorithm for many practitioners due to its robustness and good default performance.

In the next tutorial, we’ll explore neural networks, which can capture even more complex patterns but require more careful tuning and larger datasets.