Course Content
Linear Regression in Java
Implementing linear regression using SuperML Java
Linear Regression with SuperML Java
Linear regression is one of the fundamental algorithms in machine learning, perfect for predicting numerical values. The SuperML Java framework makes implementing linear regression straightforward with its intuitive API.
Understanding Linear Regression
Linear regression finds the best line that fits through your data points, allowing you to predict new values based on the learned relationship between input features and target values.
When to Use Linear Regression
- Predicting continuous numerical values (house prices, stock prices, temperatures)
- Understanding relationships between variables
- Baseline models for comparison with more complex algorithms
- When interpretability is important
Simple Linear Regression
Letβs start with a simple example predicting house prices based on size:
import org.superml.data.DataLoader;
import org.superml.data.Dataset;
import org.superml.models.LinearRegression;
import org.superml.metrics.Metrics;
public class SimpleLinearRegression {
public static void main(String[] args) {
// Create sample data: house sizes and prices
double[][] features = {
{1000}, {1200}, {1400}, {1600}, {1800},
{2000}, {2200}, {2400}, {2600}, {2800}
};
double[] labels = {
150000, 180000, 210000, 240000, 270000,
300000, 330000, 360000, 390000, 420000
};
// Create dataset
Dataset data = DataLoader.fromArrays(features, labels);
data.split(0.8); // 80% training, 20% testing
// Create and configure model
LinearRegression model = new LinearRegression();
model.setLearningRate(0.01);
model.setMaxIterations(1000);
// Train the model
model.fit(data.getTrainX(), data.getTrainY());
// Make predictions
double[] predictions = model.predict(data.getTestX());
// Evaluate performance
double rmse = Metrics.rmse(data.getTestY(), predictions);
double mae = Metrics.mae(data.getTestY(), predictions);
double r2 = Metrics.r2Score(data.getTestY(), predictions);
System.out.println("Model Performance:");
System.out.println("RMSE: " + rmse);
System.out.println("MAE: " + mae);
System.out.println("RΒ² Score: " + r2);
// Print model coefficients
double[] coefficients = model.getCoefficients();
double intercept = model.getIntercept();
System.out.println("\nModel Equation:");
System.out.println("Price = " + intercept + " + " + coefficients[0] + " * Size");
}
}
Multiple Linear Regression
Now letβs create a more complex model with multiple features:
import org.superml.data.DataLoader;
import org.superml.data.Dataset;
import org.superml.models.LinearRegression;
import org.superml.preprocessing.StandardScaler;
public class MultipleLinearRegression {
public static void main(String[] args) {
// Load real estate dataset
Dataset data = DataLoader.fromCSV("real_estate.csv");
// Features: size, bedrooms, bathrooms, age, location_score
// Target: price
// Preprocess the data
StandardScaler scaler = new StandardScaler();
double[][] scaledFeatures = scaler.fitTransform(data.getFeatures());
data.setFeatures(scaledFeatures);
// Split data
data.split(0.7, 0.15, 0.15); // 70% train, 15% validation, 15% test
// Create model with regularization
LinearRegression model = new LinearRegression();
model.setLearningRate(0.01);
model.setMaxIterations(2000);
model.setRegularization(0.001); // L2 regularization
model.setConvergenceTolerance(1e-6);
// Train the model
model.fit(data.getTrainX(), data.getTrainY());
// Validate on validation set
double[] valPredictions = model.predict(data.getValidationX());
double valRmse = Metrics.rmse(data.getValidationY(), valPredictions);
System.out.println("Validation RMSE: " + valRmse);
// Final evaluation on test set
double[] testPredictions = model.predict(data.getTestX());
double testRmse = Metrics.rmse(data.getTestY(), testPredictions);
double testR2 = Metrics.r2Score(data.getTestY(), testPredictions);
System.out.println("Test RMSE: " + testRmse);
System.out.println("Test RΒ²: " + testR2);
// Feature importance
double[] coefficients = model.getCoefficients();
String[] featureNames = {"size", "bedrooms", "bathrooms", "age", "location_score"};
System.out.println("\nFeature Importance:");
for (int i = 0; i < coefficients.length; i++) {
System.out.println(featureNames[i] + ": " + Math.abs(coefficients[i]));
}
// Save the model
model.save("house_price_model.superml");
// Save the scaler for future use
scaler.save("feature_scaler.superml");
}
}
Advanced Configuration
Regularization
Prevent overfitting with L1 (Lasso) or L2 (Ridge) regularization:
LinearRegression model = new LinearRegression();
// L2 Regularization (Ridge)
model.setRegularization(0.01);
model.setRegularizationType(RegularizationType.L2);
// L1 Regularization (Lasso)
model.setRegularization(0.01);
model.setRegularizationType(RegularizationType.L1);
// Elastic Net (combination of L1 and L2)
model.setRegularization(0.01);
model.setRegularizationType(RegularizationType.ELASTIC_NET);
model.setL1Ratio(0.5); // 50% L1, 50% L2
Early Stopping
Prevent overfitting with early stopping:
model.setEarlyStoppingEnabled(true);
model.setValidationSplit(0.2); // Use 20% of training data for validation
model.setEarlyStoppingPatience(10); // Stop if no improvement for 10 iterations
model.setEarlyStoppingMinDelta(1e-4); // Minimum improvement threshold
Learning Rate Scheduling
Adjust learning rate during training:
model.setLearningRateScheduler(LearningRateScheduler.EXPONENTIAL_DECAY);
model.setLearningRateDecayRate(0.95);
model.setLearningRateDecaySteps(100);
Model Evaluation and Diagnostics
Residual Analysis
// Get predictions and residuals
double[] predictions = model.predict(data.getTestX());
double[] residuals = Metrics.residuals(data.getTestY(), predictions);
// Check for patterns in residuals
double meanResidual = Arrays.stream(residuals).average().orElse(0.0);
double residualStd = Metrics.standardDeviation(residuals);
System.out.println("Mean residual: " + meanResidual);
System.out.println("Residual std: " + residualStd);
// Durbin-Watson test for autocorrelation
double dwTest = Metrics.durbinWatsonTest(residuals);
System.out.println("Durbin-Watson statistic: " + dwTest);
Cross-Validation
import org.superml.validation.CrossValidator;
// Perform 5-fold cross-validation
CrossValidator cv = new CrossValidator(5);
double[] cvScores = cv.crossValidateScore(model, data.getFeatures(), data.getLabels());
double meanScore = Arrays.stream(cvScores).average().orElse(0.0);
double stdScore = Metrics.standardDeviation(cvScores);
System.out.println("CV Mean Score: " + meanScore);
System.out.println("CV Std Score: " + stdScore);
Polynomial Features
Create polynomial features for non-linear relationships:
import org.superml.preprocessing.PolynomialFeatures;
// Create polynomial features up to degree 2
PolynomialFeatures polyFeatures = new PolynomialFeatures(2);
double[][] polyData = polyFeatures.fitTransform(data.getFeatures());
// Use polynomial features with linear regression
Dataset polyDataset = new Dataset(polyData, data.getLabels());
polyDataset.split(0.8);
LinearRegression polyModel = new LinearRegression();
polyModel.setRegularization(0.01); // Important for polynomial features
polyModel.fit(polyDataset.getTrainX(), polyDataset.getTrainY());
Production Deployment
Model Serving
import org.superml.serving.ModelServer;
public class HousePricePredictionService {
private LinearRegression model;
private StandardScaler scaler;
public void initialize() {
// Load trained model and scaler
this.model = LinearRegression.load("house_price_model.superml");
this.scaler = StandardScaler.load("feature_scaler.superml");
}
public double predictPrice(double size, int bedrooms, int bathrooms,
double age, double locationScore) {
// Prepare features
double[][] features = {{size, bedrooms, bathrooms, age, locationScore}};
// Scale features
double[][] scaledFeatures = scaler.transform(features);
// Make prediction
double[] prediction = model.predict(scaledFeatures);
return prediction[0];
}
public PredictionResult predictWithConfidence(double[] features) {
double[][] featureArray = {features};
double[][] scaledFeatures = scaler.transform(featureArray);
double[] prediction = model.predict(scaledFeatures);
double[] confidence = model.predictWithConfidence(scaledFeatures);
return new PredictionResult(prediction[0], confidence[0]);
}
}
Common Pitfalls and Solutions
1. Feature Scaling
Always scale features when they have different ranges:
// Bad: features with different scales
double[][] features = {{1000, 3}, {1200, 4}}; // size vs bedrooms
// Good: scaled features
StandardScaler scaler = new StandardScaler();
double[][] scaledFeatures = scaler.fitTransform(features);
2. Multicollinearity
Check for highly correlated features:
// Calculate correlation matrix
double[][] correlationMatrix = Metrics.correlationMatrix(data.getFeatures());
// Remove highly correlated features (correlation > 0.9)
FeatureSelector selector = new FeatureSelector();
int[] selectedFeatures = selector.removeHighlyCorrelated(correlationMatrix, 0.9);
3. Overfitting
Use regularization and cross-validation:
// Use regularization
model.setRegularization(0.01);
// Validate with cross-validation
CrossValidator cv = new CrossValidator(5);
double[] scores = cv.crossValidateScore(model, features, labels);
Next Steps
- Logistic Regression - Learn classification techniques
- Polynomial Regression - Handle non-linear relationships
- Regularized Regression - Advanced regularization techniques
- Model Selection - Choose the best model
Resources
Linear regression with SuperML Java provides a solid foundation for predictive modeling. The frameworkβs intuitive API makes it easy to implement, evaluate, and deploy regression models in production Java applications.