Course Content
XGBoost and Gradient Boosting
Lightning-fast XGBoost implementation with hyperparameter optimization
XGBoost in Java - Extreme Gradient Boosting
XGBoost (Extreme Gradient Boosting) in SuperML Java 2.1.0 provides lightning-fast training with early stopping, advanced regularization, and enterprise-grade performance. This tutorial covers how to implement XGBoost for classification and regression with professional hyperparameter tuning and production deployment.
What Youβll Learn
- XGBoost Fundamentals - Understanding extreme gradient boosting
- Lightning-Fast Training - Optimized training with early stopping
- Advanced Regularization - L1/L2 regularization and tree pruning
- Hyperparameter Optimization - Grid search and Bayesian optimization
- Feature Importance - Understanding model decisions
- Production Deployment - Enterprise-ready XGBoost systems
- Performance Benchmarking - Comparing with other algorithms
Prerequisites
- Completion of βIntroduction to SuperML Javaβ tutorial
- Understanding of ensemble methods and decision trees
- Basic knowledge of gradient boosting concepts
- Java development environment with SuperML Java 2.1.0
XGBoost Overview
XGBoost (Extreme Gradient Boosting) is an optimized distributed gradient boosting library that provides:
- Superior Performance: Often wins machine learning competitions
- Speed: Highly optimized parallel training
- Flexibility: Supports regression, classification, and ranking
- Regularization: Built-in L1/L2 regularization to prevent overfitting
- Feature Importance: Comprehensive feature importance metrics
Basic XGBoost Implementation
XGBoost Classification
import org.superml.tree_models.XGBoost;
import org.superml.datasets.Datasets;
import org.superml.model_selection.ModelSelection;
import org.superml.metrics.Metrics;
public class XGBoostClassificationExample {
public static void main(String[] args) {
System.out.println("=== SuperML 2.1.0 - XGBoost Classification ===\n");
try {
// Load dataset
var dataset = Datasets.loadWine();
var split = ModelSelection.trainTestSplit(dataset.X, dataset.y, 0.2, 42);
System.out.println("π Dataset: " + dataset.X.length + " samples, " + dataset.X[0].length + " features");
System.out.println("π Classes: " + (int)(java.util.Arrays.stream(dataset.y).max().orElse(0) + 1));
System.out.println("π Training samples: " + split.XTrain.length);
System.out.println("π Test samples: " + split.XTest.length);
// Create XGBoost classifier with optimized parameters
XGBoost xgb = new XGBoost()
.setObjective("multi:softprob") // Multi-class classification
.setNumBoostRound(100) // Number of boosting rounds
.setLearningRate(0.1) // Learning rate (eta)
.setMaxDepth(6) // Maximum tree depth
.setMinChildWeight(1) // Minimum child weight
.setSubsample(0.8) // Subsample ratio
.setColsampleBytree(0.8) // Feature sampling ratio
.setRegAlpha(0.1) // L1 regularization
.setRegLambda(1.0) // L2 regularization
.setGamma(0.1) // Minimum split loss
.setEarlyStoppingRounds(10) // Early stopping
.setVerbose(true); // Verbose output
System.out.println("π XGBoost Configuration:");
System.out.println("- Objective: Multi-class classification");
System.out.println("- Boosting rounds: 100");
System.out.println("- Learning rate: 0.1");
System.out.println("- Max depth: 6");
System.out.println("- Regularization: L1=0.1, L2=1.0");
System.out.println("- Early stopping: 10 rounds");
// Train XGBoost model
System.out.println("\nποΈ Training XGBoost model...");
long startTime = System.currentTimeMillis();
xgb.fit(split.XTrain, split.yTrain);
long trainingTime = System.currentTimeMillis() - startTime;
System.out.println("β‘ Training completed in " + trainingTime + " ms");
// Make predictions
double[] predictions = xgb.predict(split.XTest);
double[][] probabilities = xgb.predictProbabilities(split.XTest);
// Evaluate performance
double accuracy = Metrics.accuracy(split.yTest, predictions);
double precision = Metrics.precision(split.yTest, predictions);
double recall = Metrics.recall(split.yTest, predictions);
double f1 = Metrics.f1Score(split.yTest, predictions);
System.out.println("\n=== XGBoost Classification Results ===");
System.out.println("π Accuracy: " + String.format("%.4f", accuracy));
System.out.println("π Precision: " + String.format("%.4f", precision));
System.out.println("π Recall: " + String.format("%.4f", recall));
System.out.println("π F1 Score: " + String.format("%.4f", f1));
System.out.println("β±οΈ Training Time: " + trainingTime + " ms");
// Display feature importance
System.out.println("\nπ Feature Importance (Top 5):");
double[] importance = xgb.getFeatureImportance();
for (int i = 0; i < Math.min(5, importance.length); i++) {
System.out.println("- Feature " + i + ": " + String.format("%.4f", importance[i]));
}
// Training history
System.out.println("\nπ Training History:");
var history = xgb.getTrainingHistory();
System.out.println("- Best iteration: " + history.getBestIteration());
System.out.println("- Best score: " + String.format("%.4f", history.getBestScore()));
System.out.println("- Final train loss: " + String.format("%.4f", history.getFinalTrainLoss()));
System.out.println("- Final validation loss: " + String.format("%.4f", history.getFinalValidationLoss()));
// Confusion matrix
int[][] confMatrix = Metrics.confusionMatrix(split.yTest, predictions);
System.out.println("\nπ Confusion Matrix:");
for (int i = 0; i < confMatrix.length; i++) {
System.out.println(java.util.Arrays.toString(confMatrix[i]));
}
System.out.println("\nβ
XGBoost classification completed successfully!");
} catch (Exception e) {
System.err.println("β Error in XGBoost classification: " + e.getMessage());
e.printStackTrace();
}
}
}
XGBoost Regression
import org.superml.tree_models.XGBoost;
import org.superml.datasets.Datasets;
import org.superml.model_selection.ModelSelection;
import org.superml.metrics.Metrics;
public class XGBoostRegressionExample {
public static void main(String[] args) {
System.out.println("=== SuperML 2.1.0 - XGBoost Regression ===\n");
try {
// Load regression dataset
var dataset = Datasets.loadBoston();
var split = ModelSelection.trainTestSplit(dataset.X, dataset.y, 0.2, 42);
System.out.println("π Dataset: " + dataset.X.length + " samples, " + dataset.X[0].length + " features");
System.out.println("π Target range: [" +
String.format("%.2f", java.util.Arrays.stream(dataset.y).min().orElse(0)) + ", " +
String.format("%.2f", java.util.Arrays.stream(dataset.y).max().orElse(0)) + "]");
// Create XGBoost regressor
XGBoost xgb = new XGBoost()
.setObjective("reg:squarederror") // Regression objective
.setNumBoostRound(200) // More rounds for regression
.setLearningRate(0.05) // Lower learning rate
.setMaxDepth(5) // Moderate depth
.setMinChildWeight(3) // Higher minimum child weight
.setSubsample(0.9) // High subsample ratio
.setColsampleBytree(0.9) // High feature sampling
.setRegAlpha(0.05) // L1 regularization
.setRegLambda(0.5) // L2 regularization
.setGamma(0.05) // Minimum split loss
.setEarlyStoppingRounds(20) // Early stopping
.setEvalMetric("rmse") // Root Mean Square Error
.setVerbose(true);
System.out.println("π XGBoost Regression Configuration:");
System.out.println("- Objective: Squared error regression");
System.out.println("- Boosting rounds: 200");
System.out.println("- Learning rate: 0.05");
System.out.println("- Max depth: 5");
System.out.println("- Evaluation metric: RMSE");
// Train XGBoost model
System.out.println("\nποΈ Training XGBoost regressor...");
long startTime = System.currentTimeMillis();
xgb.fit(split.XTrain, split.yTrain);
long trainingTime = System.currentTimeMillis() - startTime;
System.out.println("β‘ Training completed in " + trainingTime + " ms");
// Make predictions
double[] predictions = xgb.predict(split.XTest);
// Evaluate performance
double mse = Metrics.meanSquaredError(split.yTest, predictions);
double rmse = Math.sqrt(mse);
double mae = Metrics.meanAbsoluteError(split.yTest, predictions);
double r2 = Metrics.r2Score(split.yTest, predictions);
System.out.println("\n=== XGBoost Regression Results ===");
System.out.println("π Mean Squared Error: " + String.format("%.4f", mse));
System.out.println("π Root Mean Squared Error: " + String.format("%.4f", rmse));
System.out.println("π Mean Absolute Error: " + String.format("%.4f", mae));
System.out.println("π RΒ² Score: " + String.format("%.4f", r2));
System.out.println("β±οΈ Training Time: " + trainingTime + " ms");
// Feature importance analysis
System.out.println("\nπ Feature Importance Analysis:");
double[] importance = xgb.getFeatureImportance();
String[] featureNames = {"CRIM", "ZN", "INDUS", "CHAS", "NOX", "RM", "AGE", "DIS", "RAD", "TAX", "PTRATIO", "B", "LSTAT"};
// Sort features by importance
var featureImportance = new java.util.ArrayList<java.util.Map.Entry<String, Double>>();
for (int i = 0; i < Math.min(importance.length, featureNames.length); i++) {
featureImportance.add(new java.util.AbstractMap.SimpleEntry<>(featureNames[i], importance[i]));
}
featureImportance.sort(java.util.Map.Entry.<String, Double>comparingByValue().reversed());
System.out.println("Top 5 Most Important Features:");
for (int i = 0; i < Math.min(5, featureImportance.size()); i++) {
var entry = featureImportance.get(i);
System.out.println("- " + entry.getKey() + ": " + String.format("%.4f", entry.getValue()));
}
// Training convergence
System.out.println("\nπ Training Convergence:");
var history = xgb.getTrainingHistory();
System.out.println("- Best iteration: " + history.getBestIteration());
System.out.println("- Best RMSE: " + String.format("%.4f", history.getBestScore()));
System.out.println("- Training stopped early: " + history.isEarlyStopped());
// Prediction analysis
System.out.println("\nπ― Prediction Analysis (First 10 samples):");
System.out.println("Actual\tPredicted\tError");
for (int i = 0; i < Math.min(10, split.yTest.length); i++) {
double error = Math.abs(split.yTest[i] - predictions[i]);
System.out.println(String.format("%.2f\t%.2f\t\t%.2f",
split.yTest[i], predictions[i], error));
}
System.out.println("\nβ
XGBoost regression completed successfully!");
} catch (Exception e) {
System.err.println("β Error in XGBoost regression: " + e.getMessage());
e.printStackTrace();
}
}
}
Advanced XGBoost Features
Hyperparameter Optimization
import org.superml.tree_models.XGBoost;
import org.superml.model_selection.GridSearchCV;
import org.superml.model_selection.RandomizedSearchCV;
import org.superml.datasets.Datasets;
public class XGBoostHyperparameterTuning {
public static void main(String[] args) {
System.out.println("=== SuperML 2.1.0 - XGBoost Hyperparameter Tuning ===\n");
try {
// Load dataset
var dataset = Datasets.loadWine();
System.out.println("π Dataset: " + dataset.X.length + " samples, " + dataset.X[0].length + " features");
// Define hyperparameter search space
var paramGrid = new java.util.HashMap<String, Object>();
paramGrid.put("numBoostRound", new int[]{50, 100, 200});
paramGrid.put("learningRate", new double[]{0.01, 0.1, 0.2});
paramGrid.put("maxDepth", new int[]{3, 5, 7});
paramGrid.put("minChildWeight", new int[]{1, 3, 5});
paramGrid.put("subsample", new double[]{0.7, 0.8, 0.9});
paramGrid.put("colsampleBytree", new double[]{0.7, 0.8, 0.9});
paramGrid.put("regAlpha", new double[]{0.0, 0.1, 0.5});
paramGrid.put("regLambda", new double[]{0.5, 1.0, 2.0});
System.out.println("π§ Hyperparameter Search Space:");
System.out.println("- Boosting rounds: [50, 100, 200]");
System.out.println("- Learning rate: [0.01, 0.1, 0.2]");
System.out.println("- Max depth: [3, 5, 7]");
System.out.println("- Regularization: Alpha [0.0, 0.1, 0.5], Lambda [0.5, 1.0, 2.0]");
System.out.println("- Total combinations: " + calculateCombinations(paramGrid));
// Grid Search with XGBoost
System.out.println("\nπ Starting Grid Search...");
var gridSearch = new GridSearchCV()
.setEstimator(new XGBoost().setObjective("multi:softprob"))
.setParamGrid(paramGrid)
.setCrossValidation(5)
.setScoring("accuracy")
.setVerbose(true)
.setNJobs(4); // Parallel processing
long startTime = System.currentTimeMillis();
gridSearch.fit(dataset.X, dataset.y);
long gridSearchTime = System.currentTimeMillis() - startTime;
// Display grid search results
System.out.println("\n=== Grid Search Results ===");
System.out.println("π Best Score: " + String.format("%.4f", gridSearch.getBestScore()));
System.out.println("π Best Parameters: " + gridSearch.getBestParams());
System.out.println("β±οΈ Grid Search Time: " + gridSearchTime + " ms");
// Get best model
var bestModel = gridSearch.getBestEstimator();
// Randomized Search for comparison
System.out.println("\nπ² Starting Randomized Search...");
var randomSearch = new RandomizedSearchCV()
.setEstimator(new XGBoost().setObjective("multi:softprob"))
.setParamDistributions(paramGrid)
.setCrossValidation(5)
.setScoring("accuracy")
.setNIter(50) // 50 random combinations
.setVerbose(true)
.setNJobs(4);
startTime = System.currentTimeMillis();
randomSearch.fit(dataset.X, dataset.y);
long randomSearchTime = System.currentTimeMillis() - startTime;
// Display randomized search results
System.out.println("\n=== Randomized Search Results ===");
System.out.println("π Best Score: " + String.format("%.4f", randomSearch.getBestScore()));
System.out.println("π Best Parameters: " + randomSearch.getBestParams());
System.out.println("β±οΈ Randomized Search Time: " + randomSearchTime + " ms");
// Compare search methods
System.out.println("\nπ Search Method Comparison:");
System.out.println("- Grid Search: " + String.format("%.4f", gridSearch.getBestScore()) +
" (Time: " + gridSearchTime + " ms)");
System.out.println("- Randomized Search: " + String.format("%.4f", randomSearch.getBestScore()) +
" (Time: " + randomSearchTime + " ms)");
// Advanced hyperparameter analysis
System.out.println("\n㪠Advanced Hyperparameter Analysis:");
analyzeHyperparameterImportance(gridSearch);
System.out.println("\nβ
XGBoost hyperparameter tuning completed!");
} catch (Exception e) {
System.err.println("β Error in hyperparameter tuning: " + e.getMessage());
e.printStackTrace();
}
}
private static int calculateCombinations(java.util.HashMap<String, Object> paramGrid) {
int combinations = 1;
for (Object values : paramGrid.values()) {
if (values instanceof int[]) {
combinations *= ((int[]) values).length;
} else if (values instanceof double[]) {
combinations *= ((double[]) values).length;
}
}
return combinations;
}
private static void analyzeHyperparameterImportance(GridSearchCV gridSearch) {
// Analyze which hyperparameters have the most impact
var results = gridSearch.getCVResults();
System.out.println("Parameter Impact Analysis:");
System.out.println("- Learning Rate: High impact on convergence speed");
System.out.println("- Max Depth: Controls model complexity");
System.out.println("- Regularization: Prevents overfitting");
System.out.println("- Subsample: Reduces overfitting and training time");
}
}
XGBoost with Early Stopping
import org.superml.tree_models.XGBoost;
import org.superml.datasets.Datasets;
import org.superml.model_selection.ModelSelection;
public class XGBoostEarlyStoppingExample {
public static void main(String[] args) {
System.out.println("=== SuperML 2.1.0 - XGBoost Early Stopping ===\n");
try {
// Load large dataset
var dataset = Datasets.makeClassification(5000, 50, 5, 42);
var split = ModelSelection.trainTestSplit(dataset.X, dataset.y, 0.2, 42);
System.out.println("π Dataset: " + dataset.X.length + " samples, " + dataset.X[0].length + " features");
System.out.println("π Training samples: " + split.XTrain.length);
System.out.println("π Validation samples: " + split.XTest.length);
// XGBoost with early stopping
XGBoost xgb = new XGBoost()
.setObjective("multi:softprob")
.setNumBoostRound(1000) // Many rounds - early stopping will control
.setLearningRate(0.1)
.setMaxDepth(6)
.setMinChildWeight(1)
.setSubsample(0.8)
.setColsampleBytree(0.8)
.setRegAlpha(0.1)
.setRegLambda(1.0)
.setEarlyStoppingRounds(20) // Stop if no improvement for 20 rounds
.setEvalMetric("mlogloss") // Multi-class log loss
.setValidationFraction(0.2) // Use 20% for validation
.setVerboseEval(10) // Print every 10 rounds
.setVerbose(true);
System.out.println("π XGBoost Early Stopping Configuration:");
System.out.println("- Max boosting rounds: 1000");
System.out.println("- Early stopping rounds: 20");
System.out.println("- Validation fraction: 20%");
System.out.println("- Evaluation metric: Multi-class log loss");
System.out.println("- Verbose evaluation: Every 10 rounds");
// Train with early stopping
System.out.println("\nποΈ Training XGBoost with early stopping...");
long startTime = System.currentTimeMillis();
xgb.fit(split.XTrain, split.yTrain);
long trainingTime = System.currentTimeMillis() - startTime;
// Get training history
var history = xgb.getTrainingHistory();
System.out.println("\n=== Early Stopping Results ===");
System.out.println("β±οΈ Training Time: " + trainingTime + " ms");
System.out.println("π― Best Iteration: " + history.getBestIteration());
System.out.println("π Best Score: " + String.format("%.4f", history.getBestScore()));
System.out.println("π Early Stopped: " + history.isEarlyStopped());
System.out.println("π Total Iterations: " + history.getTotalIterations());
System.out.println("πΎ Training Time Saved: " +
String.format("%.1f%%", (1.0 - (double)history.getTotalIterations() / 1000) * 100));
// Analyze convergence
System.out.println("\nπ Convergence Analysis:");
var trainLoss = history.getTrainLoss();
var validLoss = history.getValidationLoss();
System.out.println("Training Loss Progression (last 10 iterations):");
for (int i = Math.max(0, trainLoss.length - 10); i < trainLoss.length; i++) {
System.out.println("- Iteration " + (i + 1) + ": Train=" +
String.format("%.4f", trainLoss[i]) + ", Valid=" +
String.format("%.4f", validLoss[i]));
}
// Evaluate on test set
double[] predictions = xgb.predict(split.XTest);
double accuracy = Metrics.accuracy(split.yTest, predictions);
System.out.println("\nπ Test Set Performance:");
System.out.println("- Accuracy: " + String.format("%.4f", accuracy));
System.out.println("- Optimal iterations: " + history.getBestIteration());
System.out.println("- Overfitting prevented: " + history.isEarlyStopped());
// Compare with fixed iterations
System.out.println("\nπ Comparison with Fixed Iterations:");
// Train model with fixed 100 iterations
XGBoost xgbFixed = new XGBoost()
.setObjective("multi:softprob")
.setNumBoostRound(100)
.setLearningRate(0.1)
.setMaxDepth(6)
.setVerbose(false);
startTime = System.currentTimeMillis();
xgbFixed.fit(split.XTrain, split.yTrain);
long fixedTime = System.currentTimeMillis() - startTime;
double[] fixedPredictions = xgbFixed.predict(split.XTest);
double fixedAccuracy = Metrics.accuracy(split.yTest, fixedPredictions);
System.out.println("- Early Stopping: " + String.format("%.4f", accuracy) +
" accuracy, " + trainingTime + " ms");
System.out.println("- Fixed 100 rounds: " + String.format("%.4f", fixedAccuracy) +
" accuracy, " + fixedTime + " ms");
System.out.println("- Improvement: " + String.format("%.4f", accuracy - fixedAccuracy));
System.out.println("\nβ
XGBoost early stopping analysis completed!");
} catch (Exception e) {
System.err.println("β Error in early stopping: " + e.getMessage());
e.printStackTrace();
}
}
}
Feature Importance and Model Interpretation
Advanced Feature Importance Analysis
import org.superml.tree_models.XGBoost;
import org.superml.datasets.Datasets;
import org.superml.interpretation.FeatureImportanceAnalyzer;
import org.superml.interpretation.SHAPValues;
public class XGBoostFeatureImportanceExample {
public static void main(String[] args) {
System.out.println("=== SuperML 2.1.0 - XGBoost Feature Importance Analysis ===\n");
try {
// Load dataset with known features
var dataset = Datasets.loadWine();
String[] featureNames = {
"Alcohol", "Malic Acid", "Ash", "Alcalinity of Ash", "Magnesium",
"Total Phenols", "Flavanoids", "Nonflavanoid Phenols", "Proanthocyanins",
"Color Intensity", "Hue", "OD280/OD315", "Proline"
};
System.out.println("π Dataset: Wine Classification");
System.out.println("π Features: " + featureNames.length);
System.out.println("π Samples: " + dataset.X.length);
// Train XGBoost model
XGBoost xgb = new XGBoost()
.setObjective("multi:softprob")
.setNumBoostRound(100)
.setLearningRate(0.1)
.setMaxDepth(6)
.setRegAlpha(0.1)
.setRegLambda(1.0)
.setVerbose(false);
System.out.println("\nποΈ Training XGBoost model...");
xgb.fit(dataset.X, dataset.y);
// Get multiple types of feature importance
System.out.println("\n=== Feature Importance Analysis ===");
// 1. Gain-based importance (default)
double[] gainImportance = xgb.getFeatureImportance("gain");
System.out.println("\nπ Feature Importance by Gain:");
displayFeatureImportance(featureNames, gainImportance);
// 2. Frequency-based importance
double[] frequencyImportance = xgb.getFeatureImportance("frequency");
System.out.println("\nπ Feature Importance by Frequency:");
displayFeatureImportance(featureNames, frequencyImportance);
// 3. Cover-based importance
double[] coverImportance = xgb.getFeatureImportance("cover");
System.out.println("\nπ Feature Importance by Cover:");
displayFeatureImportance(featureNames, coverImportance);
// Advanced feature importance analysis
var analyzer = new FeatureImportanceAnalyzer(xgb);
// Permutation importance
System.out.println("\nπ Permutation Importance Analysis:");
double[] permutationImportance = analyzer.calculatePermutationImportance(
dataset.X, dataset.y, 5); // 5 permutations
displayFeatureImportance(featureNames, permutationImportance);
// Feature interaction analysis
System.out.println("\nπ Feature Interaction Analysis:");
var interactions = analyzer.calculateFeatureInteractions(dataset.X, dataset.y);
System.out.println("Top 5 Feature Interactions:");
interactions.entrySet().stream()
.sorted(java.util.Map.Entry.<String, Double>comparingByValue().reversed())
.limit(5)
.forEach(entry -> {
System.out.println("- " + entry.getKey() + ": " +
String.format("%.4f", entry.getValue()));
});
// SHAP values for model interpretability
System.out.println("\nπ― SHAP Values Analysis:");
var shapAnalyzer = new SHAPValues(xgb);
// Calculate SHAP values for first 5 samples
for (int i = 0; i < Math.min(5, dataset.X.length); i++) {
double[] shapValues = shapAnalyzer.calculateSHAPValues(dataset.X[i]);
System.out.println("\nSample " + (i + 1) + " SHAP Values:");
// Show top 3 contributing features
var shapContributions = new java.util.ArrayList<java.util.Map.Entry<String, Double>>();
for (int j = 0; j < Math.min(shapValues.length, featureNames.length); j++) {
shapContributions.add(new java.util.AbstractMap.SimpleEntry<>(
featureNames[j], Math.abs(shapValues[j])));
}
shapContributions.sort(java.util.Map.Entry.<String, Double>comparingByValue().reversed());
for (int j = 0; j < Math.min(3, shapContributions.size()); j++) {
var entry = shapContributions.get(j);
System.out.println("- " + entry.getKey() + ": " +
String.format("%.4f", entry.getValue()));
}
}
// Global feature importance ranking
System.out.println("\nπ Global Feature Importance Ranking:");
var globalRanking = analyzer.calculateGlobalRanking(
gainImportance, frequencyImportance, coverImportance, permutationImportance);
for (int i = 0; i < Math.min(10, globalRanking.size()); i++) {
var entry = globalRanking.get(i);
System.out.println((i + 1) + ". " + entry.getKey() + ": " +
String.format("%.4f", entry.getValue()));
}
System.out.println("\nβ
Feature importance analysis completed!");
} catch (Exception e) {
System.err.println("β Error in feature importance analysis: " + e.getMessage());
e.printStackTrace();
}
}
private static void displayFeatureImportance(String[] featureNames, double[] importance) {
// Create feature-importance pairs and sort by importance
var featureImportance = new java.util.ArrayList<java.util.Map.Entry<String, Double>>();
for (int i = 0; i < Math.min(importance.length, featureNames.length); i++) {
featureImportance.add(new java.util.AbstractMap.SimpleEntry<>(featureNames[i], importance[i]));
}
featureImportance.sort(java.util.Map.Entry.<String, Double>comparingByValue().reversed());
// Display top 5
for (int i = 0; i < Math.min(5, featureImportance.size()); i++) {
var entry = featureImportance.get(i);
System.out.println("- " + entry.getKey() + ": " + String.format("%.4f", entry.getValue()));
}
}
}
Production XGBoost Deployment
Enterprise XGBoost System
import org.superml.tree_models.XGBoost;
import org.superml.persistence.ModelPersistence;
import org.superml.inference.InferenceEngine;
import org.superml.monitoring.ModelMonitor;
@Service
public class ProductionXGBoostSystem {
private final InferenceEngine inferenceEngine;
private final ModelMonitor monitor;
private XGBoost productionModel;
public ProductionXGBoostSystem() {
this.inferenceEngine = new InferenceEngine()
.setModelCache(true)
.setPerformanceMonitoring(true)
.setBatchSize(1000)
.setMaxLatency(10); // 10ms max latency
this.monitor = new ModelMonitor()
.setDriftDetection(true)
.setPerformanceThreshold(0.05)
.setAlertingEnabled(true);
loadProductionModel();
}
@PostConstruct
private void loadProductionModel() {
try {
// Load production XGBoost model
this.productionModel = ModelPersistence.load(
"models/production_xgboost.superml", XGBoost.class);
// Register with inference engine
inferenceEngine.registerModel("xgboost_classifier", productionModel);
System.out.println("β
Production XGBoost model loaded successfully");
} catch (Exception e) {
System.err.println("β Error loading production model: " + e.getMessage());
throw new RuntimeException("Failed to load production model", e);
}
}
@PostMapping("/predict")
public PredictionResponse predict(@RequestBody PredictionRequest request) {
long startTime = System.currentTimeMillis();
try {
// Validate input
if (request.getFeatures() == null || request.getFeatures().length == 0) {
throw new IllegalArgumentException("Features cannot be empty");
}
// Make prediction using inference engine
double[][] features = new double[][]{request.getFeatures()};
double[] predictions = inferenceEngine.predict("xgboost_classifier", features);
double[][] probabilities = inferenceEngine.predictProbabilities("xgboost_classifier", features);
// Get feature importance for this prediction
double[] featureImportance = productionModel.getFeatureImportance();
// Monitor prediction
monitor.recordPrediction(features[0], predictions[0]);
long latency = System.currentTimeMillis() - startTime;
return new PredictionResponse()
.setPrediction(predictions[0])
.setProbabilities(probabilities[0])
.setFeatureImportance(featureImportance)
.setLatency(latency)
.setModelVersion(productionModel.getVersion())
.setConfidence(calculateConfidence(probabilities[0]));
} catch (Exception e) {
System.err.println("β Error in prediction: " + e.getMessage());
return new PredictionResponse()
.setError(e.getMessage())
.setLatency(System.currentTimeMillis() - startTime);
}
}
@PostMapping("/predict-batch")
public BatchPredictionResponse predictBatch(@RequestBody BatchPredictionRequest request) {
long startTime = System.currentTimeMillis();
try {
double[][] features = request.getFeatures();
// Batch prediction for high throughput
double[] predictions = inferenceEngine.predict("xgboost_classifier", features);
double[][] probabilities = inferenceEngine.predictProbabilities("xgboost_classifier", features);
// Monitor batch prediction
monitor.recordBatchPrediction(features, predictions);
long latency = System.currentTimeMillis() - startTime;
double throughput = (double) features.length / latency * 1000; // predictions per second
return new BatchPredictionResponse()
.setPredictions(predictions)
.setProbabilities(probabilities)
.setLatency(latency)
.setThroughput(throughput)
.setBatchSize(features.length)
.setModelVersion(productionModel.getVersion());
} catch (Exception e) {
System.err.println("β Error in batch prediction: " + e.getMessage());
return new BatchPredictionResponse()
.setError(e.getMessage())
.setLatency(System.currentTimeMillis() - startTime);
}
}
@GetMapping("/model-info")
public ModelInfoResponse getModelInfo() {
try {
var history = productionModel.getTrainingHistory();
double[] featureImportance = productionModel.getFeatureImportance();
return new ModelInfoResponse()
.setModelType("XGBoost")
.setVersion(productionModel.getVersion())
.setTrainingAccuracy(history.getBestScore())
.setBestIteration(history.getBestIteration())
.setFeatureCount(featureImportance.length)
.setTrainingTime(history.getTrainingTime())
.setHyperparameters(getHyperparameters())
.setFeatureImportance(featureImportance);
} catch (Exception e) {
System.err.println("β Error getting model info: " + e.getMessage());
return new ModelInfoResponse().setError(e.getMessage());
}
}
@Scheduled(fixedRate = 3600000) // Hourly monitoring
public void monitorModelPerformance() {
try {
// Check model drift
boolean driftDetected = monitor.checkDrift();
if (driftDetected) {
System.out.println("π¨ Model drift detected - alerting operations team");
alertOperationsTeam("Model drift detected in production XGBoost model");
}
// Check performance degradation
double currentPerformance = monitor.getCurrentPerformance();
double baselinePerformance = monitor.getBaselinePerformance();
if (currentPerformance < baselinePerformance - 0.05) {
System.out.println("β οΈ Performance degradation detected");
alertOperationsTeam("Performance degradation in production XGBoost model");
}
// Log performance metrics
System.out.println("π Model Performance Update:");
System.out.println("- Current Performance: " + String.format("%.4f", currentPerformance));
System.out.println("- Baseline Performance: " + String.format("%.4f", baselinePerformance));
System.out.println("- Predictions Today: " + monitor.getTodaysPredictionCount());
System.out.println("- Average Latency: " + monitor.getAverageLatency() + " ms");
} catch (Exception e) {
System.err.println("β Error in model monitoring: " + e.getMessage());
}
}
private double calculateConfidence(double[] probabilities) {
return java.util.Arrays.stream(probabilities).max().orElse(0.0);
}
private Map<String, Object> getHyperparameters() {
Map<String, Object> params = new HashMap<>();
params.put("numBoostRound", productionModel.getNumBoostRound());
params.put("learningRate", productionModel.getLearningRate());
params.put("maxDepth", productionModel.getMaxDepth());
params.put("regAlpha", productionModel.getRegAlpha());
params.put("regLambda", productionModel.getRegLambda());
return params;
}
private void alertOperationsTeam(String message) {
// Send alert to operations team
System.out.println("π¨ ALERT: " + message);
}
}
Performance Benchmarking
XGBoost vs Other Algorithms
import org.superml.tree_models.XGBoost;
import org.superml.tree_models.RandomForest;
import org.superml.tree_models.GradientBoosting;
import org.superml.linear_model.LogisticRegression;
import org.superml.datasets.Datasets;
import org.superml.model_selection.ModelSelection;
import org.superml.metrics.Metrics;
public class XGBoostBenchmark {
public static void main(String[] args) {
System.out.println("=== SuperML 2.1.0 - XGBoost Performance Benchmark ===\n");
try {
// Load large dataset for benchmarking
var dataset = Datasets.makeClassification(10000, 100, 10, 42);
var split = ModelSelection.trainTestSplit(dataset.X, dataset.y, 0.2, 42);
System.out.println("π Benchmark Dataset:");
System.out.println("- Samples: " + dataset.X.length);
System.out.println("- Features: " + dataset.X[0].length);
System.out.println("- Classes: " + (int)(java.util.Arrays.stream(dataset.y).max().orElse(0) + 1));
System.out.println("- Training samples: " + split.XTrain.length);
System.out.println("- Test samples: " + split.XTest.length);
// Define models for comparison
var models = new java.util.LinkedHashMap<String, Object>();
// XGBoost
models.put("XGBoost", new XGBoost()
.setObjective("multi:softprob")
.setNumBoostRound(100)
.setLearningRate(0.1)
.setMaxDepth(6)
.setVerbose(false));
// Random Forest
models.put("Random Forest", new RandomForest()
.setNEstimators(100)
.setMaxDepth(10)
.setVerbose(false));
// Gradient Boosting
models.put("Gradient Boosting", new GradientBoosting()
.setNEstimators(100)
.setLearningRate(0.1)
.setMaxDepth(6)
.setVerbose(false));
// Logistic Regression
models.put("Logistic Regression", new LogisticRegression()
.setMaxIter(1000)
.setVerbose(false));
System.out.println("\nπ Starting Benchmark...");
// Benchmark results
var results = new java.util.LinkedHashMap<String, BenchmarkResult>();
for (var entry : models.entrySet()) {
String algorithmName = entry.getKey();
Object model = entry.getValue();
System.out.println("\nπ Benchmarking " + algorithmName + "...");
// Training time
long trainStart = System.currentTimeMillis();
if (model instanceof XGBoost) {
((XGBoost) model).fit(split.XTrain, split.yTrain);
} else if (model instanceof RandomForest) {
((RandomForest) model).fit(split.XTrain, split.yTrain);
} else if (model instanceof GradientBoosting) {
((GradientBoosting) model).fit(split.XTrain, split.yTrain);
} else if (model instanceof LogisticRegression) {
((LogisticRegression) model).fit(split.XTrain, split.yTrain);
}
long trainTime = System.currentTimeMillis() - trainStart;
// Prediction time
long predStart = System.currentTimeMillis();
double[] predictions = null;
if (model instanceof XGBoost) {
predictions = ((XGBoost) model).predict(split.XTest);
} else if (model instanceof RandomForest) {
predictions = ((RandomForest) model).predict(split.XTest);
} else if (model instanceof GradientBoosting) {
predictions = ((GradientBoosting) model).predict(split.XTest);
} else if (model instanceof LogisticRegression) {
predictions = ((LogisticRegression) model).predict(split.XTest);
}
long predTime = System.currentTimeMillis() - predStart;
// Calculate metrics
double accuracy = Metrics.accuracy(split.yTest, predictions);
double f1 = Metrics.f1Score(split.yTest, predictions);
// Memory usage
Runtime runtime = Runtime.getRuntime();
long memoryUsed = runtime.totalMemory() - runtime.freeMemory();
results.put(algorithmName, new BenchmarkResult()
.setAccuracy(accuracy)
.setF1Score(f1)
.setTrainingTime(trainTime)
.setPredictionTime(predTime)
.setMemoryUsage(memoryUsed / 1024 / 1024)); // MB
System.out.println("- Training time: " + trainTime + " ms");
System.out.println("- Prediction time: " + predTime + " ms");
System.out.println("- Accuracy: " + String.format("%.4f", accuracy));
System.out.println("- F1 Score: " + String.format("%.4f", f1));
}
// Display benchmark results
System.out.println("\n=== Benchmark Results Summary ===");
System.out.println(String.format("%-20s %12s %12s %12s %12s %12s",
"Algorithm", "Accuracy", "F1 Score", "Train (ms)", "Pred (ms)", "Memory (MB)"));
System.out.println("=" .repeat(100));
for (var entry : results.entrySet()) {
String name = entry.getKey();
BenchmarkResult result = entry.getValue();
System.out.println(String.format("%-20s %12.4f %12.4f %12d %12d %12d",
name,
result.getAccuracy(),
result.getF1Score(),
result.getTrainingTime(),
result.getPredictionTime(),
result.getMemoryUsage()));
}
// Performance analysis
System.out.println("\nπ Performance Analysis:");
// Find best performing algorithm
var bestAccuracy = results.entrySet().stream()
.max(java.util.Map.Entry.comparingByValue(
(a, b) -> Double.compare(a.getAccuracy(), b.getAccuracy())));
var fastestTraining = results.entrySet().stream()
.min(java.util.Map.Entry.comparingByValue(
(a, b) -> Long.compare(a.getTrainingTime(), b.getTrainingTime())));
var fastestPrediction = results.entrySet().stream()
.min(java.util.Map.Entry.comparingByValue(
(a, b) -> Long.compare(a.getPredictionTime(), b.getPredictionTime())));
System.out.println("π Best Accuracy: " + bestAccuracy.get().getKey() +
" (" + String.format("%.4f", bestAccuracy.get().getValue().getAccuracy()) + ")");
System.out.println("β‘ Fastest Training: " + fastestTraining.get().getKey() +
" (" + fastestTraining.get().getValue().getTrainingTime() + " ms)");
System.out.println("π Fastest Prediction: " + fastestPrediction.get().getKey() +
" (" + fastestPrediction.get().getValue().getPredictionTime() + " ms)");
// XGBoost specific analysis
var xgboostResult = results.get("XGBoost");
System.out.println("\nπ XGBoost Analysis:");
System.out.println("- Balanced performance across all metrics");
System.out.println("- Excellent accuracy: " + String.format("%.4f", xgboostResult.getAccuracy()));
System.out.println("- Reasonable training time: " + xgboostResult.getTrainingTime() + " ms");
System.out.println("- Fast prediction: " + xgboostResult.getPredictionTime() + " ms");
System.out.println("- Memory efficient: " + xgboostResult.getMemoryUsage() + " MB");
System.out.println("\nβ
XGBoost benchmark completed successfully!");
} catch (Exception e) {
System.err.println("β Error in benchmark: " + e.getMessage());
e.printStackTrace();
}
}
private static class BenchmarkResult {
private double accuracy;
private double f1Score;
private long trainingTime;
private long predictionTime;
private long memoryUsage;
// Getters and setters
public double getAccuracy() { return accuracy; }
public BenchmarkResult setAccuracy(double accuracy) { this.accuracy = accuracy; return this; }
public double getF1Score() { return f1Score; }
public BenchmarkResult setF1Score(double f1Score) { this.f1Score = f1Score; return this; }
public long getTrainingTime() { return trainingTime; }
public BenchmarkResult setTrainingTime(long trainingTime) { this.trainingTime = trainingTime; return this; }
public long getPredictionTime() { return predictionTime; }
public BenchmarkResult setPredictionTime(long predictionTime) { this.predictionTime = predictionTime; return this; }
public long getMemoryUsage() { return memoryUsage; }
public BenchmarkResult setMemoryUsage(long memoryUsage) { this.memoryUsage = memoryUsage; return this; }
}
}
Best Practices
1. XGBoost Configuration
- Learning Rate: Start with 0.1, decrease for better performance
- Tree Depth: Use 3-10 for most problems
- Regularization: Always use L1/L2 regularization
- Early Stopping: Prevent overfitting with validation monitoring
2. Hyperparameter Tuning
- Sequential Tuning: Tune parameters in order of importance
- Cross-Validation: Use proper CV for reliable estimates
- Time Budget: Set reasonable time limits for optimization
- Parallel Processing: Use multiple cores for faster tuning
3. Feature Engineering
- Feature Importance: Use XGBoostβs built-in feature importance
- Interaction Effects: XGBoost handles interactions well
- Missing Values: XGBoost handles missing values naturally
- Categorical Features: Use label encoding or one-hot encoding
4. Production Deployment
- Model Versioning: Track model versions and performance
- Monitoring: Implement drift detection and performance monitoring
- Caching: Cache models for faster inference
- Batch Processing: Use batch predictions for high throughput
Summary
In this tutorial, you learned:
- XGBoost Implementation: Classification and regression with SuperML Java
- Lightning-Fast Training: Optimized training with early stopping
- Advanced Regularization: L1/L2 regularization and tree pruning
- Hyperparameter Optimization: Grid search and Bayesian optimization
- Feature Importance: Multiple importance metrics and SHAP values
- Production Deployment: Enterprise-ready XGBoost systems
- Performance Benchmarking: Comparing with other algorithms
XGBoost in SuperML Java 2.1.0 provides enterprise-grade performance with sophisticated optimization techniques. The framework handles the complexity of gradient boosting while providing you with intuitive APIs and professional deployment capabilities.
Next Steps
- Try AutoML: Automated XGBoost optimization
- Explore Neural Networks: Deep learning with MLP, CNN, and RNN
- Model Ensembles: Combining XGBoost with other algorithms
- Advanced Preprocessing: Feature engineering for XGBoost
- MLOps Integration: CI/CD pipelines for XGBoost models
Youβre now ready to build production-grade XGBoost applications with SuperML Java 2.1.0!