Testing FastAPI
Write comprehensive tests for production-ready ML APIs
Overview
Testing is critical for ML APIs: ensuring endpoints work correctly, validating model predictions, testing database interactions, and verifying error handling. This guide covers testing strategies for FastAPI applications using pytest.
Test Setup
Installation
bash
# Install testing dependencies
uv add --dev pytest pytest-asyncio httpxTest Project Structure
ml-api/
├── src/
│ ├── main.py
│ ├── routes/
│ └── models/
├── tests/
│ ├── __init__.py
│ ├── conftest.py # Shared fixtures
│ ├── test_api.py # API endpoint tests
│ ├── test_models.py # Model tests
│ ├── test_database.py # Database tests
│ └── test_predictions.py # Prediction logic tests
└── pytest.inipytest Configuration
pytest.ini:
ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = autoconftest.py - Shared Fixtures
python
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import NullPool
from src.main import app
from src.database import Base, get_db
from src.models.prediction import ModelVersion, Prediction
# Test database URL
TEST_DATABASE_URL = "postgresql+asyncpg://test:test@localhost:5432/test_mldb"
# Test engine
test_engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=NullPool, # Don't pool connections in tests
)
TestSessionLocal = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
@pytest_asyncio.fixture
async def test_db():
"""Create test database tables."""
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session(test_db):
"""Provide database session for tests."""
async with TestSessionLocal() as session:
yield session
@pytest_asyncio.fixture
async def client(db_session):
"""Provide test client with overridden dependencies."""
async def override_get_db():
"""Override database dependency."""
yield db_session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()Testing Endpoints
Basic Endpoint Tests
python
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_root_endpoint(client: AsyncClient):
"""Test root endpoint."""
response = await client.get("/")
assert response.status_code == 200
assert "name" in response.json()
@pytest.mark.asyncio
async def test_health_check(client: AsyncClient):
"""Test health check endpoint."""
response = await client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"Testing POST Endpoints
python
@pytest.mark.asyncio
async def test_create_prediction(client: AsyncClient, db_session):
"""Test prediction creation."""
# Create model version first
model = ModelVersion(
name="test_model",
version="1.0.0",
framework="sklearn"
)
db_session.add(model)
await db_session.commit()
await db_session.refresh(model)
# Create prediction
payload = {
"model_id": model.id,
"input_data": {"features": [1.0, 2.0, 3.0]}
}
response = await client.post("/api/v1/predictions", json=payload)
assert response.status_code == 201
data = response.json()
assert "id" in data
assert data["model_id"] == model.id
assert "prediction" in dataTesting Query Parameters
python
@pytest.mark.asyncio
async def test_list_predictions_with_filters(client: AsyncClient, db_session):
"""Test prediction listing with filters."""
# Create test data
model = ModelVersion(name="test", version="1.0", framework="sklearn")
db_session.add(model)
await db_session.commit()
await db_session.refresh(model)
for i in range(5):
prediction = Prediction(
model_id=model.id,
input_data={"features": [i]},
prediction=float(i)
)
db_session.add(prediction)
await db_session.commit()
# Test pagination
response = await client.get("/api/v1/predictions?skip=2&limit=2")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
# Test filtering
response = await client.get(f"/api/v1/predictions?model_id={model.id}")
assert response.status_code == 200
assert all(p["model_id"] == model.id for p in response.json())Testing Error Cases
python
@pytest.mark.asyncio
async def test_prediction_not_found(client: AsyncClient):
"""Test 404 for non-existent prediction."""
response = await client.get("/api/v1/predictions/99999")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_invalid_input(client: AsyncClient):
"""Test validation error."""
payload = {
"model_id": "invalid", # Should be int
"input_data": {}
}
response = await client.post("/api/v1/predictions", json=payload)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_prediction_with_missing_model(client: AsyncClient):
"""Test prediction with non-existent model."""
payload = {
"model_id": 99999,
"input_data": {"features": [1.0]}
}
response = await client.post("/api/v1/predictions", json=payload)
assert response.status_code == 404Testing Async Endpoints
Async Test Functions
python
@pytest.mark.asyncio
async def test_async_prediction(client: AsyncClient):
"""Test async prediction endpoint."""
payload = {
"instances": [[1.0, 2.0], [3.0, 4.0]]
}
response = await client.post("/predict/batch/async", json=payload)
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert len(data["predictions"]) == 2Testing Background Tasks
python
import asyncio
@pytest.mark.asyncio
async def test_background_task(client: AsyncClient, db_session):
"""Test endpoint with background task."""
payload = {
"features": [1.0, 2.0, 3.0],
"user_id": 123
}
response = await client.post("/predict/json", json=payload)
assert response.status_code == 200
# Wait for background task
await asyncio.sleep(0.5)
# Verify background task executed
# (Check logs, database, or other side effects)Mocking Dependencies
Override Dependencies
python
from unittest.mock import Mock, AsyncMock
@pytest.mark.asyncio
async def test_with_mocked_model(client: AsyncClient):
"""Test with mocked model dependency."""
mock_model = Mock()
mock_model.predict.return_value = [0.75]
async def mock_get_model():
return mock_model
# Override dependency
from src.dependencies import get_model
app.dependency_overrides[get_model] = mock_get_model
response = await client.post(
"/predict",
json={"features": [1.0, 2.0]}
)
assert response.status_code == 200
assert response.json()["prediction"] == 0.75
# Clean up
app.dependency_overrides.clear()Mock External Services
python
from unittest.mock import patch, AsyncMock
@pytest.mark.asyncio
@patch("src.ml.inference.external_api_call")
async def test_with_external_api_mock(mock_api, client: AsyncClient):
"""Test with mocked external API."""
mock_api.return_value = {"result": "success"}
response = await client.post("/process", json={"data": "test"})
assert response.status_code == 200
mock_api.assert_called_once()Database Testing
Test Database Operations
python
from sqlalchemy import select
@pytest.mark.asyncio
async def test_create_model_version(db_session):
"""Test model version creation."""
model = ModelVersion(
name="classifier",
version="1.0.0",
framework="sklearn",
metrics={"accuracy": 0.95}
)
db_session.add(model)
await db_session.commit()
await db_session.refresh(model)
assert model.id is not None
assert model.name == "classifier"
assert model.metrics["accuracy"] == 0.95
@pytest.mark.asyncio
async def test_query_predictions(db_session):
"""Test querying predictions."""
# Create test data
model = ModelVersion(name="test", version="1.0", framework="sklearn")
db_session.add(model)
await db_session.commit()
await db_session.refresh(model)
prediction = Prediction(
model_id=model.id,
input_data={"features": [1.0]},
prediction=0.8
)
db_session.add(prediction)
await db_session.commit()
# Query
stmt = select(Prediction).where(Prediction.model_id == model.id)
result = await db_session.execute(stmt)
predictions = result.scalars().all()
assert len(predictions) == 1
assert predictions[0].prediction == 0.8Test Relationships
python
@pytest.mark.asyncio
async def test_model_prediction_relationship(db_session):
"""Test relationship between model and predictions."""
from sqlalchemy.orm import selectinload
# Create model with predictions
model = ModelVersion(name="test", version="1.0", framework="sklearn")
db_session.add(model)
await db_session.flush()
for i in range(3):
prediction = Prediction(
model_id=model.id,
input_data={"features": [i]},
prediction=float(i)
)
db_session.add(prediction)
await db_session.commit()
# Load model with predictions
stmt = (
select(ModelVersion)
.options(selectinload(ModelVersion.predictions))
.where(ModelVersion.id == model.id)
)
result = await db_session.execute(stmt)
loaded_model = result.scalar_one()
assert len(loaded_model.predictions) == 3Testing File Uploads
python
from io import BytesIO
@pytest.mark.asyncio
async def test_file_upload(client: AsyncClient):
"""Test file upload endpoint."""
file_content = b"test,data\n1.0,2.0\n3.0,4.0"
files = {
"file": ("test.csv", BytesIO(file_content), "text/csv")
}
response = await client.post("/upload-model", files=files)
assert response.status_code == 200
data = response.json()
assert data["filename"] == "test.csv"
assert data["size"] == len(file_content)
@pytest.mark.asyncio
async def test_invalid_file_type(client: AsyncClient):
"""Test file upload with invalid type."""
files = {
"file": ("test.txt", BytesIO(b"invalid"), "text/plain")
}
response = await client.post("/upload-model", files=files)
assert response.status_code == 400
assert "not allowed" in response.json()["detail"].lower()Testing Streaming Endpoints
Test SSE Streaming
python
@pytest.mark.asyncio
async def test_sse_stream(client: AsyncClient):
"""Test server-sent events endpoint."""
async with client.stream(
"POST",
"/predict/stream",
json={"instances": [[1.0], [2.0], [3.0]]}
) as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
events = []
async for line in response.aiter_lines():
if line.startswith("data: "):
data = json.loads(line[6:])
events.append(data)
assert len(events) == 3
assert all("prediction" in e for e in events)Test WebSocket
python
@pytest.mark.asyncio
async def test_websocket_prediction():
"""Test WebSocket endpoint."""
from fastapi.testclient import TestClient
with TestClient(app) as test_client:
with test_client.websocket_connect("/ws/predict") as websocket:
# Send data
websocket.send_json({
"features": [1.0, 2.0, 3.0],
"version": "v1"
})
# Receive response
data = websocket.receive_json()
assert "prediction" in data
assert "version" in data
assert data["version"] == "v1"Testing Model Predictions
Test Model Logic
python
import numpy as np
def test_model_prediction():
"""Test model prediction logic."""
from src.ml.model_loader import load_model
model = load_model("classifier", "v1")
features = np.array([[1.0, 2.0, 3.0]])
prediction = model.predict(features)
assert prediction is not None
assert len(prediction) == 1
assert isinstance(prediction[0], (int, float, np.number))
def test_model_batch_prediction():
"""Test batch predictions."""
from src.ml.model_loader import load_model
model = load_model("classifier", "v1")
features = np.array([
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]
])
predictions = model.predict(features)
assert len(predictions) == 3Test Model Versioning
python
@pytest.mark.asyncio
async def test_model_version_selection(client: AsyncClient):
"""Test selecting different model versions."""
payload = {"features": [1.0, 2.0, 3.0]}
# Test v1
response_v1 = await client.post(
"/predict?version=v1",
json=payload
)
assert response_v1.status_code == 200
result_v1 = response_v1.json()
# Test v2
response_v2 = await client.post(
"/predict?version=v2",
json=payload
)
assert response_v2.status_code == 200
result_v2 = response_v2.json()
# Results may differ between versions
assert result_v1["version"] == "v1"
assert result_v2["version"] == "v2"Integration Tests
Full Workflow Test
python
@pytest.mark.asyncio
async def test_full_prediction_workflow(client: AsyncClient, db_session):
"""Test complete prediction workflow."""
# 1. Create model version
model_data = {
"name": "classifier",
"version": "1.0.0",
"framework": "sklearn"
}
response = await client.post("/api/v1/models", json=model_data)
assert response.status_code == 201
model = response.json()
# 2. Make prediction
prediction_data = {
"model_id": model["id"],
"input_data": {"features": [1.0, 2.0, 3.0]}
}
response = await client.post("/api/v1/predictions", json=prediction_data)
assert response.status_code == 201
prediction = response.json()
# 3. Retrieve prediction
response = await client.get(f"/api/v1/predictions/{prediction['id']}")
assert response.status_code == 200
retrieved = response.json()
assert retrieved["id"] == prediction["id"]
# 4. List predictions
response = await client.get(f"/api/v1/predictions?model_id={model['id']}")
assert response.status_code == 200
predictions = response.json()
assert len(predictions) >= 1Performance Testing
Load Testing with pytest-benchmark
bash
uv add --dev pytest-benchmarkpython
def test_prediction_performance(benchmark):
"""Benchmark prediction performance."""
from src.ml.model_loader import load_model
import numpy as np
model = load_model("classifier", "v1")
features = np.array([[1.0, 2.0, 3.0]])
result = benchmark(model.predict, features)
assert result is not NoneConcurrent Request Testing
python
import asyncio
@pytest.mark.asyncio
async def test_concurrent_predictions(client: AsyncClient):
"""Test handling concurrent requests."""
payload = {"features": [1.0, 2.0, 3.0]}
# Create 100 concurrent requests
tasks = [
client.post("/predict", json=payload)
for _ in range(100)
]
responses = await asyncio.gather(*tasks)
# All should succeed
assert all(r.status_code == 200 for r in responses)
# All should have predictions
assert all("prediction" in r.json() for r in responses)Test Coverage
Generate Coverage Report
bash
# Install coverage
uv add --dev pytest-cov
# Run tests with coverage
pytest --cov=src --cov-report=html
# View report
# Open htmlcov/index.html in browserCoverage Configuration
.coveragerc:
ini
[run]
source = src
omit =
*/tests/*
*/migrations/*
*/__init__.py
[report]
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:Best Practices
- Test happy paths and errors - Both success and failure cases
- Use fixtures - Share setup code across tests
- Mock external dependencies - Don't call real APIs in tests
- Test database isolation - Each test uses clean database
- Test async properly - Use pytest-asyncio for async tests
- Meaningful assertions - Assert specific expected values
- Test edge cases - Empty inputs, large batches, invalid data
- Fast tests - Mock slow operations like model loading
- Clear test names - Describe what is being tested
- Arrange-Act-Assert - Structure tests clearly
Running Tests
bash
# Run all tests
pytest
# Run specific test file
pytest tests/test_api.py
# Run specific test
pytest tests/test_api.py::test_create_prediction
# Run with verbose output
pytest -v
# Run with coverage
pytest --cov=src
# Run in parallel (requires pytest-xdist)
pytest -n auto
# Run only failed tests
pytest --lfNext Steps
Continue to Production Deployment to learn how to deploy your tested FastAPI application.