Course Content
Model Serving: REST APIs and FastAPI
Build and deploy a REST API to serve predictions in real-time
From Model to API
You have a trained model. A pickle file sitting in models/churn_model.pkl. Now you need to make it useful to the rest of the organization. The product team needs to show a churn probability in the customer dashboard. The CRM system needs to trigger a retention workflow when churn risk exceeds 80%. The mobile app needs a real-time score when a customer opens it.
None of these consumers can load a pickle file. They need an HTTP API: send a JSON payload, get a JSON response. That’s what this lesson builds.
FastAPI is the right tool for this. It’s fast (comparable to Node.js and Go for I/O-bound workloads), produces automatic OpenAPI documentation, validates inputs with Pydantic, and has excellent async support.
What Makes ML Serving Different
Serving an ML model isn’t like serving a typical web API. The key differences:
Model loading is expensive. Loading a large model (scikit-learn, XGBoost, PyTorch) can take seconds. If you load it on every request, your API will be catastrophically slow. You load it once at startup and keep it in memory.
Preprocessing must match training. Every transformation applied to features during training must be applied identically at inference time. If you one-hot-encoded a column during training, you must one-hot-encode it at inference. If you scaled features with a StandardScaler, you must use the same scaler (same mean and std) at inference — not refit it.
Batching matters at scale. At 1,000 requests per minute with a 50ms model, you’re comfortable. At 10,000 requests per minute, you need to batch requests together and run them through the model in a single forward pass.
Validation is non-negotiable. A model receiving unexpected input (wrong types, missing features, out-of-range values) will either crash or silently return garbage. Validate every input before it reaches the model.
The Complete FastAPI Application
# src/api/app.py
import logging
import os
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
import joblib
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ─── Data models ─────────────────────────────────────────────────────────────
class CustomerFeatures(BaseModel):
"""Input features for a single customer churn prediction."""
tenure_months: int = Field(
...,
ge=0,
le=120,
description="Number of months the customer has been with us"
)
monthly_charges: float = Field(
...,
ge=0.0,
le=500.0,
description="Current monthly charge in USD"
)
total_charges: float = Field(
...,
ge=0.0,
description="Total amount charged to date in USD"
)
num_products: int = Field(
...,
ge=1,
le=10,
description="Number of products/services the customer subscribes to"
)
has_support_calls: int = Field(
...,
ge=0,
le=1,
description="Whether the customer called support in the last 30 days (0 or 1)"
)
@field_validator("has_support_calls")
@classmethod
def validate_binary(cls, v):
if v not in (0, 1):
raise ValueError("has_support_calls must be 0 or 1")
return v
class PredictionResponse(BaseModel):
"""Structured response for a single prediction."""
churn_probability: float = Field(
...,
ge=0.0,
le=1.0,
description="Probability of churn (0.0 to 1.0)"
)
churn_prediction: bool = Field(
...,
description="True if predicted to churn"
)
risk_tier: str = Field(
...,
description="Risk tier: LOW, MEDIUM, HIGH"
)
model_version: str
@classmethod
def from_probability(cls, prob: float, version: str) -> "PredictionResponse":
return cls(
churn_probability=round(prob, 4),
churn_prediction=prob >= 0.5,
risk_tier="HIGH" if prob >= 0.7 else "MEDIUM" if prob >= 0.4 else "LOW",
model_version=version,
)
class BatchPredictionRequest(BaseModel):
customers: list[CustomerFeatures] = Field(
...,
min_length=1,
max_length=100, # Limit batch size
)
class BatchPredictionResponse(BaseModel):
predictions: list[PredictionResponse]
batch_size: int
latency_ms: float
class HealthResponse(BaseModel):
status: str
model_loaded: bool
model_version: str
uptime_seconds: float
# ─── Application state ────────────────────────────────────────────────────────
class ModelState:
model = None
model_version: str = "unknown"
start_time: float = 0.0
state = ModelState()
FEATURE_COLUMNS = [
"tenure_months",
"monthly_charges",
"total_charges",
"num_products",
"has_support_calls",
]
# ─── Startup and shutdown ─────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load the model on startup, release on shutdown."""
# STARTUP
state.start_time = time.time()
model_path = Path(os.getenv("MODEL_PATH", "models/churn_model.pkl"))
version_path = model_path.parent / "model_version.txt"
logger.info(f"Loading model from {model_path}")
if not model_path.exists():
logger.error(f"Model file not found: {model_path}")
raise RuntimeError(f"Model not found at {model_path}")
state.model = joblib.load(model_path)
state.model_version = (
version_path.read_text().strip()
if version_path.exists()
else "1.0.0"
)
logger.info(f"Model loaded successfully. Version: {state.model_version}")
yield # Application runs here
# SHUTDOWN
logger.info("Shutting down. Releasing model from memory.")
state.model = None
# ─── FastAPI app ──────────────────────────────────────────────────────────────
app = FastAPI(
title="Churn Prediction API",
description="Real-time customer churn prediction service",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─── Middleware: request logging ──────────────────────────────────────────────
@app.middleware("http")
async def log_requests(request: Request, call_next):
start = time.perf_counter()
response = await call_next(request)
duration_ms = (time.perf_counter() - start) * 1000
logger.info(
f"{request.method} {request.url.path} "
f"status={response.status_code} "
f"latency={duration_ms:.1f}ms"
)
return response
# ─── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint for load balancers and Kubernetes probes."""
return HealthResponse(
status="healthy" if state.model is not None else "unhealthy",
model_loaded=state.model is not None,
model_version=state.model_version,
uptime_seconds=round(time.time() - state.start_time, 1),
)
@app.post("/predict", response_model=PredictionResponse)
async def predict(customer: CustomerFeatures):
"""Predict churn probability for a single customer."""
if state.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Build feature array in the exact order the model was trained on
features = pd.DataFrame(
[[getattr(customer, col) for col in FEATURE_COLUMNS]],
columns=FEATURE_COLUMNS,
)
# Get probability of churn (class 1)
prob = float(state.model.predict_proba(features)[0][1])
return PredictionResponse.from_probability(prob, state.model_version)
@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
"""Predict churn probability for a batch of customers.
More efficient than calling /predict N times — all predictions
run in a single model forward pass.
"""
if state.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start = time.perf_counter()
# Build batch feature matrix
rows = [
[getattr(customer, col) for col in FEATURE_COLUMNS]
for customer in request.customers
]
features = pd.DataFrame(rows, columns=FEATURE_COLUMNS)
# Single forward pass for all customers
probabilities = state.model.predict_proba(features)[:, 1]
predictions = [
PredictionResponse.from_probability(float(prob), state.model_version)
for prob in probabilities
]
latency_ms = (time.perf_counter() - start) * 1000
return BatchPredictionResponse(
predictions=predictions,
batch_size=len(predictions),
latency_ms=round(latency_ms, 2),
)Running the API
# Install dependencies
pip install fastapi uvicorn[standard] pydantic joblib scikit-learn pandas
# Start the server
uvicorn src.api.app:app --host 0.0.0.0 --port 8000 --reload
# In production, remove --reload and add workers:
uvicorn src.api.app:app --host 0.0.0.0 --port 8000 --workers 4Testing with curl and Python
# Health check
curl http://localhost:8000/health
# Single prediction
curl -X POST http://localhost:8000/predict \
-H "Content-Type: application/json" \
-d '{
"tenure_months": 6,
"monthly_charges": 79.99,
"total_charges": 479.94,
"num_products": 1,
"has_support_calls": 1
}'Response:
{
"churn_probability": 0.7234,
"churn_prediction": true,
"risk_tier": "HIGH",
"model_version": "1.0.0"
}# test_api.py — integration test
import requests
BASE_URL = "http://localhost:8000"
def test_health():
resp = requests.get(f"{BASE_URL}/health")
assert resp.status_code == 200
data = resp.json()
assert data["model_loaded"] is True
assert data["status"] == "healthy"
def test_predict():
payload = {
"tenure_months": 6,
"monthly_charges": 79.99,
"total_charges": 479.94,
"num_products": 1,
"has_support_calls": 1,
}
resp = requests.post(f"{BASE_URL}/predict", json=payload)
assert resp.status_code == 200
data = resp.json()
assert 0.0 <= data["churn_probability"] <= 1.0
assert data["risk_tier"] in ("LOW", "MEDIUM", "HIGH")
def test_predict_validation_error():
# tenure_months > 120 should fail validation
payload = {
"tenure_months": 999,
"monthly_charges": 79.99,
"total_charges": 479.94,
"num_products": 1,
"has_support_calls": 1,
}
resp = requests.post(f"{BASE_URL}/predict", json=payload)
assert resp.status_code == 422 # Unprocessable Entity
def test_batch_predict():
payload = {
"customers": [
{"tenure_months": 6, "monthly_charges": 80.0,
"total_charges": 480.0, "num_products": 1, "has_support_calls": 1},
{"tenure_months": 48, "monthly_charges": 45.0,
"total_charges": 2160.0, "num_products": 3, "has_support_calls": 0},
]
}
resp = requests.post(f"{BASE_URL}/predict/batch", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["batch_size"] == 2
assert len(data["predictions"]) == 2Auto-Generated API Documentation
FastAPI generates interactive API docs automatically. Visit:
http://localhost:8000/docs— Swagger UI (try requests in the browser)http://localhost:8000/redoc— ReDoc (cleaner for documentation)
This documentation is always in sync with your Pydantic models. No manual maintenance.
Performance: Why Model Loading at Startup Matters
Consider 1,000 requests per minute. Each request takes 50ms for model loading + 5ms for inference = 55ms. Total processing: 55,000ms per minute. But if you load once at startup (2,000ms) and then only do 5ms inference: 5,000ms per minute. The startup cost is amortized instantly.
For larger models (PyTorch, transformers), loading can take 30-60 seconds. Per-request loading would make your API unusable.
The lifespan context manager in the code above loads the model exactly once when the server starts. All requests share the same model object in memory. This is thread-safe for sklearn and XGBoost models (they don’t modify internal state during inference). For PyTorch, use torch.no_grad() and ensure you’re not accidentally running .train() on incoming requests.
Horizontal Scaling
When your single server can’t handle the load:
# Run multiple worker processes on the same machine (for CPU models)
uvicorn src.api.app:app --workers 4
# Each worker loads the model independently
# 4 workers = 4 model instances, ~4x throughputEach worker is a separate OS process, each with its own model copy in memory. For a 500MB model with 4 workers, that’s 2GB of RAM — plan your instance size accordingly.
For GPU models, workers share the GPU through PyTorch’s thread-safe inference. Use 1 worker per GPU to avoid memory overflow.
In Kubernetes (the next-but-one lesson), you scale by adding more pods — each pod is one server process, each with its own model copy.
Summary
The FastAPI serving layer you built in this lesson:
- Loads the model once at startup — not per request
- Validates every input with Pydantic before it touches the model
- Returns structured, typed responses
- Supports both single and batch predictions
- Provides a health check endpoint for infrastructure monitoring
- Generates interactive API documentation automatically
The serving API is the interface between your ML system and the world. Get it right here, and everything downstream — containerization, Kubernetes scaling, monitoring — becomes straightforward. The next lesson covers what happens after you deploy: detecting when your model starts to fail in production.
