· Java Machine Learning · 12 min read
Linear Regression with SuperML Java
Linear regression is one of the fundamental algorithms in machine learning and often the first algorithm students learn. It’s simple to understand, implement, and interpret, making it an excellent starting point for your machine learning journey with SuperML Java.
What is Linear Regression?
Linear regression models the relationship between a dependent variable (target) and one or more independent variables (features) by fitting a linear equation to observed data. The goal is to find the best line that minimizes the difference between predicted and actual values.
Mathematical Foundation
For simple linear regression:
y = β₀ + β₁x + ε
For multiple linear regression:
y = β₀ + β₁x₁ + β₂x₂ + ... + βₙxₙ + ε
Where:
y
is the target variablex₁, x₂, ..., xₙ
are feature variablesβ₀
is the intercept (bias)β₁, β₂, ..., βₙ
are coefficients (weights)ε
is the error term
Simple Linear Regression
Let’s start with a basic example using house size to predict price:
import org.superml.linear_model.LinearRegression;
import org.superml.datasets.Datasets;
import org.superml.metrics.Metrics;
public class SimpleLinearRegressionExample {
public static void main(String[] args) {
// Create sample data: house size vs price
double[][] X = {{800}, {1000}, {1200}, {1400}, {1600}, {1800}, {2000}};
double[] y = {150000, 180000, 210000, 240000, 270000, 300000, 330000};
// Create and train the model
LinearRegression model = new LinearRegression();
model.fit(X, y);
// Make predictions
double[][] testX = {{900}, {1500}, {2200}};
double[] predictions = model.predict(testX);
// Display results
System.out.println("Model coefficients:");
System.out.println("Intercept: " + model.getIntercept());
System.out.println("Coefficient: " + model.getCoefficients()[0]);
System.out.println("\nPredictions:");
for (int i = 0; i < testX.length; i++) {
System.out.printf("House size: %.0f sq ft -> Predicted price: $%.2f%n",
testX[i][0], predictions[i]);
}
}
}
Loading Real Data
Let’s work with a more realistic dataset:
public class HousePricePrediction {
public void predictHousePrices() {
// Load housing dataset
Dataset data = DataLoader.fromCSV("data/house_prices.csv");
// Display basic information
System.out.println("Dataset shape: " + data.getShape());
System.out.println("Features: " + data.getFeatureNames());
System.out.println(data.describe());
// Select features for simple regression
Dataset subset = data.selectColumns("size", "price");
// Split into features and target
double[][] X = subset.getFeatures("size");
double[] y = subset.getTarget("price");
// Split into train/test sets
DataSplit split = subset.split(0.8);
// Train the model
LinearRegression model = new LinearRegression();
model.fit(split.getTrainX(), split.getTrainY());
// Evaluate on test set
double[] predictions = model.predict(split.getTestX());
double rmse = Metrics.rmse(split.getTestY(), predictions);
double r2 = Metrics.r2Score(split.getTestY(), predictions);
System.out.printf("Test RMSE: %.2f%n", rmse);
System.out.printf("Test R²: %.3f%n", r2);
}
}
Multiple Linear Regression
Real-world problems usually involve multiple features:
public class MultipleLinearRegressionExample {
public void multipleFeatureExample() {
// Load dataset with multiple features
Dataset data = DataLoader.fromCSV("data/house_prices_complete.csv");
// Preprocessing
data = preprocessData(data);
// Select relevant features
String[] features = {"size", "bedrooms", "bathrooms", "age", "garage"};
Dataset modelData = data.selectColumns(features, "price");
// Split data
DataSplit split = modelData.split(0.8);
// Create and configure model
LinearRegression model = new LinearRegression()
.setFitIntercept(true)
.setNormalize(false); // Already preprocessed
// Train model
model.fit(split.getTrainX(), split.getTrainY());
// Model interpretation
interpretModel(model, features);
// Evaluation
evaluateModel(model, split);
}
private Dataset preprocessData(Dataset data) {
// Handle missing values
data = data.fillMissing(Strategy.MEAN);
// Remove outliers (optional)
OutlierDetector detector = new IQRDetector(multiplier=1.5);
boolean[] outliers = detector.detect(data, "price");
data = data.removeOutliers(outliers);
// Scale features
StandardScaler scaler = new StandardScaler();
String[] numericFeatures = {"size", "bedrooms", "bathrooms", "age", "garage"};
data = scaler.fitTransform(data, numericFeatures);
return data;
}
private void interpretModel(LinearRegression model, String[] features) {
System.out.println("Model Coefficients:");
System.out.printf("Intercept: %.2f%n", model.getIntercept());
double[] coefficients = model.getCoefficients();
for (int i = 0; i < features.length; i++) {
System.out.printf("%-12s: %.4f%n", features[i], coefficients[i]);
}
// Feature importance (absolute coefficient values)
System.out.println("\nFeature Importance (|coefficient|):");
Map<String, Double> importance = new HashMap<>();
for (int i = 0; i < features.length; i++) {
importance.put(features[i], Math.abs(coefficients[i]));
}
importance.entrySet().stream()
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.forEach(entry ->
System.out.printf("%-12s: %.4f%n", entry.getKey(), entry.getValue()));
}
private void evaluateModel(LinearRegression model, DataSplit split) {
// Training metrics
double[] trainPreds = model.predict(split.getTrainX());
double trainRmse = Metrics.rmse(split.getTrainY(), trainPreds);
double trainR2 = Metrics.r2Score(split.getTrainY(), trainPreds);
// Test metrics
double[] testPreds = model.predict(split.getTestX());
double testRmse = Metrics.rmse(split.getTestY(), testPreds);
double testR2 = Metrics.r2Score(split.getTestY(), testPreds);
System.out.println("\nModel Performance:");
System.out.printf("Training RMSE: %.2f, R²: %.3f%n", trainRmse, trainR2);
System.out.printf("Test RMSE: %.2f, R²: %.3f%n", testRmse, testR2);
// Check for overfitting
if (trainR2 - testR2 > 0.1) {
System.out.println("Warning: Possible overfitting detected!");
}
}
}
Regularization Techniques
Regularization helps prevent overfitting by penalizing large coefficients:
Ridge Regression (L2 Regularization)
public class RidgeRegressionExample {
public void ridgeRegressionDemo() {
Dataset data = DataLoader.fromCSV("data/house_prices.csv");
data = preprocessData(data);
DataSplit split = data.split(0.8);
// Ridge regression with different alpha values
double[] alphas = {0.1, 1.0, 10.0, 100.0};
for (double alpha : alphas) {
RidgeRegression model = new RidgeRegression(alpha);
model.fit(split.getTrainX(), split.getTrainY());
double[] predictions = model.predict(split.getTestX());
double rmse = Metrics.rmse(split.getTestY(), predictions);
double r2 = Metrics.r2Score(split.getTestY(), predictions);
System.out.printf("Alpha: %.1f, RMSE: %.2f, R²: %.3f%n", alpha, rmse, r2);
}
}
public void findOptimalAlpha() {
Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
// Cross-validation for hyperparameter tuning
RidgeRegressionCV ridgeCV = new RidgeRegressionCV()
.setAlphas(0.1, 1.0, 10.0, 100.0, 1000.0)
.setCv(5); // 5-fold cross-validation
ridgeCV.fit(data.getFeatures(), data.getTargets());
System.out.println("Optimal alpha: " + ridgeCV.getBestAlpha());
System.out.println("Best CV score: " + ridgeCV.getBestScore());
}
}
Lasso Regression (L1 Regularization)
public class LassoRegressionExample {
public void lassoRegressionDemo() {
Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
DataSplit split = data.split(0.8);
// Lasso regression for feature selection
LassoRegression model = new LassoRegression(alpha=1.0);
model.fit(split.getTrainX(), split.getTrainY());
// Check which features were selected (non-zero coefficients)
double[] coefficients = model.getCoefficients();
String[] featureNames = data.getFeatureNames();
System.out.println("Selected features:");
for (int i = 0; i < coefficients.length; i++) {
if (Math.abs(coefficients[i]) > 1e-6) {
System.out.printf("%-15s: %.4f%n", featureNames[i], coefficients[i]);
}
}
// Evaluation
double[] predictions = model.predict(split.getTestX());
double rmse = Metrics.rmse(split.getTestY(), predictions);
System.out.printf("Test RMSE: %.2f%n", rmse);
}
public void lassoPath() {
Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
// Compute regularization path
LassoPath path = new LassoPath();
path.fit(data.getFeatures(), data.getTargets());
// Get coefficients for different alpha values
double[] alphas = path.getAlphas();
double[][] coefficientPaths = path.getCoefficientPaths();
System.out.println("Regularization path:");
for (int i = 0; i < alphas.length; i += 10) { // Sample every 10th alpha
System.out.printf("Alpha: %.4f, Non-zero coefficients: %d%n",
alphas[i], countNonZero(coefficientPaths[i]));
}
}
private int countNonZero(double[] coefficients) {
return (int) Arrays.stream(coefficients)
.map(Math::abs)
.filter(c -> c > 1e-6)
.count();
}
}
Elastic Net (L1 + L2 Regularization)
public class ElasticNetExample {
public void elasticNetDemo() {
Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
DataSplit split = data.split(0.8);
// Elastic Net combines Ridge and Lasso
ElasticNet model = new ElasticNet()
.setAlpha(1.0) // Overall regularization strength
.setL1Ratio(0.5); // Mix between L1 (1.0) and L2 (0.0)
model.fit(split.getTrainX(), split.getTrainY());
// Cross-validation for optimal parameters
ElasticNetCV modelCV = new ElasticNetCV()
.setAlphas(0.1, 1.0, 10.0)
.setL1Ratios(0.1, 0.5, 0.7, 0.9)
.setCv(5);
modelCV.fit(split.getTrainX(), split.getTrainY());
System.out.println("Best alpha: " + modelCV.getBestAlpha());
System.out.println("Best l1_ratio: " + modelCV.getBestL1Ratio());
double[] predictions = modelCV.predict(split.getTestX());
double rmse = Metrics.rmse(split.getTestY(), predictions);
System.out.printf("Test RMSE: %.2f%n", rmse);
}
}
Advanced Features
Polynomial Regression
public class PolynomialRegressionExample {
public void polynomialRegressionDemo() {
// Generate sample data with non-linear relationship
double[][] X = generateNonLinearData();
double[] y = generateNonLinearTargets(X);
DataSplit split = splitData(X, y, 0.8);
// Try different polynomial degrees
int[] degrees = {1, 2, 3, 4, 5};
for (int degree : degrees) {
// Create polynomial features
PolynomialFeatures polyFeatures = new PolynomialFeatures(degree);
double[][] XPoly = polyFeatures.fitTransform(split.getTrainX());
double[][] XTestPoly = polyFeatures.transform(split.getTestX());
// Fit linear regression on polynomial features
LinearRegression model = new LinearRegression();
model.fit(XPoly, split.getTrainY());
// Evaluate
double[] predictions = model.predict(XTestPoly);
double rmse = Metrics.rmse(split.getTestY(), predictions);
double r2 = Metrics.r2Score(split.getTestY(), predictions);
System.out.printf("Degree: %d, RMSE: %.3f, R²: %.3f%n", degree, rmse, r2);
}
}
// Use with regularization to prevent overfitting
public void regularizedPolynomialRegression() {
double[][] X = generateNonLinearData();
double[] y = generateNonLinearTargets(X);
// High-degree polynomial with regularization
PolynomialFeatures polyFeatures = new PolynomialFeatures(degree=8);
double[][] XPoly = polyFeatures.fitTransform(X);
// Ridge regression to control overfitting
RidgeRegression model = new RidgeRegression(alpha=100.0);
model.fit(XPoly, y);
System.out.println("High-degree polynomial with Ridge regularization trained successfully");
}
}
Custom Features and Transformations
public class CustomFeatureEngineering {
public void createCustomFeatures() {
Dataset data = DataLoader.fromCSV("data/house_prices.csv");
// Create new features
data = data.withNewColumn("price_per_sqft",
row -> row.getDouble("price") / row.getDouble("size"));
data = data.withNewColumn("bedroom_ratio",
row -> row.getDouble("bedrooms") / row.getDouble("size"));
data = data.withNewColumn("total_rooms",
row -> row.getDouble("bedrooms") + row.getDouble("bathrooms"));
// Log transformation for skewed features
data = data.withTransformedColumn("log_price", "price", Math::log);
data = data.withTransformedColumn("sqrt_size", "size", Math::sqrt);
// Interaction features
data = data.withNewColumn("size_bedrooms_interaction",
row -> row.getDouble("size") * row.getDouble("bedrooms"));
// Use engineered features in model
String[] features = {"size", "bedrooms", "bathrooms", "price_per_sqft",
"bedroom_ratio", "total_rooms", "size_bedrooms_interaction"};
Dataset modelData = data.selectColumns(features, "price");
DataSplit split = modelData.split(0.8);
LinearRegression model = new LinearRegression();
model.fit(split.getTrainX(), split.getTrainY());
double[] predictions = model.predict(split.getTestX());
double rmse = Metrics.rmse(split.getTestY(), predictions);
System.out.printf("RMSE with engineered features: %.2f%n", rmse);
}
}
Model Validation and Diagnostics
Residual Analysis
public class ModelDiagnostics {
public void residualAnalysis(LinearRegression model, DataSplit split) {
double[] predictions = model.predict(split.getTestX());
double[] actual = split.getTestY();
// Calculate residuals
double[] residuals = new double[actual.length];
for (int i = 0; i < residuals.length; i++) {
residuals[i] = actual[i] - predictions[i];
}
// Residual statistics
double meanResidual = Arrays.stream(residuals).average().orElse(0.0);
double stdResidual = calculateStandardDeviation(residuals);
System.out.printf("Mean residual: %.4f (should be close to 0)%n", meanResidual);
System.out.printf("Std residual: %.4f%n", stdResidual);
// Check for heteroscedasticity
checkHomoscedasticity(predictions, residuals);
// Check for normality of residuals
checkNormality(residuals);
}
private void checkHomoscedasticity(double[] predictions, double[] residuals) {
// Simple test: correlation between |residuals| and predictions
double[] absResiduals = Arrays.stream(residuals).map(Math::abs).toArray();
double correlation = calculateCorrelation(predictions, absResiduals);
System.out.printf("Correlation |residuals| vs predictions: %.4f%n", correlation);
if (Math.abs(correlation) > 0.3) {
System.out.println("Warning: Possible heteroscedasticity detected!");
}
}
private void checkNormality(double[] residuals) {
// Simple normality check using skewness and kurtosis
double skewness = calculateSkewness(residuals);
double kurtosis = calculateKurtosis(residuals);
System.out.printf("Residual skewness: %.4f (should be close to 0)%n", skewness);
System.out.printf("Residual kurtosis: %.4f (should be close to 3)%n", kurtosis);
}
}
Cross-Validation
public class CrossValidationExample {
public void crossValidateModel() {
Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
// K-Fold Cross-Validation
KFoldValidator validator = new KFoldValidator(k=5);
CrossValidationResult result = validator.validate(
new LinearRegression(),
data.getFeatures(),
data.getTargets(),
Metrics.RMSE
);
System.out.println("Cross-Validation Results:");
System.out.printf("Mean RMSE: %.2f (±%.2f)%n",
result.getMean(), result.getStd());
System.out.printf("Individual fold scores: %s%n",
Arrays.toString(result.getScores()));
// Time Series Cross-Validation (if data has time component)
if (data.hasColumn("date")) {
TimeSeriesValidator tsValidator = new TimeSeriesValidator();
CrossValidationResult tsResult = tsValidator.validate(
new LinearRegression(),
data.getFeatures(),
data.getTargets(),
data.getColumn("date"),
Metrics.RMSE
);
System.out.printf("Time Series CV Mean RMSE: %.2f%n", tsResult.getMean());
}
}
}
Complete Real-World Example
Let’s put everything together with a comprehensive example:
public class ComprehensiveHousePricePrediction {
public static void main(String[] args) {
ComprehensiveHousePricePrediction predictor =
new ComprehensiveHousePricePrediction();
predictor.runCompleteAnalysis();
}
public void runCompleteAnalysis() {
// Load and explore data
Dataset data = DataLoader.fromCSV("data/house_prices_complete.csv");
exploreData(data);
// Preprocess data
data = preprocessData(data);
// Feature engineering
data = engineerFeatures(data);
// Model selection and training
Map<String, Double> modelResults = compareModels(data);
// Select best model and final training
String bestModel = getBestModel(modelResults);
LinearRegression finalModel = trainFinalModel(data, bestModel);
// Model interpretation and validation
interpretFinalModel(finalModel, data);
// Save model for production use
saveModel(finalModel, "house_price_model.pkl");
}
private void exploreData(Dataset data) {
System.out.println("=== Data Exploration ===");
System.out.println("Dataset shape: " + data.getShape());
System.out.println("Missing values: " + data.getMissingCounts());
// Basic statistics
DataSummary summary = data.describe();
System.out.println(summary);
// Correlation with target
Map<String, Double> correlations = data.getCorrelationWithTarget("price");
correlations.entrySet().stream()
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.limit(10)
.forEach(entry ->
System.out.printf("%-15s: %.3f%n", entry.getKey(), entry.getValue()));
}
private Dataset preprocessData(Dataset data) {
System.out.println("\n=== Data Preprocessing ===");
// Handle missing values
data = data.fillMissing(Map.of(
"age", Strategy.MEDIAN,
"garage", 0.0,
"basement", 0.0
));
// Remove outliers
OutlierDetector detector = new IQRDetector(multiplier=1.5);
boolean[] outliers = detector.detect(data, "price");
int outliersRemoved = (int) Arrays.stream(outliers).filter(b -> b).count();
System.out.println("Outliers removed: " + outliersRemoved);
data = data.removeOutliers(outliers);
return data;
}
private Dataset engineerFeatures(Dataset data) {
System.out.println("\n=== Feature Engineering ===");
// Create new features
data = data.withNewColumn("price_per_sqft",
row -> row.getDouble("price") / row.getDouble("size"));
data = data.withNewColumn("rooms_total",
row -> row.getDouble("bedrooms") + row.getDouble("bathrooms"));
data = data.withNewColumn("age_group", row -> {
double age = row.getDouble("age");
if (age < 5) return "new";
else if (age < 20) return "modern";
else if (age < 50) return "mature";
else return "old";
});
// One-hot encode categorical features
OneHotEncoder encoder = new OneHotEncoder();
data = encoder.fitTransform(data, "age_group", "neighborhood");
// Scale numerical features
StandardScaler scaler = new StandardScaler();
String[] numericalFeatures = {"size", "bedrooms", "bathrooms", "age",
"garage", "basement", "price_per_sqft", "rooms_total"};
data = scaler.fitTransform(data, numericalFeatures);
System.out.println("Final feature count: " + data.getFeatureNames().size());
return data;
}
private Map<String, Double> compareModels(Dataset data) {
System.out.println("\n=== Model Comparison ===");
DataSplit split = data.split(0.8);
Map<String, Double> results = new HashMap<>();
// Linear Regression
LinearRegression lr = new LinearRegression();
lr.fit(split.getTrainX(), split.getTrainY());
double[] lrPreds = lr.predict(split.getTestX());
results.put("LinearRegression", Metrics.rmse(split.getTestY(), lrPreds));
// Ridge Regression
RidgeRegressionCV ridge = new RidgeRegressionCV()
.setAlphas(0.1, 1.0, 10.0, 100.0)
.setCv(5);
ridge.fit(split.getTrainX(), split.getTrainY());
double[] ridgePreds = ridge.predict(split.getTestX());
results.put("Ridge", Metrics.rmse(split.getTestY(), ridgePreds));
// Lasso Regression
LassoRegressionCV lasso = new LassoRegressionCV()
.setAlphas(0.1, 1.0, 10.0, 100.0)
.setCv(5);
lasso.fit(split.getTrainX(), split.getTrainY());
double[] lassoPreds = lasso.predict(split.getTestX());
results.put("Lasso", Metrics.rmse(split.getTestY(), lassoPreds));
// Elastic Net
ElasticNetCV elastic = new ElasticNetCV()
.setAlphas(0.1, 1.0, 10.0)
.setL1Ratios(0.1, 0.5, 0.9)
.setCv(5);
elastic.fit(split.getTrainX(), split.getTrainY());
double[] elasticPreds = elastic.predict(split.getTestX());
results.put("ElasticNet", Metrics.rmse(split.getTestY(), elasticPreds));
// Display results
results.entrySet().stream()
.sorted(Map.Entry.comparingByValue())
.forEach(entry ->
System.out.printf("%-15s: RMSE = %.2f%n", entry.getKey(), entry.getValue()));
return results;
}
private String getBestModel(Map<String, Double> results) {
return results.entrySet().stream()
.min(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.orElse("LinearRegression");
}
private LinearRegression trainFinalModel(Dataset data, String modelType) {
System.out.println("\n=== Final Model Training ===");
System.out.println("Best model: " + modelType);
// Train on full dataset
LinearRegression finalModel;
switch (modelType) {
case "Ridge":
finalModel = new RidgeRegressionCV().setAlphas(0.1, 1.0, 10.0, 100.0);
break;
case "Lasso":
finalModel = new LassoRegressionCV().setAlphas(0.1, 1.0, 10.0, 100.0);
break;
case "ElasticNet":
finalModel = new ElasticNetCV().setAlphas(0.1, 1.0, 10.0);
break;
default:
finalModel = new LinearRegression();
}
finalModel.fit(data.getFeatures(), data.getTargets());
return finalModel;
}
private void interpretFinalModel(LinearRegression model, Dataset data) {
System.out.println("\n=== Model Interpretation ===");
double[] coefficients = model.getCoefficients();
String[] featureNames = data.getFeatureNames().toArray(new String[0]);
// Top 10 most important features
Map<String, Double> importance = new HashMap<>();
for (int i = 0; i < featureNames.length; i++) {
importance.put(featureNames[i], Math.abs(coefficients[i]));
}
System.out.println("Top 10 most important features:");
importance.entrySet().stream()
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.limit(10)
.forEach(entry ->
System.out.printf("%-20s: %.4f%n", entry.getKey(), entry.getValue()));
}
private void saveModel(LinearRegression model, String filename) {
try {
ModelSerializer.save(model, filename);
System.out.println("\nModel saved to: " + filename);
} catch (Exception e) {
System.err.println("Failed to save model: " + e.getMessage());
}
}
}
Production Deployment
Model Serving Example
@RestController
public class HousePriceController {
private final LinearRegression model;
private final Pipeline preprocessor;
public HousePriceController() {
this.model = ModelSerializer.load("house_price_model.pkl");
this.preprocessor = Pipeline.load("preprocessing_pipeline.pkl");
}
@PostMapping("/predict")
public PredictionResponse predict(@RequestBody HouseFeatures features) {
try {
// Convert to dataset
Dataset input = Dataset.fromMap(features.toMap());
// Preprocess
Dataset processed = preprocessor.transform(input);
// Predict
double prediction = model.predict(processed.getFeatures())[0];
// Calculate confidence interval (optional)
double[] interval = model.predictWithConfidence(processed.getFeatures())[0];
return new PredictionResponse(prediction, interval[0], interval[1]);
} catch (Exception e) {
throw new PredictionException("Prediction failed: " + e.getMessage());
}
}
}
Summary
In this comprehensive tutorial, we covered:
- Basic linear regression concepts and implementation
- Multiple linear regression with real datasets
- Regularization techniques (Ridge, Lasso, Elastic Net)
- Polynomial regression for non-linear relationships
- Feature engineering and custom transformations
- Model validation and diagnostics
- Cross-validation for robust evaluation
- Complete end-to-end machine learning pipeline
- Production deployment considerations
Linear regression is a powerful tool for understanding relationships between variables and making predictions. While simple, it forms the foundation for many advanced machine learning techniques. The key to success with linear regression is proper data preprocessing, feature engineering, and model validation.
In the next tutorial, we’ll explore classification problems using logistic regression, which extends these concepts to categorical predictions.