Skip to content

Commit 4dea043

Browse files
author
Dylan Huang
committed
works with broadcast queue
1 parent f8398de commit 4dea043

1 file changed

Lines changed: 70 additions & 36 deletions

File tree

eval_protocol/utils/logs_server.py

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
import json
33
import logging
44
import os
5+
import threading
56
import time
67
from contextlib import asynccontextmanager
8+
from queue import Queue
79
from typing import TYPE_CHECKING, Any, List, Optional
810

911
import uvicorn
1012
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
1113

1214
from eval_protocol.dataset_logger import default_logger
15+
from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE
1316
from eval_protocol.event_bus import event_bus
1417
from eval_protocol.utils.vite_server import ViteServer
1518

@@ -24,54 +27,85 @@ class WebSocketManager:
2427

2528
def __init__(self):
2629
self.active_connections: List[WebSocket] = []
27-
self._loop = None
30+
self._broadcast_queue: Queue = Queue()
31+
self._broadcast_task: Optional[asyncio.Task] = None
32+
self._lock = threading.Lock()
2833

2934
async def connect(self, websocket: WebSocket):
3035
await websocket.accept()
31-
self.active_connections.append(websocket)
32-
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
36+
with self._lock:
37+
self.active_connections.append(websocket)
38+
connection_count = len(self.active_connections)
39+
logger.info(f"WebSocket connected. Total connections: {connection_count}")
3340
logs = default_logger.read()
34-
asyncio.run_coroutine_threadsafe(
35-
websocket.send_text(
36-
json.dumps(
37-
{
38-
"type": "initialize_logs",
39-
"logs": [log.model_dump_json(exclude_none=True) for log in logs],
40-
}
41-
)
42-
),
43-
self._loop,
41+
await websocket.send_text(
42+
json.dumps({"type": "initialize_logs", "logs": [log.model_dump_json(exclude_none=True) for log in logs]})
4443
)
4544

4645
def disconnect(self, websocket: WebSocket):
47-
if websocket in self.active_connections:
48-
self.active_connections.remove(websocket)
49-
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
46+
with self._lock:
47+
if websocket in self.active_connections:
48+
self.active_connections.remove(websocket)
49+
connection_count = len(self.active_connections)
50+
logger.info(f"WebSocket disconnected. Total connections: {connection_count}")
5051

5152
def broadcast_row_upserted(self, row: "EvaluationRow"):
5253
"""Broadcast a row-upsert event to all connected clients.
5354
5455
Safe no-op if server loop is not running or there are no connections.
5556
"""
56-
if not self._loop or not self.active_connections:
57-
return
58-
5957
try:
6058
# Serialize pydantic model
6159
json_message = json.dumps({"type": "log", "row": json.loads(row.model_dump_json(exclude_none=True))})
60+
# Queue the message for broadcasting in the main event loop
61+
self._broadcast_queue.put(json_message)
6262
except Exception as e:
6363
logger.error(f"Failed to serialize row for broadcast: {e}")
64+
65+
async def _start_broadcast_loop(self):
66+
"""Start the broadcast loop that processes queued messages."""
67+
while True:
68+
try:
69+
# Wait for a message to be queued
70+
message = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
71+
await self._send_text_to_all_connections(message)
72+
except Exception as e:
73+
logger.error(f"Error in broadcast loop: {e}")
74+
await asyncio.sleep(0.1)
75+
except asyncio.CancelledError:
76+
logger.info("Broadcast loop cancelled")
77+
break
78+
79+
async def _send_text_to_all_connections(self, text: str):
80+
with self._lock:
81+
connections = list(self.active_connections)
82+
83+
if not connections:
6484
return
6585

66-
for connection in list(self.active_connections):
86+
tasks = []
87+
for connection in connections:
6788
try:
68-
asyncio.run_coroutine_threadsafe(connection.send_text(json_message), self._loop)
89+
tasks.append(connection.send_text(text))
6990
except Exception as e:
70-
logger.error(f"Failed to send row_upserted to WebSocket: {e}")
71-
try:
72-
self.active_connections.remove(connection)
73-
except ValueError:
74-
pass
91+
logger.error(f"Failed to send text to WebSocket: {e}")
92+
with self._lock:
93+
try:
94+
self.active_connections.remove(connection)
95+
except ValueError:
96+
pass
97+
if tasks:
98+
await asyncio.gather(*tasks, return_exceptions=True)
99+
100+
def start_broadcast_loop(self):
101+
"""Start the broadcast loop in the current event loop."""
102+
if self._broadcast_task is None or self._broadcast_task.done():
103+
self._broadcast_task = asyncio.create_task(self._start_broadcast_loop())
104+
105+
def stop_broadcast_loop(self):
106+
"""Stop the broadcast loop."""
107+
if self._broadcast_task and not self._broadcast_task.done():
108+
self._broadcast_task.cancel()
75109

76110

77111
class LogsServer(ViteServer):
@@ -95,12 +129,7 @@ def __init__(
95129
# Initialize WebSocket manager
96130
self.websocket_manager = WebSocketManager()
97131

98-
@asynccontextmanager
99-
async def lifespan(app: FastAPI):
100-
self.websocket_manager._loop = asyncio.get_running_loop()
101-
yield
102-
103-
super().__init__(build_dir, host, port, index_file, lifespan=lifespan)
132+
super().__init__(build_dir, host, port, index_file)
104133

105134
# Add WebSocket endpoint
106135
self._setup_websocket_routes()
@@ -130,16 +159,18 @@ async def websocket_endpoint(websocket: WebSocket):
130159
@self.app.get("/api/status")
131160
async def status():
132161
"""Get server status including active connections."""
162+
with self.websocket_manager._lock:
163+
active_connections_count = len(self.websocket_manager.active_connections)
133164
return {
134165
"status": "ok",
135166
"build_dir": str(self.build_dir),
136-
"active_connections": len(self.websocket_manager.active_connections),
167+
"active_connections": active_connections_count,
137168
"watch_paths": self.watch_paths,
138169
}
139170

140171
def _handle_event(self, event_type: str, data: Any) -> None:
141172
"""Handle events from the event bus."""
142-
if event_type in ["log"]:
173+
if event_type in [LOG_EVENT_TYPE]:
143174
from eval_protocol.models import EvaluationRow
144175

145176
data = EvaluationRow(**data)
@@ -157,8 +188,8 @@ async def run_async(self):
157188
logger.info(f"Serving files from: {self.build_dir}")
158189
logger.info("WebSocket endpoint available at /ws")
159190

160-
# Store the event loop for WebSocket manager
161-
self.websocket_manager._loop = asyncio.get_running_loop()
191+
# Start the broadcast loop
192+
self.websocket_manager.start_broadcast_loop()
162193

163194
config = uvicorn.Config(
164195
self.app,
@@ -172,6 +203,9 @@ async def run_async(self):
172203

173204
except KeyboardInterrupt:
174205
logger.info("Shutting down LogsServer...")
206+
finally:
207+
# Clean up broadcast loop
208+
self.websocket_manager.stop_broadcast_loop()
175209

176210
def run(self):
177211
"""

0 commit comments

Comments
 (0)