Press ESC to exit fullscreen
📖 Lesson ⏱️ 120 minutes

Model Serving: REST APIs and FastAPI

Build and deploy a REST API to serve predictions in real-time

From Model to API

You have a trained model. A pickle file sitting in models/churn_model.pkl. Now you need to make it useful to the rest of the organization. The product team needs to show a churn probability in the customer dashboard. The CRM system needs to trigger a retention workflow when churn risk exceeds 80%. The mobile app needs a real-time score when a customer opens it.

None of these consumers can load a pickle file. They need an HTTP API: send a JSON payload, get a JSON response. That’s what this lesson builds.

FastAPI is the right tool for this. It’s fast (comparable to Node.js and Go for I/O-bound workloads), produces automatic OpenAPI documentation, validates inputs with Pydantic, and has excellent async support.


What Makes ML Serving Different

Serving an ML model isn’t like serving a typical web API. The key differences:

Model loading is expensive. Loading a large model (scikit-learn, XGBoost, PyTorch) can take seconds. If you load it on every request, your API will be catastrophically slow. You load it once at startup and keep it in memory.

Preprocessing must match training. Every transformation applied to features during training must be applied identically at inference time. If you one-hot-encoded a column during training, you must one-hot-encode it at inference. If you scaled features with a StandardScaler, you must use the same scaler (same mean and std) at inference — not refit it.

Batching matters at scale. At 1,000 requests per minute with a 50ms model, you’re comfortable. At 10,000 requests per minute, you need to batch requests together and run them through the model in a single forward pass.

Validation is non-negotiable. A model receiving unexpected input (wrong types, missing features, out-of-range values) will either crash or silently return garbage. Validate every input before it reaches the model.


The Complete FastAPI Application

# src/api/app.py
import logging
import os
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional

import joblib
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ─── Data models ─────────────────────────────────────────────────────────────

class CustomerFeatures(BaseModel):
    """Input features for a single customer churn prediction."""
    tenure_months: int = Field(
        ...,
        ge=0,
        le=120,
        description="Number of months the customer has been with us"
    )
    monthly_charges: float = Field(
        ...,
        ge=0.0,
        le=500.0,
        description="Current monthly charge in USD"
    )
    total_charges: float = Field(
        ...,
        ge=0.0,
        description="Total amount charged to date in USD"
    )
    num_products: int = Field(
        ...,
        ge=1,
        le=10,
        description="Number of products/services the customer subscribes to"
    )
    has_support_calls: int = Field(
        ...,
        ge=0,
        le=1,
        description="Whether the customer called support in the last 30 days (0 or 1)"
    )

    @field_validator("has_support_calls")
    @classmethod
    def validate_binary(cls, v):
        if v not in (0, 1):
            raise ValueError("has_support_calls must be 0 or 1")
        return v


class PredictionResponse(BaseModel):
    """Structured response for a single prediction."""
    churn_probability: float = Field(
        ...,
        ge=0.0,
        le=1.0,
        description="Probability of churn (0.0 to 1.0)"
    )
    churn_prediction: bool = Field(
        ...,
        description="True if predicted to churn"
    )
    risk_tier: str = Field(
        ...,
        description="Risk tier: LOW, MEDIUM, HIGH"
    )
    model_version: str

    @classmethod
    def from_probability(cls, prob: float, version: str) -> "PredictionResponse":
        return cls(
            churn_probability=round(prob, 4),
            churn_prediction=prob >= 0.5,
            risk_tier="HIGH" if prob >= 0.7 else "MEDIUM" if prob >= 0.4 else "LOW",
            model_version=version,
        )


class BatchPredictionRequest(BaseModel):
    customers: list[CustomerFeatures] = Field(
        ...,
        min_length=1,
        max_length=100,  # Limit batch size
    )


class BatchPredictionResponse(BaseModel):
    predictions: list[PredictionResponse]
    batch_size: int
    latency_ms: float


class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    model_version: str
    uptime_seconds: float


# ─── Application state ────────────────────────────────────────────────────────

class ModelState:
    model = None
    model_version: str = "unknown"
    start_time: float = 0.0


state = ModelState()

FEATURE_COLUMNS = [
    "tenure_months",
    "monthly_charges",
    "total_charges",
    "num_products",
    "has_support_calls",
]

# ─── Startup and shutdown ─────────────────────────────────────────────────────

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load the model on startup, release on shutdown."""
    # STARTUP
    state.start_time = time.time()
    model_path = Path(os.getenv("MODEL_PATH", "models/churn_model.pkl"))
    version_path = model_path.parent / "model_version.txt"

    logger.info(f"Loading model from {model_path}")
    if not model_path.exists():
        logger.error(f"Model file not found: {model_path}")
        raise RuntimeError(f"Model not found at {model_path}")

    state.model = joblib.load(model_path)
    state.model_version = (
        version_path.read_text().strip()
        if version_path.exists()
        else "1.0.0"
    )
    logger.info(f"Model loaded successfully. Version: {state.model_version}")

    yield  # Application runs here

    # SHUTDOWN
    logger.info("Shutting down. Releasing model from memory.")
    state.model = None


# ─── FastAPI app ──────────────────────────────────────────────────────────────

app = FastAPI(
    title="Churn Prediction API",
    description="Real-time customer churn prediction service",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ─── Middleware: request logging ──────────────────────────────────────────────

@app.middleware("http")
async def log_requests(request: Request, call_next):
    start = time.perf_counter()
    response = await call_next(request)
    duration_ms = (time.perf_counter() - start) * 1000
    logger.info(
        f"{request.method} {request.url.path} "
        f"status={response.status_code} "
        f"latency={duration_ms:.1f}ms"
    )
    return response


# ─── Endpoints ────────────────────────────────────────────────────────────────

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint for load balancers and Kubernetes probes."""
    return HealthResponse(
        status="healthy" if state.model is not None else "unhealthy",
        model_loaded=state.model is not None,
        model_version=state.model_version,
        uptime_seconds=round(time.time() - state.start_time, 1),
    )


@app.post("/predict", response_model=PredictionResponse)
async def predict(customer: CustomerFeatures):
    """Predict churn probability for a single customer."""
    if state.model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    # Build feature array in the exact order the model was trained on
    features = pd.DataFrame(
        [[getattr(customer, col) for col in FEATURE_COLUMNS]],
        columns=FEATURE_COLUMNS,
    )

    # Get probability of churn (class 1)
    prob = float(state.model.predict_proba(features)[0][1])

    return PredictionResponse.from_probability(prob, state.model_version)


@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
    """Predict churn probability for a batch of customers.

    More efficient than calling /predict N times — all predictions
    run in a single model forward pass.
    """
    if state.model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    start = time.perf_counter()

    # Build batch feature matrix
    rows = [
        [getattr(customer, col) for col in FEATURE_COLUMNS]
        for customer in request.customers
    ]
    features = pd.DataFrame(rows, columns=FEATURE_COLUMNS)

    # Single forward pass for all customers
    probabilities = state.model.predict_proba(features)[:, 1]

    predictions = [
        PredictionResponse.from_probability(float(prob), state.model_version)
        for prob in probabilities
    ]

    latency_ms = (time.perf_counter() - start) * 1000

    return BatchPredictionResponse(
        predictions=predictions,
        batch_size=len(predictions),
        latency_ms=round(latency_ms, 2),
    )

Running the API

# Install dependencies
pip install fastapi uvicorn[standard] pydantic joblib scikit-learn pandas

# Start the server
uvicorn src.api.app:app --host 0.0.0.0 --port 8000 --reload

# In production, remove --reload and add workers:
uvicorn src.api.app:app --host 0.0.0.0 --port 8000 --workers 4

Testing with curl and Python

# Health check
curl http://localhost:8000/health

# Single prediction
curl -X POST http://localhost:8000/predict \
  -H "Content-Type: application/json" \
  -d '{
    "tenure_months": 6,
    "monthly_charges": 79.99,
    "total_charges": 479.94,
    "num_products": 1,
    "has_support_calls": 1
  }'

Response:

{
  "churn_probability": 0.7234,
  "churn_prediction": true,
  "risk_tier": "HIGH",
  "model_version": "1.0.0"
}
# test_api.py — integration test
import requests

BASE_URL = "http://localhost:8000"

def test_health():
    resp = requests.get(f"{BASE_URL}/health")
    assert resp.status_code == 200
    data = resp.json()
    assert data["model_loaded"] is True
    assert data["status"] == "healthy"

def test_predict():
    payload = {
        "tenure_months": 6,
        "monthly_charges": 79.99,
        "total_charges": 479.94,
        "num_products": 1,
        "has_support_calls": 1,
    }
    resp = requests.post(f"{BASE_URL}/predict", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert 0.0 <= data["churn_probability"] <= 1.0
    assert data["risk_tier"] in ("LOW", "MEDIUM", "HIGH")

def test_predict_validation_error():
    # tenure_months > 120 should fail validation
    payload = {
        "tenure_months": 999,
        "monthly_charges": 79.99,
        "total_charges": 479.94,
        "num_products": 1,
        "has_support_calls": 1,
    }
    resp = requests.post(f"{BASE_URL}/predict", json=payload)
    assert resp.status_code == 422  # Unprocessable Entity

def test_batch_predict():
    payload = {
        "customers": [
            {"tenure_months": 6, "monthly_charges": 80.0,
             "total_charges": 480.0, "num_products": 1, "has_support_calls": 1},
            {"tenure_months": 48, "monthly_charges": 45.0,
             "total_charges": 2160.0, "num_products": 3, "has_support_calls": 0},
        ]
    }
    resp = requests.post(f"{BASE_URL}/predict/batch", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert data["batch_size"] == 2
    assert len(data["predictions"]) == 2

Auto-Generated API Documentation

FastAPI generates interactive API docs automatically. Visit:

  • http://localhost:8000/docs — Swagger UI (try requests in the browser)
  • http://localhost:8000/redoc — ReDoc (cleaner for documentation)

This documentation is always in sync with your Pydantic models. No manual maintenance.


Performance: Why Model Loading at Startup Matters

Consider 1,000 requests per minute. Each request takes 50ms for model loading + 5ms for inference = 55ms. Total processing: 55,000ms per minute. But if you load once at startup (2,000ms) and then only do 5ms inference: 5,000ms per minute. The startup cost is amortized instantly.

For larger models (PyTorch, transformers), loading can take 30-60 seconds. Per-request loading would make your API unusable.

The lifespan context manager in the code above loads the model exactly once when the server starts. All requests share the same model object in memory. This is thread-safe for sklearn and XGBoost models (they don’t modify internal state during inference). For PyTorch, use torch.no_grad() and ensure you’re not accidentally running .train() on incoming requests.


Horizontal Scaling

When your single server can’t handle the load:

# Run multiple worker processes on the same machine (for CPU models)
uvicorn src.api.app:app --workers 4

# Each worker loads the model independently
# 4 workers = 4 model instances, ~4x throughput

Each worker is a separate OS process, each with its own model copy in memory. For a 500MB model with 4 workers, that’s 2GB of RAM — plan your instance size accordingly.

For GPU models, workers share the GPU through PyTorch’s thread-safe inference. Use 1 worker per GPU to avoid memory overflow.

In Kubernetes (the next-but-one lesson), you scale by adding more pods — each pod is one server process, each with its own model copy.


Summary

The FastAPI serving layer you built in this lesson:

  • Loads the model once at startup — not per request
  • Validates every input with Pydantic before it touches the model
  • Returns structured, typed responses
  • Supports both single and batch predictions
  • Provides a health check endpoint for infrastructure monitoring
  • Generates interactive API documentation automatically

The serving API is the interface between your ML system and the world. Get it right here, and everything downstream — containerization, Kubernetes scaling, monitoring — becomes straightforward. The next lesson covers what happens after you deploy: detecting when your model starts to fail in production.