Press ESC to exit fullscreen
πŸ“– Lesson ⏱️ 90 minutes

Classification with Logistic Regression

Building classification models in Java

Classification with Logistic Regression in SuperML Java

Logistic regression is one of the most fundamental and widely used algorithms for classification problems. Despite its name containing β€œregression,” it’s actually a classification algorithm that predicts the probability of class membership. This tutorial will teach you how to implement and optimize logistic regression using SuperML Java.

Understanding Logistic Regression

Mathematical Foundation

While linear regression predicts continuous values, logistic regression predicts probabilities using the logistic function:

p = 1 / (1 + e^(-(Ξ²β‚€ + β₁x₁ + Ξ²β‚‚xβ‚‚ + ... + Ξ²β‚™xβ‚™)))

This ensures predictions are always between 0 and 1, making them interpretable as probabilities.

When to Use Logistic Regression

  • Binary classification: Email spam detection, medical diagnosis
  • Multiclass classification: Image classification, text categorization
  • Probability estimation: When you need calibrated probabilities
  • Feature interpretation: Understanding which features influence predictions
  • Baseline models: As a strong baseline for more complex algorithms

Binary Classification

Let’s start with a binary classification example - predicting customer churn:

import org.superml.linear_model.LogisticRegression;
import org.superml.datasets.Datasets;
import org.superml.metrics.Metrics;

public class CustomerChurnPrediction {
    
    public static void main(String[] args) {
        CustomerChurnPrediction predictor = new CustomerChurnPrediction();
        predictor.predictCustomerChurn();
    }
    
    public void predictCustomerChurn() {
        // Load customer data
        Dataset data = DataLoader.fromCSV("data/customer_churn.csv");
        
        // Explore the data
        exploreData(data);
        
        // Preprocess the data
        data = preprocessData(data);
        
        // Train and evaluate model
        trainAndEvaluateModel(data);
    }
    
    private void exploreData(Dataset data) {
        System.out.println("=== Data Exploration ===");
        System.out.println("Dataset shape: " + data.getShape());
        System.out.println("Features: " + data.getFeatureNames());
        
        // Check class distribution
        Map<Object, Integer> classDistribution = data.getValueCounts("churn");
        System.out.println("Class distribution:");
        classDistribution.forEach((key, value) -> 
            System.out.printf("Class %s: %d (%.1f%%)%n", 
                key, value, 100.0 * value / data.getRowCount()));
        
        // Check for class imbalance
        double imbalanceRatio = (double) classDistribution.get(0) / classDistribution.get(1);
        if (imbalanceRatio > 2 || imbalanceRatio < 0.5) {
            System.out.println("Warning: Class imbalance detected!");
        }
    }
    
    private Dataset preprocessData(Dataset data) {
        System.out.println("\n=== Data Preprocessing ===");
        
        // Handle missing values
        data = data.fillMissing(Map.of(
            "total_charges", Strategy.MEDIAN,
            "monthly_charges", Strategy.MEAN
        ));
        
        // Encode categorical variables
        OneHotEncoder encoder = new OneHotEncoder();
        String[] categoricalColumns = {"gender", "contract", "payment_method", 
                                     "internet_service", "phone_service"};
        data = encoder.fitTransform(data, categoricalColumns);
        
        // Scale numerical features
        StandardScaler scaler = new StandardScaler();
        String[] numericalColumns = {"tenure", "monthly_charges", "total_charges"};
        data = scaler.fitTransform(data, numericalColumns);
        
        System.out.println("Preprocessed shape: " + data.getShape());
        return data;
    }
    
    private void trainAndEvaluateModel(Dataset data) {
        System.out.println("\n=== Model Training and Evaluation ===");
        
        // Split data
        DataSplit split = data.split(0.8, stratify=true);
        
        // Create and train logistic regression model
        LogisticRegression model = new LogisticRegression()
            .setMaxIterations(1000)
            .setTolerance(1e-6)
            .setRegularization(Regularization.L2)
            .setC(1.0);  // Inverse of regularization strength
        
        model.fit(split.getTrainX(), split.getTrainY());
        
        // Make predictions
        double[] predictions = model.predict(split.getTestX());
        double[] probabilities = model.predictProba(split.getTestX());
        
        // Evaluate performance
        evaluateModel(split.getTestY(), predictions, probabilities);
        
        // Interpret model
        interpretModel(model, data.getFeatureNames());
    }
    
    private void evaluateModel(double[] yTrue, double[] yPred, double[] probabilities) {
        // Basic metrics
        double accuracy = Metrics.accuracy(yTrue, yPred);
        double precision = Metrics.precision(yTrue, yPred);
        double recall = Metrics.recall(yTrue, yPred);
        double f1 = Metrics.f1Score(yTrue, yPred);
        
        System.out.printf("Accuracy:  %.3f%n", accuracy);
        System.out.printf("Precision: %.3f%n", precision);
        System.out.printf("Recall:    %.3f%n", recall);
        System.out.printf("F1-Score:  %.3f%n", f1);
        
        // Confusion Matrix
        ConfusionMatrix cm = Metrics.confusionMatrix(yTrue, yPred);
        System.out.println("\nConfusion Matrix:");
        System.out.println(cm);
        
        // ROC AUC
        double rocAuc = Metrics.rocAuc(yTrue, probabilities);
        System.out.printf("ROC AUC:   %.3f%n", rocAuc);
        
        // Precision-Recall AUC
        double prAuc = Metrics.precisionRecallAuc(yTrue, probabilities);
        System.out.printf("PR AUC:    %.3f%n", prAuc);
    }
    
    private void interpretModel(LogisticRegression model, List<String> featureNames) {
        System.out.println("\n=== Model Interpretation ===");
        
        double[] coefficients = model.getCoefficients();
        double intercept = model.getIntercept();
        
        System.out.printf("Intercept: %.4f%n", intercept);
        System.out.println("\nFeature Coefficients (odds ratios):");
        
        // Create list of feature importance
        List<FeatureImportance> importance = new ArrayList<>();
        for (int i = 0; i < featureNames.size(); i++) {
            double coef = coefficients[i];
            double oddsRatio = Math.exp(coef);
            importance.add(new FeatureImportance(featureNames.get(i), coef, oddsRatio));
        }
        
        // Sort by absolute coefficient value
        importance.sort((a, b) -> Double.compare(Math.abs(b.coefficient), Math.abs(a.coefficient)));
        
        // Display top 10 features
        System.out.println("Top 10 most important features:");
        importance.stream().limit(10).forEach(fi -> {
            String direction = fi.coefficient > 0 ? "↑" : "↓";
            System.out.printf("%-25s: %s %.4f (OR: %.3f)%n", 
                fi.name, direction, fi.coefficient, fi.oddsRatio);
        });
    }
    
    private static class FeatureImportance {
        String name;
        double coefficient;
        double oddsRatio;
        
        FeatureImportance(String name, double coefficient, double oddsRatio) {
            this.name = name;
            this.coefficient = coefficient;
            this.oddsRatio = oddsRatio;
        }
    }
}

Multiclass Classification

For problems with more than two classes, SuperML Java supports several strategies:

public class MulticlassClassificationExample {
    
    public void irisClassification() {
        // Load the famous Iris dataset
        Dataset data = DataLoader.fromCSV("data/iris.csv");
        
        System.out.println("=== Iris Multiclass Classification ===");
        System.out.println("Classes: " + data.getUniqueValues("species"));
        
        // Preprocess
        StandardScaler scaler = new StandardScaler();
        data = scaler.fitTransform(data, "sepal_length", "sepal_width", 
                                  "petal_length", "petal_width");
        
        DataSplit split = data.split(0.8, stratify=true);
        
        // One-vs-Rest (OvR) strategy
        LogisticRegression ovrModel = new LogisticRegression()
            .setMultiClass(MultiClass.OVR)
            .setSolver(Solver.LIBLINEAR);
        
        ovrModel.fit(split.getTrainX(), split.getTrainY());
        
        // Multinomial strategy (preferred for multiclass)
        LogisticRegression multinomialModel = new LogisticRegression()
            .setMultiClass(MultiClass.MULTINOMIAL)
            .setSolver(Solver.LBFGS)
            .setMaxIterations(1000);
        
        multinomialModel.fit(split.getTrainX(), split.getTrainY());
        
        // Compare both approaches
        compareMulticlassStrategies(ovrModel, multinomialModel, split);
    }
    
    private void compareMulticlassStrategies(LogisticRegression ovrModel, 
                                           LogisticRegression multinomialModel, 
                                           DataSplit split) {
        
        System.out.println("\n=== Strategy Comparison ===");
        
        // OvR predictions
        double[] ovrPreds = ovrModel.predict(split.getTestX());
        double[][] ovrProbs = ovrModel.predictProba(split.getTestX());
        
        double ovrAccuracy = Metrics.accuracy(split.getTestY(), ovrPreds);
        double ovrF1 = Metrics.f1Score(split.getTestY(), ovrPreds, F1Average.WEIGHTED);
        
        // Multinomial predictions
        double[] multinomialPreds = multinomialModel.predict(split.getTestX());
        double[][] multinomialProbs = multinomialModel.predictProba(split.getTestX());
        
        double multinomialAccuracy = Metrics.accuracy(split.getTestY(), multinomialPreds);
        double multinomialF1 = Metrics.f1Score(split.getTestY(), multinomialPreds, F1Average.WEIGHTED);
        
        System.out.printf("One-vs-Rest      - Accuracy: %.3f, F1: %.3f%n", ovrAccuracy, ovrF1);
        System.out.printf("Multinomial      - Accuracy: %.3f, F1: %.3f%n", multinomialAccuracy, multinomialF1);
        
        // Detailed classification report
        ClassificationReport report = Metrics.classificationReport(
            split.getTestY(), multinomialPreds);
        System.out.println("\nDetailed Classification Report:");
        System.out.println(report);
        
        // Confusion matrix for multiclass
        MulticlassConfusionMatrix mcm = Metrics.multiclassConfusionMatrix(
            split.getTestY(), multinomialPreds);
        System.out.println("\nMulticlass Confusion Matrix:");
        System.out.println(mcm);
    }
}

Regularization and Hyperparameter Tuning

Prevent overfitting and optimize performance:

public class LogisticRegressionTuning {
    
    public void hyperparameterTuning() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/classification_data.csv"));
        
        // Grid search for optimal hyperparameters
        GridSearchCV gridSearch = new GridSearchCV()
            .setEstimator(new LogisticRegression())
            .setParamGrid(Map.of(
                "C", Arrays.asList(0.01, 0.1, 1.0, 10.0, 100.0),
                "penalty", Arrays.asList("l1", "l2", "elasticnet"),
                "solver", Arrays.asList("liblinear", "lbfgs", "saga")
            ))
            .setScoring(Scoring.F1_WEIGHTED)
            .setCv(5)
            .setVerbose(true);
        
        gridSearch.fit(data.getFeatures(), data.getTargets());
        
        System.out.println("Best parameters: " + gridSearch.getBestParams());
        System.out.println("Best CV score: " + gridSearch.getBestScore());
        
        // Train final model with best parameters
        LogisticRegression bestModel = gridSearch.getBestEstimator();
        
        // Evaluate on test set
        DataSplit split = data.split(0.8);
        double[] predictions = bestModel.predict(split.getTestX());
        double testF1 = Metrics.f1Score(split.getTestY(), predictions, F1Average.WEIGHTED);
        
        System.out.printf("Test F1 Score: %.3f%n", testF1);
    }
    
    public void regularizationPaths() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/classification_data.csv"));
        
        // L1 regularization path (Lasso)
        double[] cValues = {100, 10, 1, 0.1, 0.01, 0.001};
        
        System.out.println("=== L1 Regularization Path ===");
        System.out.printf("%-8s %-10s %-15s%n", "C", "Accuracy", "Non-zero Coef");
        
        for (double c : cValues) {
            LogisticRegression model = new LogisticRegression()
                .setC(c)
                .setPenalty("l1")
                .setSolver("liblinear");
            
            // Cross-validation
            CrossValidationResult cvResult = CrossValidator.validate(
                model, data.getFeatures(), data.getTargets(), 5, Scoring.ACCURACY);
            
            model.fit(data.getFeatures(), data.getTargets());
            int nonZeroCoef = countNonZeroCoefficients(model.getCoefficients());
            
            System.out.printf("%-8.3f %-10.3f %-15d%n", c, cvResult.getMean(), nonZeroCoef);
        }
    }
    
    private int countNonZeroCoefficients(double[] coefficients) {
        return (int) Arrays.stream(coefficients)
            .map(Math::abs)
            .filter(c -> c > 1e-6)
            .count();
    }
}

Advanced Features

Class Imbalance Handling

public class ImbalancedClassificationExample {
    
    public void handleClassImbalance() {
        Dataset data = DataLoader.fromCSV("data/imbalanced_data.csv");
        
        // Check class distribution
        Map<Object, Integer> classCounts = data.getValueCounts("target");
        System.out.println("Original class distribution: " + classCounts);
        
        DataSplit split = data.split(0.8, stratify=true);
        
        // Approach 1: Class weights
        LogisticRegression weightedModel = new LogisticRegression()
            .setClassWeight("balanced");  // Automatically balance classes
        
        weightedModel.fit(split.getTrainX(), split.getTrainY());
        
        // Approach 2: Custom class weights
        Map<Integer, Double> customWeights = Map.of(
            0, 1.0,
            1, 5.0  // Give minority class 5x weight
        );
        
        LogisticRegression customWeightedModel = new LogisticRegression()
            .setClassWeight(customWeights);
        
        customWeightedModel.fit(split.getTrainX(), split.getTrainY());
        
        // Approach 3: Threshold tuning
        double[] probabilities = weightedModel.predictProba(split.getTestX());
        double[] predictions = adjustThreshold(probabilities, 0.3);  // Lower threshold
        
        // Evaluate all approaches
        evaluateImbalancedModel(split.getTestY(), predictions, probabilities);
    }
    
    private double[] adjustThreshold(double[] probabilities, double threshold) {
        return Arrays.stream(probabilities)
            .map(p -> p >= threshold ? 1.0 : 0.0)
            .toArray();
    }
    
    private void evaluateImbalancedModel(double[] yTrue, double[] yPred, double[] probabilities) {
        System.out.println("\n=== Imbalanced Dataset Evaluation ===");
        
        // Standard metrics
        double accuracy = Metrics.accuracy(yTrue, yPred);
        double precision = Metrics.precision(yTrue, yPred);
        double recall = Metrics.recall(yTrue, yPred);
        double f1 = Metrics.f1Score(yTrue, yPred);
        
        // Metrics specific to imbalanced datasets
        double balancedAccuracy = Metrics.balancedAccuracy(yTrue, yPred);
        double mcc = Metrics.matthewsCorrelationCoefficient(yTrue, yPred);
        
        System.out.printf("Accuracy:         %.3f%n", accuracy);
        System.out.printf("Balanced Accuracy:%.3f%n", balancedAccuracy);
        System.out.printf("Precision:        %.3f%n", precision);
        System.out.printf("Recall:           %.3f%n", recall);
        System.out.printf("F1-Score:         %.3f%n", f1);
        System.out.printf("MCC:              %.3f%n", mcc);
        
        // ROC and PR curves are more informative for imbalanced data
        double rocAuc = Metrics.rocAuc(yTrue, probabilities);
        double prAuc = Metrics.precisionRecallAuc(yTrue, probabilities);
        
        System.out.printf("ROC AUC:          %.3f%n", rocAuc);
        System.out.printf("PR AUC:           %.3f%n", prAuc);
    }
}

Feature Selection

public class FeatureSelectionExample {
    
    public void selectiveFeatureAnalysis() {
        Dataset data = preprocessData(DataLoader.fromCSV("data/high_dimensional_data.csv"));
        
        System.out.println("Original features: " + data.getFeatureCount());
        
        // Statistical feature selection
        data = statisticalFeatureSelection(data);
        
        // L1-based feature selection
        data = l1FeatureSelection(data);
        
        // Recursive feature elimination
        data = recursiveFeatureElimination(data);
        
        System.out.println("Final features: " + data.getFeatureCount());
    }
    
    private Dataset statisticalFeatureSelection(Dataset data) {
        // Chi-square test for categorical features
        ChiSquareSelector chiSquare = new ChiSquareSelector(k=50);
        Dataset selected = chiSquare.fitTransform(data);
        
        System.out.println("After chi-square selection: " + selected.getFeatureCount());
        return selected;
    }
    
    private Dataset l1FeatureSelection(Dataset data) {
        // Use L1 penalty to select features
        SelectFromModel selector = new SelectFromModel(
            new LogisticRegression().setPenalty("l1").setC(0.1)
        );
        
        Dataset selected = selector.fitTransform(data);
        System.out.println("After L1 selection: " + selected.getFeatureCount());
        
        // Show selected features
        boolean[] mask = selector.getSupport();
        List<String> selectedFeatures = selector.getFeatureNames(data.getFeatureNames());
        System.out.println("Selected features: " + selectedFeatures);
        
        return selected;
    }
    
    private Dataset recursiveFeatureElimination(Dataset data) {
        // Recursive Feature Elimination
        RFE rfe = new RFE(
            new LogisticRegression(), 
            nFeaturestoSelect=20
        );
        
        Dataset selected = rfe.fitTransform(data);
        System.out.println("After RFE: " + selected.getFeatureCount());
        
        // Feature rankings
        int[] rankings = rfe.getRanking();
        List<String> featureNames = data.getFeatureNames();
        
        System.out.println("Top 10 features by RFE ranking:");
        IntStream.range(0, rankings.length)
            .boxed()
            .sorted((i, j) -> Integer.compare(rankings[i], rankings[j]))
            .limit(10)
            .forEach(i -> System.out.printf("%s (rank %d)%n", featureNames.get(i), rankings[i]));
        
        return selected;
    }
}

Real-World Application: Email Spam Detection

Let’s build a complete email spam detection system:

public class EmailSpamDetection {
    
    public static void main(String[] args) {
        EmailSpamDetection detector = new EmailSpamDetection();
        detector.buildSpamDetector();
    }
    
    public void buildSpamDetector() {
        // Load email dataset
        Dataset data = DataLoader.fromCSV("data/emails.csv");
        
        System.out.println("=== Email Spam Detection System ===");
        System.out.println("Total emails: " + data.getRowCount());
        
        // Text preprocessing and feature extraction
        data = extractTextFeatures(data);
        
        // Train and evaluate model
        LogisticRegression model = trainSpamModel(data);
        
        // Deploy model for real-time prediction
        deployModel(model);
    }
    
    private Dataset extractTextFeatures(Dataset data) {
        System.out.println("\n=== Text Feature Extraction ===");
        
        // TF-IDF vectorization
        TfIdfVectorizer vectorizer = new TfIdfVectorizer()
            .setMaxFeatures(5000)
            .setNgrams(1, 2)  // Unigrams and bigrams
            .setStopWords(StopWords.ENGLISH)
            .setMinDf(2)      // Ignore terms that appear in less than 2 documents
            .setMaxDf(0.95);  // Ignore terms that appear in more than 95% of documents
        
        // Apply to email text
        Dataset textFeatures = vectorizer.fitTransform(data, "email_text");
        
        // Add email metadata features
        data = data.withNewColumn("email_length", 
            row -> (double) row.getString("email_text").length());
        
        data = data.withNewColumn("capital_ratio",
            row -> calculateCapitalRatio(row.getString("email_text")));
        
        data = data.withNewColumn("exclamation_count",
            row -> (double) countCharacter(row.getString("email_text"), '!'));
        
        data = data.withNewColumn("url_count",
            row -> (double) countUrls(row.getString("email_text")));
        
        // Combine text features with metadata
        Dataset combined = textFeatures.concat(
            data.selectColumns("email_length", "capital_ratio", 
                             "exclamation_count", "url_count", "is_spam"));
        
        System.out.println("Total features extracted: " + combined.getFeatureCount());
        return combined;
    }
    
    private double calculateCapitalRatio(String text) {
        long capitals = text.chars().filter(Character::isUpperCase).count();
        return (double) capitals / text.length();
    }
    
    private int countCharacter(String text, char character) {
        return (int) text.chars().filter(c -> c == character).count();
    }
    
    private int countUrls(String text) {
        Pattern urlPattern = Pattern.compile("http[s]?://\\S+");
        Matcher matcher = urlPattern.matcher(text);
        int count = 0;
        while (matcher.find()) count++;
        return count;
    }
    
    private LogisticRegression trainSpamModel(Dataset data) {
        System.out.println("\n=== Model Training ===");
        
        // Handle class imbalance
        Map<Object, Integer> classCounts = data.getValueCounts("is_spam");
        System.out.println("Class distribution: " + classCounts);
        
        // Split data
        DataSplit split = data.split(0.8, stratify=true);
        
        // Hyperparameter tuning
        GridSearchCV gridSearch = new GridSearchCV()
            .setEstimator(new LogisticRegression())
            .setParamGrid(Map.of(
                "C", Arrays.asList(0.1, 1.0, 10.0),
                "penalty", Arrays.asList("l1", "l2"),
                "class_weight", Arrays.asList("balanced", null)
            ))
            .setScoring(Scoring.F1)
            .setCv(5);
        
        gridSearch.fit(split.getTrainX(), split.getTrainY());
        LogisticRegression bestModel = gridSearch.getBestEstimator();
        
        // Final evaluation
        double[] predictions = bestModel.predict(split.getTestX());
        double[] probabilities = bestModel.predictProba(split.getTestX());
        
        evaluateSpamModel(split.getTestY(), predictions, probabilities);
        
        return bestModel;
    }
    
    private void evaluateSpamModel(double[] yTrue, double[] yPred, double[] probabilities) {
        System.out.println("\n=== Spam Detection Performance ===");
        
        // Basic metrics
        double accuracy = Metrics.accuracy(yTrue, yPred);
        double precision = Metrics.precision(yTrue, yPred);
        double recall = Metrics.recall(yTrue, yPred);
        double f1 = Metrics.f1Score(yTrue, yPred);
        
        System.out.printf("Accuracy:  %.3f%n", accuracy);
        System.out.printf("Precision: %.3f (%.1f%% of flagged emails are actually spam)%n", 
            precision, precision * 100);
        System.out.printf("Recall:    %.3f (%.1f%% of spam emails are caught)%n", 
            recall, recall * 100);
        System.out.printf("F1-Score:  %.3f%n", f1);
        
        // False positive/negative analysis
        ConfusionMatrix cm = Metrics.confusionMatrix(yTrue, yPred);
        int falsePositives = cm.getFalsePositives();
        int falseNegatives = cm.getFalseNegatives();
        
        System.out.printf("False Positives: %d (legitimate emails marked as spam)%n", falsePositives);
        System.out.printf("False Negatives: %d (spam emails not caught)%n", falseNegatives);
        
        // ROC AUC
        double rocAuc = Metrics.rocAuc(yTrue, probabilities);
        System.out.printf("ROC AUC: %.3f%n", rocAuc);
    }
    
    private void deployModel(LogisticRegression model) {
        System.out.println("\n=== Model Deployment Ready ===");
        
        // Save model for production use
        ModelSerializer.save(model, "spam_detector_model.pkl");
        
        // Example of real-time prediction
        String newEmail = "Congratulations! You've won $1000000! Click here now!!!";
        double spamProbability = predictSingleEmail(model, newEmail);
        
        System.out.printf("Example email spam probability: %.3f%n", spamProbability);
        System.out.println("Classification: " + (spamProbability > 0.5 ? "SPAM" : "NOT SPAM"));
    }
    
    private double predictSingleEmail(LogisticRegression model, String emailText) {
        // This would use the same preprocessing pipeline as training
        // Simplified for demonstration
        Map<String, Object> features = extractEmailFeatures(emailText);
        Dataset singleEmail = Dataset.fromMap(features);
        return model.predictProba(singleEmail.getFeatures())[0];
    }
    
    private Map<String, Object> extractEmailFeatures(String emailText) {
        // Extract same features as during training
        Map<String, Object> features = new HashMap<>();
        features.put("email_length", (double) emailText.length());
        features.put("capital_ratio", calculateCapitalRatio(emailText));
        features.put("exclamation_count", (double) countCharacter(emailText, '!'));
        features.put("url_count", (double) countUrls(emailText));
        // Add TF-IDF features...
        return features;
    }
}

Model Interpretation and Explainability

public class ModelExplainability {
    
    public void explainPredictions(LogisticRegression model, Dataset data) {
        System.out.println("=== Model Explainability ===");
        
        // Global interpretation
        explainGlobalModel(model, data.getFeatureNames());
        
        // Local interpretation for individual predictions
        explainLocalPredictions(model, data);
        
        // Feature importance using permutation
        calculatePermutationImportance(model, data);
    }
    
    private void explainGlobalModel(LogisticRegression model, List<String> featureNames) {
        System.out.println("\n--- Global Model Interpretation ---");
        
        double[] coefficients = model.getCoefficients();
        
        // Sort features by absolute coefficient value
        List<FeatureImportance> importance = new ArrayList<>();
        for (int i = 0; i < featureNames.size(); i++) {
            importance.add(new FeatureImportance(
                featureNames.get(i), 
                coefficients[i], 
                Math.exp(coefficients[i])
            ));
        }
        
        importance.sort((a, b) -> Double.compare(Math.abs(b.coefficient), Math.abs(a.coefficient)));
        
        System.out.println("Top 10 most influential features:");
        importance.stream().limit(10).forEach(fi -> {
            String interpretation = interpretCoefficient(fi.coefficient, fi.oddsRatio);
            System.out.printf("%-25s: coef=%.4f, OR=%.3f (%s)%n", 
                fi.name, fi.coefficient, fi.oddsRatio, interpretation);
        });
    }
    
    private String interpretCoefficient(double coefficient, double oddsRatio) {
        if (coefficient > 0) {
            return String.format("increases odds by %.0f%%", (oddsRatio - 1) * 100);
        } else {
            return String.format("decreases odds by %.0f%%", (1 - oddsRatio) * 100);
        }
    }
    
    private void explainLocalPredictions(LogisticRegression model, Dataset data) {
        System.out.println("\n--- Local Prediction Explanations ---");
        
        // Explain first few predictions
        for (int i = 0; i < Math.min(3, data.getRowCount()); i++) {
            double[] sample = data.getRow(i);
            double probability = model.predictProba(new double[][]{sample})[0];
            
            System.out.printf("\nSample %d prediction probability: %.3f%n", i, probability);
            
            // Calculate contribution of each feature
            double[] contributions = calculateFeatureContributions(model, sample);
            List<String> featureNames = data.getFeatureNames();
            
            // Show top contributing features
            IntStream.range(0, contributions.length)
                .boxed()
                .sorted((a, b) -> Double.compare(Math.abs(contributions[b]), Math.abs(contributions[a])))
                .limit(5)
                .forEach(idx -> {
                    System.out.printf("  %-20s: %.4f (value: %.3f)%n", 
                        featureNames.get(idx), contributions[idx], sample[idx]);
                });
        }
    }
    
    private double[] calculateFeatureContributions(LogisticRegression model, double[] sample) {
        double[] coefficients = model.getCoefficients();
        double[] contributions = new double[sample.length];
        
        for (int i = 0; i < sample.length; i++) {
            contributions[i] = coefficients[i] * sample[i];
        }
        
        return contributions;
    }
    
    private void calculatePermutationImportance(LogisticRegression model, Dataset data) {
        System.out.println("\n--- Permutation Feature Importance ---");
        
        DataSplit split = data.split(0.8);
        double baseScore = Metrics.accuracy(
            split.getTestY(), 
            model.predict(split.getTestX())
        );
        
        List<String> featureNames = data.getFeatureNames();
        Map<String, Double> importance = new HashMap<>();
        
        for (int featureIdx = 0; featureIdx < featureNames.size(); featureIdx++) {
            // Permute feature values
            double[][] permutedX = permute_feature(split.getTestX(), featureIdx);
            
            // Calculate score with permuted feature
            double permutedScore = Metrics.accuracy(
                split.getTestY(),
                model.predict(permutedX)
            );
            
            // Importance is the decrease in performance
            double featureImportance = baseScore - permutedScore;
            importance.put(featureNames.get(featureIdx), featureImportance);
        }
        
        // Display sorted importance
        importance.entrySet().stream()
            .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
            .limit(10)
            .forEach(entry -> 
                System.out.printf("%-25s: %.4f%n", entry.getKey(), entry.getValue()));
    }
    
    private double[][] permute_feature(double[][] X, int featureIdx) {
        double[][] permuted = Arrays.stream(X).map(double[]::clone).toArray(double[][]::new);
        
        // Create a copy of the feature column and shuffle it
        double[] featureColumn = new double[X.length];
        for (int i = 0; i < X.length; i++) {
            featureColumn[i] = X[i][featureIdx];
        }
        
        // Simple shuffle
        Random random = new Random(42);
        for (int i = featureColumn.length - 1; i > 0; i--) {
            int j = random.nextInt(i + 1);
            double temp = featureColumn[i];
            featureColumn[i] = featureColumn[j];
            featureColumn[j] = temp;
        }
        
        // Apply shuffled values
        for (int i = 0; i < permuted.length; i++) {
            permuted[i][featureIdx] = featureColumn[i];
        }
        
        return permuted;
    }
}

Summary

In this comprehensive tutorial, we covered:

  • Binary Classification: Customer churn prediction with detailed evaluation
  • Multiclass Classification: Iris dataset with OvR and multinomial strategies
  • Regularization: L1, L2, and elastic net with hyperparameter tuning
  • Class Imbalance: Handling imbalanced datasets with various techniques
  • Feature Selection: Statistical, L1-based, and recursive feature elimination
  • Real-World Application: Complete email spam detection system
  • Model Interpretation: Global and local explainability techniques

Logistic regression is a powerful and interpretable algorithm that serves as an excellent foundation for classification problems. Its probabilistic output and linear decision boundary make it particularly valuable when you need to understand how features influence predictions.

Key takeaways:

  1. Always check class distribution and handle imbalance appropriately
  2. Use cross-validation for reliable performance estimates
  3. Regularization is crucial for high-dimensional data
  4. Feature engineering can significantly improve performance
  5. Model interpretation is as important as performance metrics

In the next tutorial, we’ll explore tree-based algorithms like Decision Trees and Random Forest, which can capture non-linear relationships and interactions between features.