Press ESC to exit fullscreen
πŸ“– Lesson ⏱️ 75 minutes

Linear Regression in Java

Implementing linear regression using SuperML Java

Linear Regression with SuperML Java

Linear regression is one of the fundamental algorithms in machine learning and often the first algorithm students learn. It’s simple to understand, implement, and interpret, making it an excellent starting point for your machine learning journey with SuperML Java.

What is Linear Regression?

Linear regression models the relationship between a dependent variable (target) and one or more independent variables (features) by fitting a linear equation to observed data. The goal is to find the best line that minimizes the difference between predicted and actual values.

Mathematical Foundation

For simple linear regression:

y = Ξ²β‚€ + β₁x + Ξ΅

For multiple linear regression:

y = Ξ²β‚€ + β₁x₁ + Ξ²β‚‚xβ‚‚ + ... + Ξ²β‚™xβ‚™ + Ξ΅

Where:

  • y is the target variable
  • x₁, xβ‚‚, ..., xβ‚™ are feature variables
  • Ξ²β‚€ is the intercept (bias)
  • β₁, Ξ²β‚‚, ..., Ξ²β‚™ are coefficients (weights)
  • Ξ΅ is the error term

Simple Linear Regression

Let’s start with a basic example using house size to predict price:

import org.superml.linear_model.LinearRegression;
import org.superml.datasets.Datasets;
import org.superml.metrics.Metrics;

public class SimpleLinearRegressionExample {
    
    public static void main(String[] args) {
        // Create sample data: house size vs price
        double[][] X = {{800}, {1000}, {1200}, {1400}, {1600}, {1800}, {2000}};
        double[] y = {150000, 180000, 210000, 240000, 270000, 300000, 330000};
        
        // Create and train the model
        LinearRegression model = new LinearRegression();
        model.fit(X, y);
        
        // Make predictions
        double[][] testX = {{900}, {1500}, {2200}};
        double[] predictions = model.predict(testX);
        
        // Display results
        System.out.println("Model coefficients:");
        System.out.println("Intercept: " + model.getIntercept());
        System.out.println("Coefficient: " + model.getCoefficients()[0]);
        
        System.out.println("\nPredictions:");
        for (int i = 0; i < testX.length; i++) {
            System.out.printf("House size: %.0f sq ft -> Predicted price: $%.2f%n", 
                testX[i][0], predictions[i]);
        }
    }
}

Loading Real Data

Let’s work with a more realistic dataset:

public class HousePricePrediction {
    
    public void predictHousePrices() {
        // Load housing dataset
        Dataset data = DataLoader.fromCSV("data/house_prices.csv");
        
        // Display basic information
        System.out.println("Dataset shape: " + data.getShape());
        System.out.println("Features: " + data.getFeatureNames());
        System.out.println(data.describe());
        
        // Select features for simple regression
        Dataset subset = data.selectColumns("size", "price");
        
        // Split into features and target
        double[][] X = subset.getFeatures("size");
        double[] y = subset.getTarget("price");
        
        // Split into train/test sets
        DataSplit split = subset.split(0.8);
        
        // Train the model
        LinearRegression model = new LinearRegression();
        model.fit(split.getTrainX(), split.getTrainY());
        
        // Evaluate on test set
        double[] predictions = model.predict(split.getTestX());
        double rmse = Metrics.rmse(split.getTestY(), predictions);
        double r2 = Metrics.r2Score(split.getTestY(), predictions);
        
        System.out.printf("Test RMSE: %.2f%n", rmse);
        System.out.printf("Test RΒ²: %.3f%n", r2);
    }
}

Multiple Linear Regression

Real-world problems usually involve multiple features:

public class MultipleLinearRegressionExample {
    
    public void multipleFeatureExample() {
        // Load dataset with multiple features
        Dataset data = DataLoader.fromCSV("data/house_prices_complete.csv");
        
        // Preprocessing
        data = preprocessData(data);
        
        // Select relevant features
        String[] features = {"size", "bedrooms", "bathrooms", "age", "garage"};
        Dataset modelData = data.selectColumns(features, "price");
        
        // Split data
        DataSplit split = modelData.split(0.8);
        
        // Create and configure model
        LinearRegression model = new LinearRegression()
            .setFitIntercept(true)
            .setNormalize(false);  // Already preprocessed
        
        // Train model
        model.fit(split.getTrainX(), split.getTrainY());
        
        // Model interpretation
        interpretModel(model, features);
        
        // Evaluation
        evaluateModel(model, split);
    }
    
    private Dataset preprocessData(Dataset data) {
        // Handle missing values
        data = data.fillMissing(Strategy.MEAN);
        
        // Remove outliers (optional)
        OutlierDetector detector = new IQRDetector(multiplier=1.5);
        boolean[] outliers = detector.detect(data, "price");
        data = data.removeOutliers(outliers);
        
        // Scale features
        StandardScaler scaler = new StandardScaler();
        String[] numericFeatures = {"size", "bedrooms", "bathrooms", "age", "garage"};
        data = scaler.fitTransform(data, numericFeatures);
        
        return data;
    }
    
    private void interpretModel(LinearRegression model, String[] features) {
        System.out.println("Model Coefficients:");
        System.out.printf("Intercept: %.2f%n", model.getIntercept());
        
        double[] coefficients = model.getCoefficients();
        for (int i = 0; i < features.length; i++) {
            System.out.printf("%-12s: %.4f%n", features[i], coefficients[i]);
        }
        
        // Feature importance (absolute coefficient values)
        System.out.println("\nFeature Importance (|coefficient|):");
        Map<String, Double> importance = new HashMap<>();
        for (int i = 0; i < features.length; i++) {
            importance.put(features[i], Math.abs(coefficients[i]));
        }
        
        importance.entrySet().stream()
            .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
            .forEach(entry -> 
                System.out.printf("%-12s: %.4f%n", entry.getKey(), entry.getValue()));
    }
    
    private void evaluateModel(LinearRegression model, DataSplit split) {
        // Training metrics
        double[] trainPreds = model.predict(split.getTrainX());
        double trainRmse = Metrics.rmse(split.getTrainY(), trainPreds);
        double trainR2 = Metrics.r2Score(split.getTrainY(), trainPreds);
        
        // Test metrics
        double[] testPreds = model.predict(split.getTestX());
        double testRmse = Metrics.rmse(split.getTestY(), testPreds);
        double testR2 = Metrics.r2Score(split.getTestY(), testPreds);
        
        System.out.println("\nModel Performance:");
        System.out.printf("Training RMSE: %.2f, RΒ²: %.3f%n", trainRmse, trainR2);
        System.out.printf("Test RMSE: %.2f, RΒ²: %.3f%n", testRmse, testR2);
        
        // Check for overfitting
        if (trainR2 - testR2 > 0.1) {
            System.out.println("Warning: Possible overfitting detected!");
        }
    }
}

Regularization Techniques

Regularization helps prevent overfitting by penalizing large coefficients:

Ridge Regression (L2 Regularization)

public class RidgeRegressionExample {
    
    public void ridgeRegressionDemo() {
        Dataset data = DataLoader.fromCSV("data/house_prices.csv");
        data = preprocessData(data);
        DataSplit split = data.split(0.8);
        
        // Ridge regression with different alpha values
        double[] alphas = {0.1, 1.0, 10.0, 100.0};
        
        for (double alpha : alphas) {
            RidgeRegression model = new RidgeRegression(alpha);
            model.fit(split.getTrainX(), split.getTrainY());
            
            double[] predictions = model.predict(split.getTestX());
            double rmse = Metrics.rmse(split.getTestY(), predictions);
            double r2 = Metrics.r2Score(split.getTestY(), predictions);
            
            System.out.printf("Alpha: %.1f, RMSE: %.2f, RΒ²: %.3f%n", alpha, rmse, r2);
        }
    }
    
    public void findOptimalAlpha() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
        
        // Cross-validation for hyperparameter tuning
        RidgeRegressionCV ridgeCV = new RidgeRegressionCV()
            .setAlphas(0.1, 1.0, 10.0, 100.0, 1000.0)
            .setCv(5);  // 5-fold cross-validation
        
        ridgeCV.fit(data.getFeatures(), data.getTargets());
        
        System.out.println("Optimal alpha: " + ridgeCV.getBestAlpha());
        System.out.println("Best CV score: " + ridgeCV.getBestScore());
    }
}

Lasso Regression (L1 Regularization)

public class LassoRegressionExample {
    
    public void lassoRegressionDemo() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
        DataSplit split = data.split(0.8);
        
        // Lasso regression for feature selection
        LassoRegression model = new LassoRegression(alpha=1.0);
        model.fit(split.getTrainX(), split.getTrainY());
        
        // Check which features were selected (non-zero coefficients)
        double[] coefficients = model.getCoefficients();
        String[] featureNames = data.getFeatureNames();
        
        System.out.println("Selected features:");
        for (int i = 0; i < coefficients.length; i++) {
            if (Math.abs(coefficients[i]) > 1e-6) {
                System.out.printf("%-15s: %.4f%n", featureNames[i], coefficients[i]);
            }
        }
        
        // Evaluation
        double[] predictions = model.predict(split.getTestX());
        double rmse = Metrics.rmse(split.getTestY(), predictions);
        System.out.printf("Test RMSE: %.2f%n", rmse);
    }
    
    public void lassoPath() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
        
        // Compute regularization path
        LassoPath path = new LassoPath();
        path.fit(data.getFeatures(), data.getTargets());
        
        // Get coefficients for different alpha values
        double[] alphas = path.getAlphas();
        double[][] coefficientPaths = path.getCoefficientPaths();
        
        System.out.println("Regularization path:");
        for (int i = 0; i < alphas.length; i += 10) {  // Sample every 10th alpha
            System.out.printf("Alpha: %.4f, Non-zero coefficients: %d%n", 
                alphas[i], countNonZero(coefficientPaths[i]));
        }
    }
    
    private int countNonZero(double[] coefficients) {
        return (int) Arrays.stream(coefficients)
            .map(Math::abs)
            .filter(c -> c > 1e-6)
            .count();
    }
}

Elastic Net (L1 + L2 Regularization)

public class ElasticNetExample {
    
    public void elasticNetDemo() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
        DataSplit split = data.split(0.8);
        
        // Elastic Net combines Ridge and Lasso
        ElasticNet model = new ElasticNet()
            .setAlpha(1.0)          // Overall regularization strength
            .setL1Ratio(0.5);       // Mix between L1 (1.0) and L2 (0.0)
        
        model.fit(split.getTrainX(), split.getTrainY());
        
        // Cross-validation for optimal parameters
        ElasticNetCV modelCV = new ElasticNetCV()
            .setAlphas(0.1, 1.0, 10.0)
            .setL1Ratios(0.1, 0.5, 0.7, 0.9)
            .setCv(5);
        
        modelCV.fit(split.getTrainX(), split.getTrainY());
        
        System.out.println("Best alpha: " + modelCV.getBestAlpha());
        System.out.println("Best l1_ratio: " + modelCV.getBestL1Ratio());
        
        double[] predictions = modelCV.predict(split.getTestX());
        double rmse = Metrics.rmse(split.getTestY(), predictions);
        System.out.printf("Test RMSE: %.2f%n", rmse);
    }
}

Advanced Features

Polynomial Regression

public class PolynomialRegressionExample {
    
    public void polynomialRegressionDemo() {
        // Generate sample data with non-linear relationship
        double[][] X = generateNonLinearData();
        double[] y = generateNonLinearTargets(X);
        
        DataSplit split = splitData(X, y, 0.8);
        
        // Try different polynomial degrees
        int[] degrees = {1, 2, 3, 4, 5};
        
        for (int degree : degrees) {
            // Create polynomial features
            PolynomialFeatures polyFeatures = new PolynomialFeatures(degree);
            double[][] XPoly = polyFeatures.fitTransform(split.getTrainX());
            double[][] XTestPoly = polyFeatures.transform(split.getTestX());
            
            // Fit linear regression on polynomial features
            LinearRegression model = new LinearRegression();
            model.fit(XPoly, split.getTrainY());
            
            // Evaluate
            double[] predictions = model.predict(XTestPoly);
            double rmse = Metrics.rmse(split.getTestY(), predictions);
            double r2 = Metrics.r2Score(split.getTestY(), predictions);
            
            System.out.printf("Degree: %d, RMSE: %.3f, RΒ²: %.3f%n", degree, rmse, r2);
        }
    }
    
    // Use with regularization to prevent overfitting
    public void regularizedPolynomialRegression() {
        double[][] X = generateNonLinearData();
        double[] y = generateNonLinearTargets(X);
        
        // High-degree polynomial with regularization
        PolynomialFeatures polyFeatures = new PolynomialFeatures(degree=8);
        double[][] XPoly = polyFeatures.fitTransform(X);
        
        // Ridge regression to control overfitting
        RidgeRegression model = new RidgeRegression(alpha=100.0);
        model.fit(XPoly, y);
        
        System.out.println("High-degree polynomial with Ridge regularization trained successfully");
    }
}

Custom Features and Transformations

public class CustomFeatureEngineering {
    
    public void createCustomFeatures() {
        Dataset data = DataLoader.fromCSV("data/house_prices.csv");
        
        // Create new features
        data = data.withNewColumn("price_per_sqft", 
            row -> row.getDouble("price") / row.getDouble("size"));
        
        data = data.withNewColumn("bedroom_ratio", 
            row -> row.getDouble("bedrooms") / row.getDouble("size"));
        
        data = data.withNewColumn("total_rooms", 
            row -> row.getDouble("bedrooms") + row.getDouble("bathrooms"));
        
        // Log transformation for skewed features
        data = data.withTransformedColumn("log_price", "price", Math::log);
        data = data.withTransformedColumn("sqrt_size", "size", Math::sqrt);
        
        // Interaction features
        data = data.withNewColumn("size_bedrooms_interaction",
            row -> row.getDouble("size") * row.getDouble("bedrooms"));
        
        // Use engineered features in model
        String[] features = {"size", "bedrooms", "bathrooms", "price_per_sqft", 
                            "bedroom_ratio", "total_rooms", "size_bedrooms_interaction"};
        
        Dataset modelData = data.selectColumns(features, "price");
        DataSplit split = modelData.split(0.8);
        
        LinearRegression model = new LinearRegression();
        model.fit(split.getTrainX(), split.getTrainY());
        
        double[] predictions = model.predict(split.getTestX());
        double rmse = Metrics.rmse(split.getTestY(), predictions);
        System.out.printf("RMSE with engineered features: %.2f%n", rmse);
    }
}

Model Validation and Diagnostics

Residual Analysis

public class ModelDiagnostics {
    
    public void residualAnalysis(LinearRegression model, DataSplit split) {
        double[] predictions = model.predict(split.getTestX());
        double[] actual = split.getTestY();
        
        // Calculate residuals
        double[] residuals = new double[actual.length];
        for (int i = 0; i < residuals.length; i++) {
            residuals[i] = actual[i] - predictions[i];
        }
        
        // Residual statistics
        double meanResidual = Arrays.stream(residuals).average().orElse(0.0);
        double stdResidual = calculateStandardDeviation(residuals);
        
        System.out.printf("Mean residual: %.4f (should be close to 0)%n", meanResidual);
        System.out.printf("Std residual: %.4f%n", stdResidual);
        
        // Check for heteroscedasticity
        checkHomoscedasticity(predictions, residuals);
        
        // Check for normality of residuals
        checkNormality(residuals);
    }
    
    private void checkHomoscedasticity(double[] predictions, double[] residuals) {
        // Simple test: correlation between |residuals| and predictions
        double[] absResiduals = Arrays.stream(residuals).map(Math::abs).toArray();
        double correlation = calculateCorrelation(predictions, absResiduals);
        
        System.out.printf("Correlation |residuals| vs predictions: %.4f%n", correlation);
        if (Math.abs(correlation) > 0.3) {
            System.out.println("Warning: Possible heteroscedasticity detected!");
        }
    }
    
    private void checkNormality(double[] residuals) {
        // Simple normality check using skewness and kurtosis
        double skewness = calculateSkewness(residuals);
        double kurtosis = calculateKurtosis(residuals);
        
        System.out.printf("Residual skewness: %.4f (should be close to 0)%n", skewness);
        System.out.printf("Residual kurtosis: %.4f (should be close to 3)%n", kurtosis);
    }
}

Cross-Validation

public class CrossValidationExample {
    
    public void crossValidateModel() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/house_prices.csv"));
        
        // K-Fold Cross-Validation
        KFoldValidator validator = new KFoldValidator(k=5);
        CrossValidationResult result = validator.validate(
            new LinearRegression(), 
            data.getFeatures(), 
            data.getTargets(),
            Metrics.RMSE
        );
        
        System.out.println("Cross-Validation Results:");
        System.out.printf("Mean RMSE: %.2f (Β±%.2f)%n", 
            result.getMean(), result.getStd());
        System.out.printf("Individual fold scores: %s%n", 
            Arrays.toString(result.getScores()));
        
        // Time Series Cross-Validation (if data has time component)
        if (data.hasColumn("date")) {
            TimeSeriesValidator tsValidator = new TimeSeriesValidator();
            CrossValidationResult tsResult = tsValidator.validate(
                new LinearRegression(),
                data.getFeatures(),
                data.getTargets(),
                data.getColumn("date"),
                Metrics.RMSE
            );
            
            System.out.printf("Time Series CV Mean RMSE: %.2f%n", tsResult.getMean());
        }
    }
}

Complete Real-World Example

Let’s put everything together with a comprehensive example:

public class ComprehensiveHousePricePrediction {
    
    public static void main(String[] args) {
        ComprehensiveHousePricePrediction predictor = 
            new ComprehensiveHousePricePrediction();
        predictor.runCompleteAnalysis();
    }
    
    public void runCompleteAnalysis() {
        // Load and explore data
        Dataset data = DataLoader.fromCSV("data/house_prices_complete.csv");
        exploreData(data);
        
        // Preprocess data
        data = preprocessData(data);
        
        // Feature engineering
        data = engineerFeatures(data);
        
        // Model selection and training
        Map<String, Double> modelResults = compareModels(data);
        
        // Select best model and final training
        String bestModel = getBestModel(modelResults);
        LinearRegression finalModel = trainFinalModel(data, bestModel);
        
        // Model interpretation and validation
        interpretFinalModel(finalModel, data);
        
        // Save model for production use
        saveModel(finalModel, "house_price_model.pkl");
    }
    
    private void exploreData(Dataset data) {
        System.out.println("=== Data Exploration ===");
        System.out.println("Dataset shape: " + data.getShape());
        System.out.println("Missing values: " + data.getMissingCounts());
        
        // Basic statistics
        DataSummary summary = data.describe();
        System.out.println(summary);
        
        // Correlation with target
        Map<String, Double> correlations = data.getCorrelationWithTarget("price");
        correlations.entrySet().stream()
            .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
            .limit(10)
            .forEach(entry -> 
                System.out.printf("%-15s: %.3f%n", entry.getKey(), entry.getValue()));
    }
    
    private Dataset preprocessData(Dataset data) {
        System.out.println("\n=== Data Preprocessing ===");
        
        // Handle missing values
        data = data.fillMissing(Map.of(
            "age", Strategy.MEDIAN,
            "garage", 0.0,
            "basement", 0.0
        ));
        
        // Remove outliers
        OutlierDetector detector = new IQRDetector(multiplier=1.5);
        boolean[] outliers = detector.detect(data, "price");
        int outliersRemoved = (int) Arrays.stream(outliers).filter(b -> b).count();
        System.out.println("Outliers removed: " + outliersRemoved);
        data = data.removeOutliers(outliers);
        
        return data;
    }
    
    private Dataset engineerFeatures(Dataset data) {
        System.out.println("\n=== Feature Engineering ===");
        
        // Create new features
        data = data.withNewColumn("price_per_sqft", 
            row -> row.getDouble("price") / row.getDouble("size"));
        
        data = data.withNewColumn("rooms_total", 
            row -> row.getDouble("bedrooms") + row.getDouble("bathrooms"));
        
        data = data.withNewColumn("age_group", row -> {
            double age = row.getDouble("age");
            if (age < 5) return "new";
            else if (age < 20) return "modern";
            else if (age < 50) return "mature";
            else return "old";
        });
        
        // One-hot encode categorical features
        OneHotEncoder encoder = new OneHotEncoder();
        data = encoder.fitTransform(data, "age_group", "neighborhood");
        
        // Scale numerical features
        StandardScaler scaler = new StandardScaler();
        String[] numericalFeatures = {"size", "bedrooms", "bathrooms", "age", 
                                     "garage", "basement", "price_per_sqft", "rooms_total"};
        data = scaler.fitTransform(data, numericalFeatures);
        
        System.out.println("Final feature count: " + data.getFeatureNames().size());
        return data;
    }
    
    private Map<String, Double> compareModels(Dataset data) {
        System.out.println("\n=== Model Comparison ===");
        
        DataSplit split = data.split(0.8);
        Map<String, Double> results = new HashMap<>();
        
        // Linear Regression
        LinearRegression lr = new LinearRegression();
        lr.fit(split.getTrainX(), split.getTrainY());
        double[] lrPreds = lr.predict(split.getTestX());
        results.put("LinearRegression", Metrics.rmse(split.getTestY(), lrPreds));
        
        // Ridge Regression
        RidgeRegressionCV ridge = new RidgeRegressionCV()
            .setAlphas(0.1, 1.0, 10.0, 100.0)
            .setCv(5);
        ridge.fit(split.getTrainX(), split.getTrainY());
        double[] ridgePreds = ridge.predict(split.getTestX());
        results.put("Ridge", Metrics.rmse(split.getTestY(), ridgePreds));
        
        // Lasso Regression
        LassoRegressionCV lasso = new LassoRegressionCV()
            .setAlphas(0.1, 1.0, 10.0, 100.0)
            .setCv(5);
        lasso.fit(split.getTrainX(), split.getTrainY());
        double[] lassoPreds = lasso.predict(split.getTestX());
        results.put("Lasso", Metrics.rmse(split.getTestY(), lassoPreds));
        
        // Elastic Net
        ElasticNetCV elastic = new ElasticNetCV()
            .setAlphas(0.1, 1.0, 10.0)
            .setL1Ratios(0.1, 0.5, 0.9)
            .setCv(5);
        elastic.fit(split.getTrainX(), split.getTrainY());
        double[] elasticPreds = elastic.predict(split.getTestX());
        results.put("ElasticNet", Metrics.rmse(split.getTestY(), elasticPreds));
        
        // Display results
        results.entrySet().stream()
            .sorted(Map.Entry.comparingByValue())
            .forEach(entry -> 
                System.out.printf("%-15s: RMSE = %.2f%n", entry.getKey(), entry.getValue()));
        
        return results;
    }
    
    private String getBestModel(Map<String, Double> results) {
        return results.entrySet().stream()
            .min(Map.Entry.comparingByValue())
            .map(Map.Entry::getKey)
            .orElse("LinearRegression");
    }
    
    private LinearRegression trainFinalModel(Dataset data, String modelType) {
        System.out.println("\n=== Final Model Training ===");
        System.out.println("Best model: " + modelType);
        
        // Train on full dataset
        LinearRegression finalModel;
        switch (modelType) {
            case "Ridge":
                finalModel = new RidgeRegressionCV().setAlphas(0.1, 1.0, 10.0, 100.0);
                break;
            case "Lasso":
                finalModel = new LassoRegressionCV().setAlphas(0.1, 1.0, 10.0, 100.0);
                break;
            case "ElasticNet":
                finalModel = new ElasticNetCV().setAlphas(0.1, 1.0, 10.0);
                break;
            default:
                finalModel = new LinearRegression();
        }
        
        finalModel.fit(data.getFeatures(), data.getTargets());
        return finalModel;
    }
    
    private void interpretFinalModel(LinearRegression model, Dataset data) {
        System.out.println("\n=== Model Interpretation ===");
        
        double[] coefficients = model.getCoefficients();
        String[] featureNames = data.getFeatureNames().toArray(new String[0]);
        
        // Top 10 most important features
        Map<String, Double> importance = new HashMap<>();
        for (int i = 0; i < featureNames.length; i++) {
            importance.put(featureNames[i], Math.abs(coefficients[i]));
        }
        
        System.out.println("Top 10 most important features:");
        importance.entrySet().stream()
            .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
            .limit(10)
            .forEach(entry -> 
                System.out.printf("%-20s: %.4f%n", entry.getKey(), entry.getValue()));
    }
    
    private void saveModel(LinearRegression model, String filename) {
        try {
            ModelSerializer.save(model, filename);
            System.out.println("\nModel saved to: " + filename);
        } catch (Exception e) {
            System.err.println("Failed to save model: " + e.getMessage());
        }
    }
}

Production Deployment

Model Serving Example

@RestController
public class HousePriceController {
    
    private final LinearRegression model;
    private final Pipeline preprocessor;
    
    public HousePriceController() {
        this.model = ModelSerializer.load("house_price_model.pkl");
        this.preprocessor = Pipeline.load("preprocessing_pipeline.pkl");
    }
    
    @PostMapping("/predict")
    public PredictionResponse predict(@RequestBody HouseFeatures features) {
        try {
            // Convert to dataset
            Dataset input = Dataset.fromMap(features.toMap());
            
            // Preprocess
            Dataset processed = preprocessor.transform(input);
            
            // Predict
            double prediction = model.predict(processed.getFeatures())[0];
            
            // Calculate confidence interval (optional)
            double[] interval = model.predictWithConfidence(processed.getFeatures())[0];
            
            return new PredictionResponse(prediction, interval[0], interval[1]);
            
        } catch (Exception e) {
            throw new PredictionException("Prediction failed: " + e.getMessage());
        }
    }
}

Summary

In this comprehensive tutorial, we covered:

  • Basic linear regression concepts and implementation
  • Multiple linear regression with real datasets
  • Regularization techniques (Ridge, Lasso, Elastic Net)
  • Polynomial regression for non-linear relationships
  • Feature engineering and custom transformations
  • Model validation and diagnostics
  • Cross-validation for robust evaluation
  • Complete end-to-end machine learning pipeline
  • Production deployment considerations

Linear regression is a powerful tool for understanding relationships between variables and making predictions. While simple, it forms the foundation for many advanced machine learning techniques. The key to success with linear regression is proper data preprocessing, feature engineering, and model validation.

In the next tutorial, we’ll explore classification problems using logistic regression, which extends these concepts to categorical predictions.