ML Model Serving & Streaming
Serve ML models efficiently with streaming responses
Overview
Production ML APIs need efficient model serving: loading models at startup, managing multiple versions, handling batch predictions, and streaming responses for real-time applications. This guide covers practical patterns for serving ML models with FastAPI.
Model Loading and Caching
Application Startup Loading
Load models during application startup for optimal performance:
python
from fastapi import FastAPI
from contextlib import asynccontextmanager
import joblib
from pathlib import Path
# Global model storage
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup, clean up on shutdown."""
# Startup: Load models
print("Loading ML models...")
models["classifier_v1"] = joblib.load("models/classifier_v1.pkl")
models["classifier_v2"] = joblib.load("models/classifier_v2.pkl")
models["regressor"] = joblib.load("models/regressor.pkl")
print(f"Loaded {len(models)} models")
yield # Application runs
# Shutdown: Clean up
print("Cleaning up models...")
models.clear()
app = FastAPI(lifespan=lifespan)Dependency Injection for Models
python
from fastapi import Depends, HTTPException
from typing import Annotated
def get_model(model_name: str):
"""Get model by name from global cache."""
if model_name not in models:
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' not found"
)
return models[model_name]
@app.post("/predict/{model_name}")
async def predict(
model_name: str,
features: list[float],
model: Annotated[object, Depends(lambda: get_model(model_name))]
):
"""Use dependency injection for model access."""
prediction = model.predict([features])[0]
return {"prediction": float(prediction)}Prediction Endpoints
Single Prediction
python
from pydantic import BaseModel, Field
class PredictionInput(BaseModel):
"""Single prediction input."""
features: list[float] = Field(min_length=1)
class PredictionOutput(BaseModel):
"""Prediction result."""
prediction: float
model_version: str
confidence: float | None = None
@app.post("/predict", response_model=PredictionOutput)
async def predict_single(
data: PredictionInput,
model_version: str = "v1"
) -> PredictionOutput:
"""Make single prediction."""
model = model_manager.get_model("classifier", model_version)
# Predict
prediction = model.predict([data.features])[0]
# Get confidence if available
confidence = None
if hasattr(model, "predict_proba"):
proba = model.predict_proba([data.features])[0]
confidence = float(max(proba))
return PredictionOutput(
prediction=float(prediction),
model_version=model_version,
confidence=confidence
)Batch Predictions
python
from typing import List
class BatchPredictionInput(BaseModel):
"""Batch prediction input."""
instances: list[list[float]] = Field(min_length=1, max_length=1000)
class BatchPredictionOutput(BaseModel):
"""Batch prediction output."""
predictions: list[float]
count: int
model_version: str
@app.post("/predict/batch", response_model=BatchPredictionOutput)
async def predict_batch(
data: BatchPredictionInput,
model_version: str = "v1"
) -> BatchPredictionOutput:
"""Make batch predictions."""
model = model_manager.get_model("classifier", model_version)
# Batch predict
predictions = model.predict(data.instances)
return BatchPredictionOutput(
predictions=[float(p) for p in predictions],
count=len(predictions),
model_version=model_version
)Async Batch Processing
python
import asyncio
from typing import List
async def predict_async(model: Any, features: list[float]) -> float:
"""Async prediction wrapper."""
# Run blocking predict in thread pool
return await asyncio.to_thread(model.predict, [features])
@app.post("/predict/batch/async")
async def predict_batch_async(
data: BatchPredictionInput,
model_version: str = "v1"
):
"""Process batch predictions concurrently."""
model = model_manager.get_model("classifier", model_version)
# Create tasks for concurrent execution
tasks = [
predict_async(model, features)
for features in data.instances
]
# Wait for all predictions
results = await asyncio.gather(*tasks)
return {
"predictions": [float(r[0]) for r in results],
"count": len(results),
"model_version": model_version
}Model Versioning
Version Management
python
from enum import Enum
class ModelVersion(str, Enum):
"""Available model versions."""
V1 = "v1"
V2 = "v2"
LATEST = "latest"
@app.post("/predict")
async def predict_versioned(
features: list[float],
version: ModelVersion = ModelVersion.LATEST
):
"""Predict with version enum."""
# Map 'latest' to actual version
actual_version = "v2" if version == ModelVersion.LATEST else version.value
model = model_manager.get_model("classifier", actual_version)
prediction = model.predict([features])[0]
return {
"prediction": float(prediction),
"version_requested": version.value,
"version_used": actual_version
}A/B Testing
python
import random
@app.post("/predict/ab-test")
async def predict_ab_test(
features: list[float],
user_id: str
):
"""A/B test between model versions."""
# Deterministic assignment based on user_id
version = "v2" if hash(user_id) % 2 == 0 else "v1"
model = model_manager.get_model("classifier", version)
prediction = model.predict([features])[0]
return {
"prediction": float(prediction),
"model_version": version,
"user_id": user_id
}Server-Sent Events (SSE)
SSE enables server-to-client streaming for real-time updates.
Basic SSE Implementation
python
from fastapi.responses import StreamingResponse
import asyncio
async def event_generator(count: int):
"""Generate server-sent events."""
for i in range(count):
# Simulate processing
await asyncio.sleep(0.5)
# Yield SSE formatted data
yield f"data: {i}\n\n"
@app.get("/stream/events")
async def stream_events(count: int = 10):
"""Stream events to client."""
return StreamingResponse(
event_generator(count),
media_type="text/event-stream"
)Client consumption (JavaScript):
javascript
const eventSource = new EventSource("/stream/events?count=10");
eventSource.onmessage = (event) => {
console.log("Received:", event.data);
};
eventSource.onerror = () => {
eventSource.close();
};Streaming Predictions
python
from typing import AsyncGenerator
async def stream_predictions(
features_list: list[list[float]],
model: Any
) -> AsyncGenerator[str, None]:
"""Stream predictions as they complete."""
for idx, features in enumerate(features_list):
# Make prediction
prediction = await asyncio.to_thread(
model.predict,
[features]
)
# Format as SSE
result = {
"index": idx,
"prediction": float(prediction[0])
}
yield f"data: {json.dumps(result)}\n\n"
# Small delay between predictions
await asyncio.sleep(0.1)
@app.post("/predict/stream")
async def predict_stream(data: BatchPredictionInput, model_version: str = "v1"):
"""Stream batch predictions."""
model = model_manager.get_model("classifier", model_version)
return StreamingResponse(
stream_predictions(data.instances, model),
media_type="text/event-stream"
)Streaming LLM Responses
Perfect for token-by-token text generation:
python
async def generate_text_stream(prompt: str) -> AsyncGenerator[str, None]:
"""Simulate LLM streaming generation."""
tokens = [
"The", "quick", "brown", "fox", "jumps",
"over", "the", "lazy", "dog"
]
for token in tokens:
await asyncio.sleep(0.2) # Simulate generation time
# Stream each token
yield f"data: {json.dumps({'token': token})}\n\n"
# Send completion signal
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/generate/stream")
async def stream_generation(prompt: str):
"""Stream LLM token generation."""
return StreamingResponse(
generate_text_stream(prompt),
media_type="text/event-stream"
)Progress Updates
python
async def training_progress(epochs: int) -> AsyncGenerator[str, None]:
"""Stream training progress."""
for epoch in range(1, epochs + 1):
await asyncio.sleep(1) # Simulate epoch
progress = {
"epoch": epoch,
"total_epochs": epochs,
"loss": 1.0 / epoch, # Simulated loss
"accuracy": 1.0 - (1.0 / epoch) # Simulated accuracy
}
yield f"data: {json.dumps(progress)}\n\n"
# Final message
yield f"data: {json.dumps({'status': 'complete'})}\n\n"
@app.post("/train/stream")
async def stream_training(epochs: int = 10):
"""Stream training progress."""
return StreamingResponse(
training_progress(epochs),
media_type="text/event-stream"
)WebSocket for Real-Time Predictions
WebSocket enables bidirectional communication for interactive ML applications.
Basic WebSocket
python
from fastapi import WebSocket, WebSocketDisconnect
@app.websocket("/ws/predict")
async def websocket_predict(websocket: WebSocket):
"""WebSocket prediction endpoint."""
await websocket.accept()
try:
while True:
# Receive data from client
data = await websocket.receive_json()
features = data.get("features", [])
model_version = data.get("version", "v1")
# Make prediction
model = model_manager.get_model("classifier", model_version)
prediction = model.predict([features])[0]
# Send response
await websocket.send_json({
"prediction": float(prediction),
"model_version": model_version
})
except WebSocketDisconnect:
print("Client disconnected")Client usage (JavaScript):
javascript
const ws = new WebSocket("ws://localhost:8000/ws/predict");
ws.onopen = () => {
ws.send(
JSON.stringify({
features: [1.0, 2.0, 3.0],
version: "v1",
}),
);
};
ws.onmessage = (event) => {
const result = JSON.parse(event.data);
console.log("Prediction:", result.prediction);
};Connection Manager
python
from typing import List
class ConnectionManager:
"""Manage WebSocket connections."""
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
"""Accept and store connection."""
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
"""Remove connection."""
self.active_connections.remove(websocket)
async def send_personal(self, message: dict, websocket: WebSocket):
"""Send to specific connection."""
await websocket.send_json(message)
async def broadcast(self, message: dict):
"""Send to all connections."""
for connection in self.active_connections:
await connection.send_json(message)
manager = ConnectionManager()
@app.websocket("/ws/predictions")
async def websocket_endpoint(websocket: WebSocket):
"""Managed WebSocket connection."""
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_json()
# Process prediction
features = data["features"]
model = model_manager.get_model("classifier", "v1")
prediction = model.predict([features])[0]
# Send back to client
await manager.send_personal(
{"prediction": float(prediction)},
websocket
)
except WebSocketDisconnect:
manager.disconnect(websocket)Real-Time Monitoring
python
import time
@app.websocket("/ws/monitor")
async def monitor_stream(websocket: WebSocket):
"""Stream real-time model metrics."""
await websocket.accept()
try:
while True:
# Collect metrics
metrics = {
"timestamp": time.time(),
"active_models": len(model_manager.models),
"memory_mb": 128.5, # Placeholder
"requests_per_sec": 45.2 # Placeholder
}
await websocket.send_json(metrics)
await asyncio.sleep(1) # Update every second
except WebSocketDisconnect:
passStreaming File Uploads/Downloads
Stream File Upload
python
from fastapi import UploadFile
@app.post("/upload/stream")
async def upload_stream(file: UploadFile):
"""Stream file upload processing."""
chunk_size = 1024 * 1024 # 1MB chunks
total_size = 0
async def process_chunks():
"""Process file in chunks."""
nonlocal total_size
while chunk := await file.read(chunk_size):
total_size += len(chunk)
# Process chunk (e.g., save to disk, upload to S3)
yield f"data: {json.dumps({'bytes_processed': total_size})}\n\n"
return StreamingResponse(
process_chunks(),
media_type="text/event-stream"
)Stream File Download
python
from pathlib import Path
async def file_stream(file_path: Path):
"""Stream file download in chunks."""
chunk_size = 1024 * 1024 # 1MB
with open(file_path, "rb") as f:
while chunk := f.read(chunk_size):
yield chunk
@app.get("/download/model/{model_name}")
async def download_model(model_name: str):
"""Stream model file download."""
file_path = Path(f"models/{model_name}.pkl")
if not file_path.exists():
raise HTTPException(status_code=404, detail="Model not found")
return StreamingResponse(
file_stream(file_path),
media_type="application/octet-stream",
headers={
"Content-Disposition": f"attachment; filename={model_name}.pkl"
}
)When to Use Each Streaming Method
Server-Sent Events (SSE)
Use when:
- Server pushes updates to client
- One-way communication (server → client)
- Progress updates, notifications, logs
- LLM token streaming
- Training progress monitoring
Advantages:
- Simple to implement
- Automatic reconnection
- Works over HTTP
- Native browser support
WebSocket
Use when:
- Bidirectional communication needed
- Real-time interactive applications
- Chat, collaborative editing
- Gaming, live predictions
- Low latency required
Advantages:
- Full-duplex communication
- Lower overhead than SSE
- Better for high-frequency updates
Regular HTTP Streaming
Use when:
- File uploads/downloads
- Large data transfers
- Video/audio streaming
Complete Example: Real-Time Prediction Service
python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import asyncio
import json
# Model manager (from previous examples)
model_manager = ModelManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup."""
await model_manager.load_model("classifier", "v1")
await model_manager.load_model("classifier", "v2")
yield
model_manager.models.clear()
app = FastAPI(lifespan=lifespan)
# SSE: Stream batch predictions
@app.post("/predict/stream")
async def stream_predictions(data: BatchPredictionInput):
"""Stream predictions as SSE."""
async def predict_stream() -> AsyncGenerator[str, None]:
model = model_manager.get_model("classifier", "v1")
for idx, features in enumerate(data.instances):
prediction = await asyncio.to_thread(
model.predict,
[features]
)
result = {
"index": idx,
"prediction": float(prediction[0])
}
yield f"data: {json.dumps(result)}\n\n"
await asyncio.sleep(0.05)
return StreamingResponse(
predict_stream(),
media_type="text/event-stream"
)
# WebSocket: Interactive predictions
@app.websocket("/ws/predict")
async def websocket_predict(websocket: WebSocket):
"""Real-time prediction over WebSocket."""
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
features = data["features"]
version = data.get("version", "v1")
model = model_manager.get_model("classifier", version)
prediction = await asyncio.to_thread(
model.predict,
[features]
)
await websocket.send_json({
"prediction": float(prediction[0]),
"version": version
})
except WebSocketDisconnect:
passBest Practices
- Load models at startup - Don't reload on each request
- Use async for I/O - Model loading, file operations
- Cache predictions - For identical inputs
- Batch when possible - More efficient than single predictions
- Choose right streaming method - SSE for push, WebSocket for bidirectional
- Handle disconnections - Clean up resources properly
- Monitor memory - Unload unused models
- Version models - Support multiple versions simultaneously
Next Steps
Continue to Testing FastAPI to learn how to test your ML APIs comprehensively.