FastAPI

FastAPI WebSockets

Thirdy Gayares
20 min read

🎓 What You Will Learn

  • WebSocket Basics: Establish persistent, bidirectional connections with clients
  • Connection Management: Handle connect, disconnect, and error scenarios
  • Broadcasting: Send messages to multiple connected clients efficiently
  • Chat Applications: Build real-time chat with user management
  • Live Notifications: Push real-time updates to clients instantly
  • Production Patterns: Scaling, reconnection logic, and error handling
FastAPIWebSocketReal-time

1Why WebSockets Matter

HTTP is request-response only. Client asks, server answers. WebSockets flip this: both sides can send messages anytime. Perfect for chat, notifications, live dashboards, and collaborative apps.

Think of HTTP like sending postcards (one-way), WebSockets like a telephone call (two-way, always connected).

2Setting Up WebSockets in FastAPI

requirements.txt
fastapi==0.104.1
uvicorn[standard]==0.24.0
websockets==12.0
main.py
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import json

app = FastAPI(title="WebSocket API")

# Enable CORS for web clients
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """Simple echo WebSocket endpoint"""
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        print("Client disconnected")
# Client connects
ws = new WebSocket('ws://localhost:8000/ws');

# Send message
ws.send('Hello, server!');

# Receive echo
ws.onmessage = (event) => {
  console.log(event.data); // "Echo: Hello, server!"
};

3Connection Lifecycle & Error Handling

WebSocket connections go through states: connecting, connected, disconnected. Handle each gracefully.

main.py
from datetime import datetime
from typing import Optional

@app.websocket("/ws/lifecycle")
async def websocket_with_logging(websocket: WebSocket):
    """WebSocket with proper lifecycle management"""
    client_id: Optional[str] = None
    
    try:
        # Accept connection
        await websocket.accept()
        client_id = f"client_{id(websocket)}"
        print(f"[{datetime.now()}] {client_id} connected")
        
        while True:
            # Receive message
            data = await websocket.receive_text()
            print(f"[{datetime.now()}] {client_id} sent: {data}")
            
            # Send response
            await websocket.send_text(f"Received: {data}")
    
    except WebSocketDisconnect:
        print(f"[{datetime.now()}] {client_id} disconnected normally")
    
    except Exception as e:
        print(f"[{datetime.now()}] {client_id} error: {str(e)}")
        try:
            await websocket.close(code=1000, reason="Server error")
        except:
            pass  # Connection already closed
💡
Connection Codes: 1000 = normal close, 1001 = going away, 1011 = server error. Use appropriate codes for debugging.

4Broadcasting Messages to Multiple Clients

The real power of WebSockets: send one message to many clients instantly. Store connected clients and broadcast to all.

main.py
from typing import Set

class ConnectionManager:
    def __init__(self):
        self.active_connections: Set[WebSocket] = set()
    
    async def connect(self, websocket: WebSocket):
        """Add a new client connection"""
        await websocket.accept()
        self.active_connections.add(websocket)
    
    async def disconnect(self, websocket: WebSocket):
        """Remove a disconnected client"""
        self.active_connections.discard(websocket)
    
    async def broadcast(self, message: str):
        """Send message to ALL connected clients"""
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception as e:
                print(f"Failed to send: {e}")

manager = ConnectionManager()

@app.websocket("/ws/broadcast")
async def websocket_broadcast(websocket: WebSocket):
    """WebSocket with broadcasting"""
    await manager.connect(websocket)
    client_count = len(manager.active_connections)
    
    try:
        # Notify all clients
        await manager.broadcast(f"New client joined! Total: {client_count}")
        
        while True:
            data = await websocket.receive_text()
            # Broadcast received message
            await manager.broadcast(f"Message: {data}")
    
    except WebSocketDisconnect:
        await manager.disconnect(websocket)
        remaining = len(manager.active_connections)
        await manager.broadcast(f"Client disconnected. Remaining: {remaining}")
Broadcasting Pattern: Store all connections in a set, iterate and send to each. Simple but powerful for real-time updates.

5Building a Real Chat Application

Combine broadcasting with user management to build a proper chat system.

main.py
from pydantic import BaseModel
from datetime import datetime

class ChatMessage(BaseModel):
    username: str
    message: str
    timestamp: str

class ChatManager:
    def __init__(self):
        self.active_connections: dict[str, WebSocket] = {}  # username -> websocket
    
    async def connect(self, username: str, websocket: WebSocket):
        await websocket.accept()
        self.active_connections[username] = websocket
        
        # Notify all
        message = f"{username} joined the chat"
        await self.broadcast_system(message)
    
    async def disconnect(self, username: str):
        self.active_connections.pop(username, None)
        message = f"{username} left the chat"
        await self.broadcast_system(message)
    
    async def broadcast_message(self, username: str, text: str):
        """Broadcast user message"""
        msg = ChatMessage(
            username=username,
            message=text,
            timestamp=datetime.now().isoformat()
        )
        
        for connection in self.active_connections.values():
            try:
                await connection.send_json(msg.model_dump())
            except Exception as e:
                print(f"Broadcast error: {e}")
    
    async def broadcast_system(self, text: str):
        """Broadcast system message"""
        msg = {
            "username": "SYSTEM",
            "message": text,
            "timestamp": datetime.now().isoformat(),
            "is_system": True
        }
        
        for connection in self.active_connections.values():
            try:
                await connection.send_json(msg)
            except Exception:
                pass

chat = ChatManager()

@app.websocket("/ws/chat/{username}")
async def websocket_chat(websocket: WebSocket, username: str):
    """Chat endpoint"""
    await chat.connect(username, websocket)
    
    try:
        while True:
            message = await websocket.receive_text()
            await chat.broadcast_message(username, message)
    
    except WebSocketDisconnect:
        await chat.disconnect(username)

6Sending Structured Data (JSON)

Binary text doesn't scale. Use JSON for structured, typed messages.

main.py
from typing import Literal
from pydantic import BaseModel

class MessageType(BaseModel):
    type: Literal["chat", "ping", "notification", "typing"]

class ChatPayload(BaseModel):
    type: Literal["chat"]
    username: str
    text: str

class TypingPayload(BaseModel):
    type: Literal["typing"]
    username: str

@app.websocket("/ws/json-chat/{username}")
async def json_chat(websocket: WebSocket, username: str):
    """WebSocket with JSON message routing"""
    await websocket.accept()
    
    try:
        while True:
            # Receive JSON
            data = await websocket.receive_json()
            msg_type = data.get("type")
            
            # Route by type
            if msg_type == "chat":
                payload = ChatPayload(**data)
                # Process chat message
                print(f"{payload.username}: {payload.text}")
                await websocket.send_json({
                    "type": "chat_received",
                    "id": id(websocket)
                })
            
            elif msg_type == "typing":
                payload = TypingPayload(**data)
                # Broadcast typing indicator
                await websocket.send_json({
                    "type": "user_typing",
                    "username": payload.username
                })
            
            elif msg_type == "ping":
                await websocket.send_json({"type": "pong"})
    
    except WebSocketDisconnect:
        pass
💡
Message Schema: Always structure messages with a type field for routing. Makes client-side handling much cleaner.

7Live Notifications & Alerts

Push real-time notifications to users without polling.

main.py
from enum import Enum

class NotificationType(Enum):
    INFO = "info"
    WARNING = "warning"
    ERROR = "error"
    SUCCESS = "success"

class NotificationManager:
    def __init__(self):
        self.subscribers: dict[str, list[WebSocket]] = {}  # topic -> websockets
    
    async def subscribe(self, topic: str, websocket: WebSocket):
        """Subscribe to notification topic"""
        await websocket.accept()
        if topic not in self.subscribers:
            self.subscribers[topic] = []
        self.subscribers[topic].append(websocket)
    
    async def publish(self, topic: str, notification: dict):
        """Publish notification to all subscribers"""
        if topic not in self.subscribers:
            return
        
        dead_connections = []
        for connection in self.subscribers[topic]:
            try:
                await connection.send_json(notification)
            except Exception:
                dead_connections.append(connection)
        
        # Clean up dead connections
        for conn in dead_connections:
            self.subscribers[topic].remove(conn)

notif = NotificationManager()

@app.websocket("/ws/notify/{topic}")
async def websocket_notify(websocket: WebSocket, topic: str):
    """Subscribe to topic notifications"""
    await notif.subscribe(topic, websocket)
    try:
        while True:
            # Keep connection alive
            await websocket.receive_text()
    except WebSocketDisconnect:
        pass

@app.post("/notify/{topic}")
async def send_notification(topic: str, title: str, message: str):
    """Send notification to topic subscribers"""
    notification = {
        "type": "notification",
        "title": title,
        "message": message,
        "timestamp": datetime.now().isoformat()
    }
    await notif.publish(topic, notification)
    return {"sent": len(notif.subscribers.get(topic, []))}
# Client subscribes to "orders" topic
ws = new WebSocket('ws://localhost:8000/ws/notify/orders');

# Server publishes notification
curl -X POST http://localhost:8000/notify/orders \
  -d "title=New Order&message=Order #123 received"

# Client receives in real-time
ws.onmessage = (event) => {
  const notif = JSON.parse(event.data);
  console.log(notif.title); // "New Order"
};

8WebSocket Authentication & Authorization

Secure WebSockets with authentication. Validate tokens before accepting connections.

main.py
from fastapi import Query, status
import jwt
from datetime import datetime, timedelta

SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"

def verify_websocket_token(token: str) -> dict:
    """Verify JWT token from query parameter"""
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        return payload
    except jwt.InvalidTokenError:
        return None

@app.websocket("/ws/secure")
async def websocket_secure(
    websocket: WebSocket,
    token: str = Query(...)
):
    """Secure WebSocket endpoint"""
    # Verify token before accepting
    user_data = verify_websocket_token(token)
    if not user_data:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Invalid token")
        return
    
    user_id = user_data.get("user_id")
    await websocket.accept()
    
    try:
        print(f"Authenticated user {user_id} connected")
        while True:
            data = await websocket.receive_text()
            print(f"User {user_id}: {data}")
    
    except WebSocketDisconnect:
        print(f"User {user_id} disconnected")
⚠️
Never accept unauthenticated WebSockets in production. Always verify credentials before accept().

9Handling Reconnections & Heartbeats

Networks fail. Implement ping/pong to detect dead connections and auto-reconnect clients.

main.py
import asyncio

@app.websocket("/ws/heartbeat")
async def websocket_heartbeat(websocket: WebSocket):
    """WebSocket with heartbeat"""
    await websocket.accept()
    
    # Heartbeat task
    async def send_heartbeat():
        while True:
            try:
                await websocket.send_json({
                    "type": "ping",
                    "timestamp": datetime.now().isoformat()
                })
                await asyncio.sleep(30)  # Every 30 seconds
            except Exception:
                break
    
    # Start heartbeat in background
    heartbeat_task = asyncio.create_task(send_heartbeat())
    
    try:
        while True:
            message = await websocket.receive_json()
            
            # Respond to pong
            if message.get("type") == "pong":
                print("Client alive")
            else:
                # Process message
                await websocket.send_json({
                    "type": "ack",
                    "id": message.get("id")
                })
    
    except WebSocketDisconnect:
        heartbeat_task.cancel()
client.html
<script>
class WebSocketClient {
  constructor(url, reconnectInterval = 3000) {
    this.url = url;
    this.reconnectInterval = reconnectInterval;
    this.connect();
  }
  
  connect() {
    this.ws = new WebSocket(this.url);
    
    this.ws.onopen = () => {
      console.log("Connected");
      this.reconnectAttempts = 0;
    };
    
    this.ws.onmessage = (event) => {
      const message = JSON.parse(event.data);
      
      // Handle heartbeat
      if (message.type === "ping") {
        this.ws.send(JSON.stringify({ type: "pong" }));
      }
    };
    
    this.ws.onclose = () => {
      console.log("Disconnected, reconnecting...");
      setTimeout(() => this.connect(), this.reconnectInterval);
    };
  }
}

// Usage
const client = new WebSocketClient("ws://localhost:8000/ws/heartbeat");
</script>

10Rate Limiting WebSocket Messages

Prevent spam and DoS attacks by limiting message rates per user.

main.py
from collections import defaultdict
from time import time

class RateLimiter:
    def __init__(self, max_messages: int, window_seconds: int):
        self.max_messages = max_messages
        self.window_seconds = window_seconds
        self.messages: dict[str, list[float]] = defaultdict(list)
    
    def is_allowed(self, user_id: str) -> bool:
        """Check if user exceeded rate limit"""
        now = time()
        window_start = now - self.window_seconds
        
        # Remove old messages outside window
        self.messages[user_id] = [
            t for t in self.messages[user_id]
            if t > window_start
        ]
        
        # Check if over limit
        if len(self.messages[user_id]) >= self.max_messages:
            return False
        
        # Record this message
        self.messages[user_id].append(now)
        return True

limiter = RateLimiter(max_messages=10, window_seconds=60)  # 10/minute

@app.websocket("/ws/ratelimited/{user_id}")
async def websocket_ratelimited(websocket: WebSocket, user_id: str):
    await websocket.accept()
    
    try:
        while True:
            data = await websocket.receive_text()
            
            # Check rate limit
            if not limiter.is_allowed(user_id):
                await websocket.send_json({
                    "type": "error",
                    "message": "Rate limit exceeded"
                })
                continue
            
            # Process message
            await websocket.send_json({
                "type": "received",
                "data": data
            })
    
    except WebSocketDisconnect:
        pass

11Testing WebSocket Endpoints

test_websocket.py
import pytest
from fastapi.testclient import TestClient

client = TestClient(app)

def test_simple_websocket():
    """Test basic WebSocket communication"""
    with client.websocket_connect("/ws") as websocket:
        # Send message
        websocket.send_text("Hello")
        
        # Receive echo
        data = websocket.receive_text()
        assert data == "Echo: Hello"

def test_broadcast():
    """Test message broadcasting"""
    with client.websocket_connect("/ws/broadcast") as ws1:
        with client.websocket_connect("/ws/broadcast") as ws2:
            # ws1 sends message
            ws1.send_text("Broadcast test")
            
            # Both should receive
            assert "Broadcast test" in ws2.receive_text()

def test_json_messages():
    """Test JSON message handling"""
    with client.websocket_connect("/ws/json-chat/testuser") as websocket:
        # Send JSON message
        websocket.send_json({
            "type": "ping"
        })
        
        # Receive response
        data = websocket.receive_json()
        assert data["type"] == "pong"

def test_chat():
    """Test chat functionality"""
    with client.websocket_connect("/ws/chat/alice") as alice:
        with client.websocket_connect("/ws/chat/bob") as bob:
            # Alice sends message
            alice.send_text("Hi Bob!")
            
            # Bob receives
            message = bob.receive_json()
            assert message["username"] == "alice"
            assert message["text"] == "Hi Bob!"

12Production Patterns & Scaling

Local connections work for development. Scale with pub/sub systems for multiple servers.

main.py
import aioredis
from typing import Optional

class RedisConnectionManager:
    """Use Redis for multi-server broadcasting"""
    
    def __init__(self):
        self.active_connections: dict[str, WebSocket] = {}
        self.redis: Optional[aioredis.Redis] = None
    
    async def startup(self):
        """Initialize Redis connection"""
        self.redis = await aioredis.create_redis_pool("redis://localhost")
    
    async def shutdown(self):
        """Close Redis connection"""
        if self.redis:
            self.redis.close()
            await self.redis.wait_closed()
    
    async def connect(self, client_id: str, websocket: WebSocket):
        await websocket.accept()
        self.active_connections[client_id] = websocket
    
    async def broadcast(self, message: dict):
        """Broadcast via Redis pub/sub"""
        # Store in Redis for all servers to see
        await self.redis.publish("broadcast", json.dumps(message))
    
    async def local_send(self, client_id: str, message: dict):
        """Send to local client"""
        if client_id in self.active_connections:
            await self.active_connections[client_id].send_json(message)

manager = RedisConnectionManager()

@app.on_event("startup")
async def startup():
    await manager.startup()

@app.on_event("shutdown")
async def shutdown():
    await manager.shutdown()
💡
Scaling Strategy: Use Redis pub/sub for broadcasting across multiple FastAPI instances. Each server maintains its own client connections but receives broadcasts from Redis.

13WebSocket Deployment Considerations

AspectLocal DevelopmentProduction
ServerUvicorn single processMultiple Uvicorn processes + Nginx
BroadcastingIn-memory setRedis pub/sub
PersistenceNoneDatabase + Redis
MonitoringLogsPrometheus + Grafana
CORSAllow allSpecific origins
TimeoutNo limitConfigurable (30-60s)
Load BalancingN/ASticky sessions or Redis

14Common Pitfalls & Solutions

⚠️
Pitfall 1 - Memory Leak: Forgetting to remove disconnected clients from connections set. Always use try/except to clean up.
⚠️
Pitfall 2 - Lost Connections: Nginx/load balancers may time out idle connections. Implement heartbeat/ping-pong to keep connections alive.
⚠️
Pitfall 3 - Unauthenticated Access: Always verify tokens before accepting WebSocket connections, not after.
⚠️
Pitfall 4 - Unhandled Exceptions: If a broadcast fails for one client, don't crash for others. Wrap send calls in try/except.

15What's Next

Congratulations! You've mastered WebSocket programming in FastAPI. You can now:
  • Accept WebSocket connections and handle lifecycle
  • Broadcast messages to multiple clients
  • Build real-time chat and notification systems
  • Implement authentication and rate limiting
  • Handle reconnections with heartbeats
  • Scale across multiple servers with Redis

Next Topics to Explore:

  • Collaborative Editing: Combine WebSockets with operational transformation (OT) for real-time document editing like Google Docs.
  • Live Data Streaming: Push live updates from databases using WebSockets for stock tickers, sensor data, etc.
  • WebSocket Rooms/Namespaces: Organize connections into rooms for targeted broadcasting (similar to Socket.IO).
  • Binary Protocol: Use MessagePack or Protocol Buffers for efficient binary communication over WebSockets.

About the Author

TG

Thirdy Gayares

Passionate developer creating custom solutions for everyone. I specialize in building user-friendly tools that solve real-world problems while maintaining the highest standards of security and privacy.