Building scalable real-time applications requires careful consideration of WebSocket implementation, especially in a multi-node environment. In this guide, we'll explore how to create a robust WebSocket architecture using NestJS and Socket.IO with Redis adapter for horizontal scaling.
Setting Up WebSocket Gateway
// chat/chat.gateway.ts
import {
WebSocketGateway,
WebSocketServer,
SubscribeMessage,
OnGatewayConnection,
OnGatewayDisconnect
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { createAdapter } from '@socket.io/redis-adapter';
import { Redis } from 'ioredis';
@WebSocketGateway({
cors: {
origin: process.env.FRONTEND_URL,
credentials: true
},
namespace: '/chat'
})
export class ChatGateway implements OnGatewayConnection, OnGatewayDisconnect {
@WebSocketServer() server: Server;
private pubClient: Redis;
private subClient: Redis;
constructor() {
this.setupRedisAdapter();
}
private async setupRedisAdapter() {
this.pubClient = new Redis({
host: process.env.REDIS_HOST,
port: parseInt(process.env.REDIS_PORT),
password: process.env.REDIS_PASSWORD
});
this.subClient = this.pubClient.duplicate();
// Wait for Redis connections
await Promise.all([
new Promise(resolve => this.pubClient.on('connect', resolve)),
new Promise(resolve => this.subClient.on('connect', resolve))
]);
// Set up Redis adapter
this.server.adapter(createAdapter(this.pubClient, this.subClient));
}
async handleConnection(client: Socket) {
const userId = this.getUserIdFromToken(client.handshake.auth.token);
if (!userId) {
client.disconnect();
return;
}
await this.joinUserRooms(client, userId);
this.server.emit('userConnected', { userId });
}
async handleDisconnect(client: Socket) {
const userId = this.getUserIdFromToken(client.handshake.auth.token);
if (userId) {
this.server.emit('userDisconnected', { userId });
}
}
@SubscribeMessage('sendMessage')
async handleMessage(client: Socket, payload: {
roomId: string;
message: string;
}) {
const userId = this.getUserIdFromToken(client.handshake.auth.token);
if (!userId) return;
const message = {
id: generateUniqueId(),
userId,
roomId: payload.roomId,
content: payload.message,
timestamp: new Date()
};
await this.saveMessage(message);
this.server.to(payload.roomId).emit('newMessage', message);
}
private async joinUserRooms(client: Socket, userId: string) {
const rooms = await this.getUserRooms(userId);
rooms.forEach(room => client.join(room.id));
}
private async saveMessage(message: any) {
// Implement message persistence logic
}
private async getUserRooms(userId: string) {
// Implement room retrieval logic
return [];
}
private getUserIdFromToken(token: string): string | null {
// Implement token verification logic
return null;
}
}
Implementing Room Management
// chat/room.service.ts
@Injectable()
export class RoomService {
constructor(
@InjectRepository(Room)
private roomRepository: Repository<Room>,
private readonly redis: Redis
) {}
async createRoom(data: CreateRoomDto) {
const room = await this.roomRepository.save({
...data,
id: generateUniqueId()
});
// Store room info in Redis for quick access
await this.redis.hset(
`room:${room.id}`,
{
name: room.name,
type: room.type,
createdAt: room.createdAt.toISOString()
}
);
return room;
}
async addUserToRoom(roomId: string, userId: string) {
await this.roomRepository
.createQueryBuilder()
.relation(Room, 'members')
.of(roomId)
.add(userId);
// Add user to room in Redis
await this.redis.sadd(`room:${roomId}:members`, userId);
// Notify room members
const gateway = this.moduleRef.get(ChatGateway);
gateway.server.to(roomId).emit('userJoinedRoom', { roomId, userId });
}
async getRoomMembers(roomId: string): Promise<string[]> {
// Try Redis first
const members = await this.redis.smembers(`room:${roomId}:members`);
if (members.length > 0) {
return members;
}
// Fallback to database
const room = await this.roomRepository.findOne({
where: { id: roomId },
relations: ['members']
});
const memberIds = room.members.map(member => member.id);
// Cache in Redis
await this.redis.sadd(`room:${roomId}:members`, ...memberIds);
return memberIds;
}
}
Implementing Message Broadcasting
// chat/chat.service.ts
@Injectable()
export class ChatService {
constructor(
@InjectRepository(Message)
private messageRepository: Repository<Message>,
private readonly redis: Redis
) {}
async broadcastToRoom(roomId: string, event: string, data: any) {
const gateway = this.moduleRef.get(ChatGateway);
gateway.server.to(roomId).emit(event, data);
// Store event in Redis for recovery/history
await this.redis.lpush(
`room:${roomId}:events`,
JSON.stringify({
event,
data,
timestamp: Date.now()
})
);
// Trim event history to last 100 events
await this.redis.ltrim(`room:${roomId}:events`, 0, 99);
}
async getLatestMessages(roomId: string, limit: number = 50) {
const messages = await this.messageRepository.find({
where: { roomId },
order: { timestamp: 'DESC' },
take: limit
});
// Cache messages in Redis
await this.redis.setex(
`room:${roomId}:messages`,
3600, // 1 hour
JSON.stringify(messages)
);
return messages;
}
}
Implementing Health Checks and Monitoring
// chat/health.service.ts
@Injectable()
export class WebSocketHealthService {
constructor(
private readonly redis: Redis,
@Inject(forwardRef(() => ChatGateway))
private readonly chatGateway: ChatGateway
) {}
async checkHealth() {
const nodeId = process.env.NODE_ID;
const timestamp = Date.now();
// Update node heartbeat
await this.redis.hset(
'ws:nodes',
nodeId,
JSON.stringify({
timestamp,
connections: this.getConnectionCount(),
memory: process.memoryUsage()
})
);
// Clean up stale nodes
await this.cleanupStaleNodes();
}
private getConnectionCount(): number {
return this.chatGateway.server.engine.clientsCount;
}
private async cleanupStaleNodes() {
const nodes = await this.redis.hgetall('ws:nodes');
const staleThreshold = Date.now() - 30000; // 30 seconds
for (const [nodeId, data] of Object.entries(nodes)) {
const nodeData = JSON.parse(data);
if (nodeData.timestamp < staleThreshold) {
await this.redis.hdel('ws:nodes', nodeId);
}
}
}
}
Load Balancing Configuration
// nginx.conf
http {
upstream websocket_nodes {
# Enable sticky sessions based on IP
ip_hash;
server ws1.example.com:3000;
server ws2.example.com:3000;
server ws3.example.com:3000;
}
server {
listen 80;
server_name ws.example.com;
location /socket.io/ {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header Host $host;
proxy_pass http://websocket_nodes;
# WebSocket specific settings
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_cache_bypass $http_upgrade;
# Timeouts
proxy_read_timeout 600s;
proxy_send_timeout 600s;
}
}
}
Scaling WebSocket Connections
When scaling WebSocket applications across multiple nodes, we need to handle several challenges including session management, broadcasting, and connection state synchronization.
// chat/scaling.service.ts
@Injectable()
export class ScalingService {
constructor(
private readonly redis: Redis,
private readonly chatGateway: ChatGateway
) {
this.setupNodeDiscovery();
}
private async setupNodeDiscovery() {
const nodeId = process.env.NODE_ID;
// Register node
await this.redis.sadd('ws:active_nodes', nodeId);
// Handle node shutdown
process.on('SIGTERM', async () => {
await this.redis.srem('ws:active_nodes', nodeId);
await this.redis.del(`ws:node:${nodeId}:connections`);
process.exit(0);
});
// Setup periodic health check
setInterval(async () => {
await this.redis.setex(
`ws:node:${nodeId}:heartbeat`,
30, // TTL 30 seconds
Date.now().toString()
);
}, 15000);
}
async handleGlobalBroadcast(event: string, data: any) {
// Publish to Redis for other nodes
await this.redis.publish('ws:broadcast',
JSON.stringify({ event, data })
);
// Emit to local connections
this.chatGateway.server.emit(event, data);
}
async getGlobalConnectionStats(): Promise<ConnectionStats> {
const nodes = await this.redis.smembers('ws:active_nodes');
const stats = await Promise.all(
nodes.map(async nodeId => {
const connections = await this.redis.get(
`ws:node:${nodeId}:connections`
);
return {
nodeId,
connections: parseInt(connections || '0')
};
})
);
return {
totalNodes: nodes.length,
totalConnections: stats.reduce(
(sum, node) => sum + node.connections,
0
),
nodesStats: stats
};
}
}
Implementing Failover and Recovery
// chat/recovery.service.ts
@Injectable()
export class RecoveryService {
constructor(
private readonly redis: Redis,
private readonly chatGateway: ChatGateway
) {}
async handleNodeFailure(failedNodeId: string) {
// Get all rooms affected by the node failure
const affectedRooms = await this.redis.smembers(
`ws:node:${failedNodeId}:rooms`
);
// Redistribute room responsibilities
for (const roomId of affectedRooms) {
await this.redistributeRoom(roomId);
}
// Clean up failed node data
await this.redis.del(`ws:node:${failedNodeId}:rooms`);
await this.redis.srem('ws:active_nodes', failedNodeId);
}
private async redistributeRoom(roomId: string) {
const members = await this.redis.smembers(
`room:${roomId}:members`
);
// Notify members to reconnect
this.chatGateway.server.to(roomId).emit('reconnectRequired', {
roomId,
reason: 'node_failure'
});
// Update room state
await this.ensureRoomStateConsistency(roomId, members);
}
private async ensureRoomStateConsistency(
roomId: string,
members: string[]
) {
// Implement room state recovery logic
// This might involve re-establishing subscriptions,
// recovering missed messages, etc.
}
}
Performance Monitoring and Optimization
// chat/monitoring.service.ts
@Injectable()
export class MonitoringService {
constructor(
private readonly redis: Redis,
private readonly chatGateway: ChatGateway
) {
this.setupMetrics();
}
private setupMetrics() {
const metrics = new Metrics();
// Monitor connection count
setInterval(() => {
const count = this.chatGateway.server.engine.clientsCount;
metrics.gauge('ws_connections', count);
}, 5000);
// Monitor message rate
this.chatGateway.server.on('message', () => {
metrics.increment('ws_messages_total');
});
// Monitor Redis adapter latency
setInterval(async () => {
const start = Date.now();
await this.redis.ping();
const latency = Date.now() - start;
metrics.gauge('redis_adapter_latency', latency);
}, 10000);
}
async getPerformanceStats() {
const nodeId = process.env.NODE_ID;
const stats = {
connections: this.chatGateway.server.engine.clientsCount,
memory: process.memoryUsage(),
cpu: process.cpuUsage(),
uptime: process.uptime(),
roomCount: await this.getRoomCount(),
messageRate: await this.getMessageRate()
};
// Store stats in Redis
await this.redis.hset(
'ws:stats',
nodeId,
JSON.stringify(stats)
);
return stats;
}
private async getRoomCount(): Promise<number> {
const rooms = await this.chatGateway.server.sockets.adapter.rooms;
return rooms.size;
}
private async getMessageRate(): Promise<number> {
const now = Date.now();
const minute = 60 * 1000;
const count = await this.redis.zcount(
'ws:message_timestamps',
now - minute,
now
);
return count / 60; // Messages per second
}
}
Best Practices and Considerations
- Implement proper authentication and authorization
- Handle connection state management across nodes
- Use Redis adapter for horizontal scaling
- Implement retry mechanisms for failed operations
- Monitor WebSocket server health and performance
- Handle backpressure and rate limiting
- Implement proper error handling and logging
- Use proper load balancing with sticky sessions
By following these patterns and implementing proper scaling strategies, you can build a robust WebSocket architecture that handles real-time communication efficiently at scale. Remember to regularly monitor your system's performance and adjust your implementation based on actual usage patterns and requirements.