22import json
33import logging
44import os
5+ import threading
56import time
67from contextlib import asynccontextmanager
8+ from queue import Queue
79from typing import TYPE_CHECKING , Any , List , Optional
810
911import uvicorn
1012from fastapi import FastAPI , WebSocket , WebSocketDisconnect
1113
1214from eval_protocol .dataset_logger import default_logger
15+ from eval_protocol .dataset_logger .dataset_logger import LOG_EVENT_TYPE
1316from eval_protocol .event_bus import event_bus
1417from eval_protocol .utils .vite_server import ViteServer
1518
@@ -24,54 +27,85 @@ class WebSocketManager:
2427
2528 def __init__ (self ):
2629 self .active_connections : List [WebSocket ] = []
27- self ._loop = None
30+ self ._broadcast_queue : Queue = Queue ()
31+ self ._broadcast_task : Optional [asyncio .Task ] = None
32+ self ._lock = threading .Lock ()
2833
2934 async def connect (self , websocket : WebSocket ):
3035 await websocket .accept ()
31- self .active_connections .append (websocket )
32- logger .info (f"WebSocket connected. Total connections: { len (self .active_connections )} " )
36+ with self ._lock :
37+ self .active_connections .append (websocket )
38+ connection_count = len (self .active_connections )
39+ logger .info (f"WebSocket connected. Total connections: { connection_count } " )
3340 logs = default_logger .read ()
34- asyncio .run_coroutine_threadsafe (
35- websocket .send_text (
36- json .dumps (
37- {
38- "type" : "initialize_logs" ,
39- "logs" : [log .model_dump_json (exclude_none = True ) for log in logs ],
40- }
41- )
42- ),
43- self ._loop ,
41+ await websocket .send_text (
42+ json .dumps ({"type" : "initialize_logs" , "logs" : [log .model_dump_json (exclude_none = True ) for log in logs ]})
4443 )
4544
4645 def disconnect (self , websocket : WebSocket ):
47- if websocket in self .active_connections :
48- self .active_connections .remove (websocket )
49- logger .info (f"WebSocket disconnected. Total connections: { len (self .active_connections )} " )
46+ with self ._lock :
47+ if websocket in self .active_connections :
48+ self .active_connections .remove (websocket )
49+ connection_count = len (self .active_connections )
50+ logger .info (f"WebSocket disconnected. Total connections: { connection_count } " )
5051
5152 def broadcast_row_upserted (self , row : "EvaluationRow" ):
5253 """Broadcast a row-upsert event to all connected clients.
5354
5455 Safe no-op if server loop is not running or there are no connections.
5556 """
56- if not self ._loop or not self .active_connections :
57- return
58-
5957 try :
6058 # Serialize pydantic model
6159 json_message = json .dumps ({"type" : "log" , "row" : json .loads (row .model_dump_json (exclude_none = True ))})
60+ # Queue the message for broadcasting in the main event loop
61+ self ._broadcast_queue .put (json_message )
6262 except Exception as e :
6363 logger .error (f"Failed to serialize row for broadcast: { e } " )
64+
65+ async def _start_broadcast_loop (self ):
66+ """Start the broadcast loop that processes queued messages."""
67+ while True :
68+ try :
69+ # Wait for a message to be queued
70+ message = await asyncio .get_event_loop ().run_in_executor (None , self ._broadcast_queue .get )
71+ await self ._send_text_to_all_connections (message )
72+ except Exception as e :
73+ logger .error (f"Error in broadcast loop: { e } " )
74+ await asyncio .sleep (0.1 )
75+ except asyncio .CancelledError :
76+ logger .info ("Broadcast loop cancelled" )
77+ break
78+
79+ async def _send_text_to_all_connections (self , text : str ):
80+ with self ._lock :
81+ connections = list (self .active_connections )
82+
83+ if not connections :
6484 return
6585
66- for connection in list (self .active_connections ):
86+ tasks = []
87+ for connection in connections :
6788 try :
68- asyncio . run_coroutine_threadsafe (connection .send_text (json_message ), self . _loop )
89+ tasks . append (connection .send_text (text ) )
6990 except Exception as e :
70- logger .error (f"Failed to send row_upserted to WebSocket: { e } " )
71- try :
72- self .active_connections .remove (connection )
73- except ValueError :
74- pass
91+ logger .error (f"Failed to send text to WebSocket: { e } " )
92+ with self ._lock :
93+ try :
94+ self .active_connections .remove (connection )
95+ except ValueError :
96+ pass
97+ if tasks :
98+ await asyncio .gather (* tasks , return_exceptions = True )
99+
100+ def start_broadcast_loop (self ):
101+ """Start the broadcast loop in the current event loop."""
102+ if self ._broadcast_task is None or self ._broadcast_task .done ():
103+ self ._broadcast_task = asyncio .create_task (self ._start_broadcast_loop ())
104+
105+ def stop_broadcast_loop (self ):
106+ """Stop the broadcast loop."""
107+ if self ._broadcast_task and not self ._broadcast_task .done ():
108+ self ._broadcast_task .cancel ()
75109
76110
77111class LogsServer (ViteServer ):
@@ -95,12 +129,7 @@ def __init__(
95129 # Initialize WebSocket manager
96130 self .websocket_manager = WebSocketManager ()
97131
98- @asynccontextmanager
99- async def lifespan (app : FastAPI ):
100- self .websocket_manager ._loop = asyncio .get_running_loop ()
101- yield
102-
103- super ().__init__ (build_dir , host , port , index_file , lifespan = lifespan )
132+ super ().__init__ (build_dir , host , port , index_file )
104133
105134 # Add WebSocket endpoint
106135 self ._setup_websocket_routes ()
@@ -130,16 +159,18 @@ async def websocket_endpoint(websocket: WebSocket):
130159 @self .app .get ("/api/status" )
131160 async def status ():
132161 """Get server status including active connections."""
162+ with self .websocket_manager ._lock :
163+ active_connections_count = len (self .websocket_manager .active_connections )
133164 return {
134165 "status" : "ok" ,
135166 "build_dir" : str (self .build_dir ),
136- "active_connections" : len ( self . websocket_manager . active_connections ) ,
167+ "active_connections" : active_connections_count ,
137168 "watch_paths" : self .watch_paths ,
138169 }
139170
140171 def _handle_event (self , event_type : str , data : Any ) -> None :
141172 """Handle events from the event bus."""
142- if event_type in ["log" ]:
173+ if event_type in [LOG_EVENT_TYPE ]:
143174 from eval_protocol .models import EvaluationRow
144175
145176 data = EvaluationRow (** data )
@@ -157,8 +188,8 @@ async def run_async(self):
157188 logger .info (f"Serving files from: { self .build_dir } " )
158189 logger .info ("WebSocket endpoint available at /ws" )
159190
160- # Store the event loop for WebSocket manager
161- self .websocket_manager ._loop = asyncio . get_running_loop ()
191+ # Start the broadcast loop
192+ self .websocket_manager .start_broadcast_loop ()
162193
163194 config = uvicorn .Config (
164195 self .app ,
@@ -172,6 +203,9 @@ async def run_async(self):
172203
173204 except KeyboardInterrupt :
174205 logger .info ("Shutting down LogsServer..." )
206+ finally :
207+ # Clean up broadcast loop
208+ self .websocket_manager .stop_broadcast_loop ()
175209
176210 def run (self ):
177211 """
0 commit comments