2024-02-01
14 min read

Scalable WebSocket Architecture with NestJS and Socket.IO

Learn how to build a scalable WebSocket architecture using NestJS and Socket.IO, including multi-node setup, Redis adapter integration, and best practices for handling real-time communication at scale.

NestJS
WebSockets
Socket.IO
Redis
Scaling
Real-time

Building scalable real-time applications requires careful consideration of WebSocket implementation, especially in a multi-node environment. In this guide, we'll explore how to create a robust WebSocket architecture using NestJS and Socket.IO with Redis adapter for horizontal scaling.

Setting Up WebSocket Gateway

typescript
// chat/chat.gateway.ts
import {
  WebSocketGateway,
  WebSocketServer,
  SubscribeMessage,
  OnGatewayConnection,
  OnGatewayDisconnect
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { createAdapter } from '@socket.io/redis-adapter';
import { Redis } from 'ioredis';

@WebSocketGateway({
  cors: {
    origin: process.env.FRONTEND_URL,
    credentials: true
  },
  namespace: '/chat'
})
export class ChatGateway implements OnGatewayConnection, OnGatewayDisconnect {
  @WebSocketServer() server: Server;
  private pubClient: Redis;
  private subClient: Redis;

  constructor() {
    this.setupRedisAdapter();
  }

  private async setupRedisAdapter() {
    this.pubClient = new Redis({
      host: process.env.REDIS_HOST,
      port: parseInt(process.env.REDIS_PORT),
      password: process.env.REDIS_PASSWORD
    });

    this.subClient = this.pubClient.duplicate();

    // Wait for Redis connections
    await Promise.all([
      new Promise(resolve => this.pubClient.on('connect', resolve)),
      new Promise(resolve => this.subClient.on('connect', resolve))
    ]);

    // Set up Redis adapter
    this.server.adapter(createAdapter(this.pubClient, this.subClient));
  }

  async handleConnection(client: Socket) {
    const userId = this.getUserIdFromToken(client.handshake.auth.token);
    if (!userId) {
      client.disconnect();
      return;
    }

    await this.joinUserRooms(client, userId);
    this.server.emit('userConnected', { userId });
  }

  async handleDisconnect(client: Socket) {
    const userId = this.getUserIdFromToken(client.handshake.auth.token);
    if (userId) {
      this.server.emit('userDisconnected', { userId });
    }
  }

  @SubscribeMessage('sendMessage')
  async handleMessage(client: Socket, payload: {
    roomId: string;
    message: string;
  }) {
    const userId = this.getUserIdFromToken(client.handshake.auth.token);
    if (!userId) return;

    const message = {
      id: generateUniqueId(),
      userId,
      roomId: payload.roomId,
      content: payload.message,
      timestamp: new Date()
    };

    await this.saveMessage(message);
    this.server.to(payload.roomId).emit('newMessage', message);
  }

  private async joinUserRooms(client: Socket, userId: string) {
    const rooms = await this.getUserRooms(userId);
    rooms.forEach(room => client.join(room.id));
  }

  private async saveMessage(message: any) {
    // Implement message persistence logic
  }

  private async getUserRooms(userId: string) {
    // Implement room retrieval logic
    return [];
  }

  private getUserIdFromToken(token: string): string | null {
    // Implement token verification logic
    return null;
  }
}

Implementing Room Management

typescript
// chat/room.service.ts
@Injectable()
export class RoomService {
  constructor(
    @InjectRepository(Room)
    private roomRepository: Repository<Room>,
    private readonly redis: Redis
  ) {}

  async createRoom(data: CreateRoomDto) {
    const room = await this.roomRepository.save({
      ...data,
      id: generateUniqueId()
    });

    // Store room info in Redis for quick access
    await this.redis.hset(
      `room:${room.id}`,
      {
        name: room.name,
        type: room.type,
        createdAt: room.createdAt.toISOString()
      }
    );

    return room;
  }

  async addUserToRoom(roomId: string, userId: string) {
    await this.roomRepository
      .createQueryBuilder()
      .relation(Room, 'members')
      .of(roomId)
      .add(userId);

    // Add user to room in Redis
    await this.redis.sadd(`room:${roomId}:members`, userId);

    // Notify room members
    const gateway = this.moduleRef.get(ChatGateway);
    gateway.server.to(roomId).emit('userJoinedRoom', { roomId, userId });
  }

  async getRoomMembers(roomId: string): Promise<string[]> {
    // Try Redis first
    const members = await this.redis.smembers(`room:${roomId}:members`);
    if (members.length > 0) {
      return members;
    }

    // Fallback to database
    const room = await this.roomRepository.findOne({
      where: { id: roomId },
      relations: ['members']
    });

    const memberIds = room.members.map(member => member.id);

    // Cache in Redis
    await this.redis.sadd(`room:${roomId}:members`, ...memberIds);

    return memberIds;
  }
}

Implementing Message Broadcasting

typescript
// chat/chat.service.ts
@Injectable()
export class ChatService {
  constructor(
    @InjectRepository(Message)
    private messageRepository: Repository<Message>,
    private readonly redis: Redis
  ) {}

  async broadcastToRoom(roomId: string, event: string, data: any) {
    const gateway = this.moduleRef.get(ChatGateway);
    gateway.server.to(roomId).emit(event, data);

    // Store event in Redis for recovery/history
    await this.redis.lpush(
      `room:${roomId}:events`,
      JSON.stringify({
        event,
        data,
        timestamp: Date.now()
      })
    );
    
    // Trim event history to last 100 events
    await this.redis.ltrim(`room:${roomId}:events`, 0, 99);
  }

  async getLatestMessages(roomId: string, limit: number = 50) {
    const messages = await this.messageRepository.find({
      where: { roomId },
      order: { timestamp: 'DESC' },
      take: limit
    });

    // Cache messages in Redis
    await this.redis.setex(
      `room:${roomId}:messages`,
      3600, // 1 hour
      JSON.stringify(messages)
    );

    return messages;
  }
}

Implementing Health Checks and Monitoring

typescript
// chat/health.service.ts
@Injectable()
export class WebSocketHealthService {
  constructor(
    private readonly redis: Redis,
    @Inject(forwardRef(() => ChatGateway))
    private readonly chatGateway: ChatGateway
  ) {}

  async checkHealth() {
    const nodeId = process.env.NODE_ID;
    const timestamp = Date.now();

    // Update node heartbeat
    await this.redis.hset(
      'ws:nodes',
      nodeId,
      JSON.stringify({
        timestamp,
        connections: this.getConnectionCount(),
        memory: process.memoryUsage()
      })
    );

    // Clean up stale nodes
    await this.cleanupStaleNodes();
  }

  private getConnectionCount(): number {
    return this.chatGateway.server.engine.clientsCount;
  }

  private async cleanupStaleNodes() {
    const nodes = await this.redis.hgetall('ws:nodes');
    const staleThreshold = Date.now() - 30000; // 30 seconds

    for (const [nodeId, data] of Object.entries(nodes)) {
      const nodeData = JSON.parse(data);
      if (nodeData.timestamp < staleThreshold) {
        await this.redis.hdel('ws:nodes', nodeId);
      }
    }
  }
}

Load Balancing Configuration

typescript
// nginx.conf
http {
  upstream websocket_nodes {
    # Enable sticky sessions based on IP
    ip_hash;
    
    server ws1.example.com:3000;
    server ws2.example.com:3000;
    server ws3.example.com:3000;
  }

  server {
    listen 80;
    server_name ws.example.com;

    location /socket.io/ {
      proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
      proxy_set_header Host $host;

      proxy_pass http://websocket_nodes;

      # WebSocket specific settings
      proxy_http_version 1.1;
      proxy_set_header Upgrade $http_upgrade;
      proxy_set_header Connection 'upgrade';
      proxy_cache_bypass $http_upgrade;

      # Timeouts
      proxy_read_timeout 600s;
      proxy_send_timeout 600s;
    }
  }
}

Scaling WebSocket Connections

When scaling WebSocket applications across multiple nodes, we need to handle several challenges including session management, broadcasting, and connection state synchronization.

typescript
// chat/scaling.service.ts
@Injectable()
export class ScalingService {
  constructor(
    private readonly redis: Redis,
    private readonly chatGateway: ChatGateway
  ) {
    this.setupNodeDiscovery();
  }

  private async setupNodeDiscovery() {
    const nodeId = process.env.NODE_ID;

    // Register node
    await this.redis.sadd('ws:active_nodes', nodeId);

    // Handle node shutdown
    process.on('SIGTERM', async () => {
      await this.redis.srem('ws:active_nodes', nodeId);
      await this.redis.del(`ws:node:${nodeId}:connections`);
      process.exit(0);
    });

    // Setup periodic health check
    setInterval(async () => {
      await this.redis.setex(
        `ws:node:${nodeId}:heartbeat`,
        30, // TTL 30 seconds
        Date.now().toString()
      );
    }, 15000);
  }

  async handleGlobalBroadcast(event: string, data: any) {
    // Publish to Redis for other nodes
    await this.redis.publish('ws:broadcast', 
      JSON.stringify({ event, data })
    );

    // Emit to local connections
    this.chatGateway.server.emit(event, data);
  }

  async getGlobalConnectionStats(): Promise<ConnectionStats> {
    const nodes = await this.redis.smembers('ws:active_nodes');
    const stats = await Promise.all(
      nodes.map(async nodeId => {
        const connections = await this.redis.get(
          `ws:node:${nodeId}:connections`
        );
        return {
          nodeId,
          connections: parseInt(connections || '0')
        };
      })
    );

    return {
      totalNodes: nodes.length,
      totalConnections: stats.reduce(
        (sum, node) => sum + node.connections, 
        0
      ),
      nodesStats: stats
    };
  }
}

Implementing Failover and Recovery

typescript
// chat/recovery.service.ts
@Injectable()
export class RecoveryService {
  constructor(
    private readonly redis: Redis,
    private readonly chatGateway: ChatGateway
  ) {}

  async handleNodeFailure(failedNodeId: string) {
    // Get all rooms affected by the node failure
    const affectedRooms = await this.redis.smembers(
      `ws:node:${failedNodeId}:rooms`
    );

    // Redistribute room responsibilities
    for (const roomId of affectedRooms) {
      await this.redistributeRoom(roomId);
    }

    // Clean up failed node data
    await this.redis.del(`ws:node:${failedNodeId}:rooms`);
    await this.redis.srem('ws:active_nodes', failedNodeId);
  }

  private async redistributeRoom(roomId: string) {
    const members = await this.redis.smembers(
      `room:${roomId}:members`
    );

    // Notify members to reconnect
    this.chatGateway.server.to(roomId).emit('reconnectRequired', {
      roomId,
      reason: 'node_failure'
    });

    // Update room state
    await this.ensureRoomStateConsistency(roomId, members);
  }

  private async ensureRoomStateConsistency(
    roomId: string, 
    members: string[]
  ) {
    // Implement room state recovery logic
    // This might involve re-establishing subscriptions,
    // recovering missed messages, etc.
  }
}

Performance Monitoring and Optimization

typescript
// chat/monitoring.service.ts
@Injectable()
export class MonitoringService {
  constructor(
    private readonly redis: Redis,
    private readonly chatGateway: ChatGateway
  ) {
    this.setupMetrics();
  }

  private setupMetrics() {
    const metrics = new Metrics();
    
    // Monitor connection count
    setInterval(() => {
      const count = this.chatGateway.server.engine.clientsCount;
      metrics.gauge('ws_connections', count);
    }, 5000);

    // Monitor message rate
    this.chatGateway.server.on('message', () => {
      metrics.increment('ws_messages_total');
    });

    // Monitor Redis adapter latency
    setInterval(async () => {
      const start = Date.now();
      await this.redis.ping();
      const latency = Date.now() - start;
      metrics.gauge('redis_adapter_latency', latency);
    }, 10000);
  }

  async getPerformanceStats() {
    const nodeId = process.env.NODE_ID;
    const stats = {
      connections: this.chatGateway.server.engine.clientsCount,
      memory: process.memoryUsage(),
      cpu: process.cpuUsage(),
      uptime: process.uptime(),
      roomCount: await this.getRoomCount(),
      messageRate: await this.getMessageRate()
    };

    // Store stats in Redis
    await this.redis.hset(
      'ws:stats',
      nodeId,
      JSON.stringify(stats)
    );

    return stats;
  }

  private async getRoomCount(): Promise<number> {
    const rooms = await this.chatGateway.server.sockets.adapter.rooms;
    return rooms.size;
  }

  private async getMessageRate(): Promise<number> {
    const now = Date.now();
    const minute = 60 * 1000;
    const count = await this.redis.zcount(
      'ws:message_timestamps',
      now - minute,
      now
    );
    return count / 60; // Messages per second
  }
}

Best Practices and Considerations

  • Implement proper authentication and authorization
  • Handle connection state management across nodes
  • Use Redis adapter for horizontal scaling
  • Implement retry mechanisms for failed operations
  • Monitor WebSocket server health and performance
  • Handle backpressure and rate limiting
  • Implement proper error handling and logging
  • Use proper load balancing with sticky sessions

By following these patterns and implementing proper scaling strategies, you can build a robust WebSocket architecture that handles real-time communication efficiently at scale. Remember to regularly monitor your system's performance and adjust your implementation based on actual usage patterns and requirements.