Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 101 additions & 38 deletions websocket/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ const redisPublisher = new Redis(process.env.REDIS_URL!);
const cooldownCache = new Map<string, number>();
const COOLDOWN_KEY_PREFIX = 'cooldown:edit:';
const HEARTBEAT_INTERVAL = 30_000;
const COOLDOWN_MS = 30_000;
const MAX_MESSAGE_LENGTH = 2000;

[redisSubscriber, redisPublisher].forEach((r) => {
r.on('error', (err) => {
Expand All @@ -20,13 +22,24 @@ const HEARTBEAT_INTERVAL = 30_000;
});
});

redisSubscriber.on('message', (channel, msg) => {
if (!channel.startsWith('updates:')) return;
const articleId = channel.substring(8);
const subscribedArticles = new Set<string>();

Array.from(articleUsers.get(articleId) || [])
.flatMap((userId) => Array.from(userSockets.get(userId) || []))
.forEach((socket) => socket.send(msg));
redisSubscriber.on('message', (channel, msg) => {
try {
if (!channel.startsWith('updates:')) return;
const articleId = channel.substring(8);
Array.from(articleUsers.get(articleId) || [])
.flatMap((userId) => Array.from(userSockets.get(userId) || []))
.forEach((socket) => {
try {
socket.send(msg);
} catch (e) {
console.error('Failed to send redis update to socket', e);
}
});
} catch (e) {
console.error('Error handling redis message:', e);
}
});

type WebSocketData = {
Expand Down Expand Up @@ -105,6 +118,14 @@ async function handleSetArticle(
prevArticle.delete(userId);
if (prevArticle.size === 0) {
articleUsers.delete(ws.data.articleId);
if (subscribedArticles.has(ws.data.articleId)) {
try {
await redisSubscriber.unsubscribe('updates:' + ws.data.articleId);
} catch (e) {
console.error('Failed to unsubscribe from redis channel:', e);
}
subscribedArticles.delete(ws.data.articleId);
}
}
await broadcastUserCount(ws.data.articleId);
}
Expand All @@ -122,7 +143,14 @@ async function handleSetArticle(

if (!articleUsers.has(articleId)) {
articleUsers.set(articleId, new Set([userId]));
await redisSubscriber.subscribe('updates:' + articleId);
if (!subscribedArticles.has(articleId)) {
try {
await redisSubscriber.subscribe('updates:' + articleId);
subscribedArticles.add(articleId);
} catch (e) {
console.error('Failed to subscribe to redis channel:', e);
}
}
await broadcastUserCount(articleId);
} else {
const wasAdded = !articleUsers.get(articleId)!.has(userId);
Expand Down Expand Up @@ -282,19 +310,38 @@ async function handleWordHover(
return;
}

await redisPublisher.publish(
`updates:${articleId}`,
JSON.stringify({
type: 'word_hover',
try {
await redisPublisher.publish(
`updates:${articleId}`,
JSON.stringify({
type: 'word_hover',
data: {
newWord,
wordIndex,
editorId: userId,
editorName: ws.data.user!.name,
editorImage: ws.data.user!.image,
}
})
);
const expiryTime = Date.now() + COOLDOWN_MS;
const cooldownKey = `${COOLDOWN_KEY_PREFIX}${userId}`;
try {
await redisPublisher.set(cooldownKey, expiryTime.toString(), 'PX', COOLDOWN_MS);
cooldownCache.set(userId, expiryTime);
} catch (e) {
console.error('Failed to set cooldown in redis:', e);
}
} catch (e) {
console.error('Failed to publish word_hover:', e);
ws.send(JSON.stringify({
type: 'error',
data: {
newWord,
wordIndex,
editorId: userId,
editorName: ws.data.user!.name,
editorImage: ws.data.user!.image,
code: 'PUBLISH_FAILED',
message: 'Failed to broadcast edit'
}
})
);
}));
}
}

async function handleWordLeave(
Expand All @@ -309,16 +356,20 @@ async function handleWordLeave(
return;
}

await redisPublisher.publish(
`updates:${articleId}`,
JSON.stringify({
type: 'word_leave',
data: {
wordIndex,
editorId: userId
}
})
);
try {
await redisPublisher.publish(
`updates:${articleId}`,
JSON.stringify({
type: 'word_leave',
data: {
wordIndex,
editorId: userId
}
})
);
} catch (e) {
console.error('Failed to publish word_leave:', e);
}
}

const server = Bun.serve<WebSocketData>({
Expand Down Expand Up @@ -360,27 +411,31 @@ const server = Bun.serve<WebSocketData>({
ws.data.lastActivity = Date.now();

if (typeof msg !== 'string') return;
if (msg.length > MAX_MESSAGE_LENGTH) return;

try {
const data = JSON.parse(msg) as { type: string; article?: any; wordIndex?: number; newWord?: string };
switch (data.type) {
const parsed = JSON.parse(msg) as { type: string; article?: any; wordIndex?: number; newWord?: string };
if (!parsed || typeof parsed.type !== 'string') return;
const allowedTypes = new Set(['set_article', 'get_active_articles', 'word_hover', 'word_leave', 'pong']);
if (!allowedTypes.has(parsed.type)) return;
switch (parsed.type) {
case 'set_article':
if (data.article) {
await handleSetArticle(ws, data.article.id);
if (parsed.article && parsed.article.id) {
await handleSetArticle(ws, parsed.article.id);
}
break;
case 'get_active_articles':
await handleGetActiveArticles(ws);
break;
case 'word_hover':
await handleWordHover(ws, {
wordIndex: data.wordIndex!,
newWord: data.newWord!
wordIndex: parsed.wordIndex!,
newWord: parsed.newWord!
});
break;
case 'word_leave':
await handleWordLeave(ws, {
wordIndex: data.wordIndex!
wordIndex: parsed.wordIndex!
});
break;
case 'pong':
Expand All @@ -402,7 +457,7 @@ const server = Bun.serve<WebSocketData>({
}, HEARTBEAT_INTERVAL);
},

close(ws) {
async close(ws) {
if (ws.data.articleId && ws.data.user) {
const userId = ws.data.user.id;

Expand All @@ -423,9 +478,17 @@ const server = Bun.serve<WebSocketData>({
users.delete(userId);
if (users.size === 0) {
articleUsers.delete(ws.data.articleId);
if (subscribedArticles.has(ws.data.articleId)) {
try {
await redisSubscriber.unsubscribe('updates:' + ws.data.articleId);
} catch (e) {
console.error('Failed to unsubscribe from redis channel during close:', e);
}
subscribedArticles.delete(ws.data.articleId);
}
}

broadcastUserCount(ws.data.articleId);
await broadcastUserCount(ws.data.articleId);

const message = JSON.stringify({
type: 'user_disconnected',
Expand Down Expand Up @@ -453,4 +516,4 @@ const server = Bun.serve<WebSocketData>({
}
});

console.log(`WebSocket server is running on port ${server.port}`);
console.log(`WebSocket server is running on port ${server.port}`);