@@ -45,6 +45,9 @@ def enable_debug_mode():
4545 print ("Debug mode enabled for all relevant loggers" )
4646
4747
48+ DEFAULT_MAX_LOGS_LIMIT = 1000 # Default limit for initial log load to prevent slowdowns
49+
50+
4851class WebSocketManager :
4952 """Manages WebSocket connections and broadcasts messages."""
5053
@@ -53,17 +56,42 @@ def __init__(self):
5356 self ._broadcast_queue : Queue = Queue ()
5457 self ._broadcast_task : Optional [asyncio .Task ] = None
5558 self ._lock = threading .Lock ()
59+ # Track which invocation_ids each connection is subscribed to (None = all)
60+ self ._connection_filters : Dict [WebSocket , Optional [List [str ]]] = {}
5661
57- async def connect (self , websocket : WebSocket ):
62+ async def connect (
63+ self ,
64+ websocket : WebSocket ,
65+ invocation_ids : Optional [List [str ]] = None ,
66+ limit : Optional [int ] = None ,
67+ ):
68+ """
69+ Connect a WebSocket client and send initial logs.
70+
71+ Args:
72+ websocket: The WebSocket connection
73+ invocation_ids: Optional list of invocation_ids to filter logs
74+ limit: Maximum number of logs to send initially (defaults to DEFAULT_MAX_LOGS_LIMIT)
75+ """
5876 logger .debug ("[WEBSOCKET_CONNECT] New websocket connection attempt" )
5977 await websocket .accept ()
6078 with self ._lock :
6179 self .active_connections .append (websocket )
80+ self ._connection_filters [websocket ] = invocation_ids
6281 connection_count = len (self .active_connections )
63- logger .info (f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: { connection_count } " )
82+ logger .info (
83+ f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: { connection_count } , "
84+ f"invocation_ids filter: { invocation_ids } , limit: { limit } "
85+ )
86+
87+ # Use provided limit or default
88+ effective_limit = limit if limit is not None else DEFAULT_MAX_LOGS_LIMIT
6489
65- logger .debug ("[WEBSOCKET_CONNECT] Reading logs for initialization" )
66- logs = default_logger .read ()
90+ logger .debug (
91+ f"[WEBSOCKET_CONNECT] Reading logs for initialization with "
92+ f"invocation_ids={ invocation_ids } , limit={ effective_limit } "
93+ )
94+ logs = default_logger .read (invocation_ids = invocation_ids , limit = effective_limit )
6795 logger .debug (f"[WEBSOCKET_CONNECT] Found { len (logs )} logs to send" )
6896
6997 data = {
@@ -82,16 +110,25 @@ def disconnect(self, websocket: WebSocket):
82110 logger .debug ("[WEBSOCKET_DISCONNECT] Removed websocket from active connections" )
83111 else :
84112 logger .debug ("[WEBSOCKET_DISCONNECT] Websocket was not in active connections" )
113+ # Clean up connection filter
114+ if websocket in self ._connection_filters :
115+ del self ._connection_filters [websocket ]
85116 connection_count = len (self .active_connections )
86117 logger .info (f"[WEBSOCKET_DISCONNECT] WebSocket disconnected. Total connections: { connection_count } " )
87118
88119 def broadcast_row_upserted (self , row : "EvaluationRow" ):
89120 """Broadcast a row-upsert event to all connected clients.
90121
91122 Safe no-op if server loop is not running or there are no connections.
123+ Messages are only sent to connections whose invocation_id filter matches the row,
124+ or to connections with no filter (subscribed to all).
92125 """
93126 rollout_id = row .execution_metadata .rollout_id if row .execution_metadata else "unknown"
94- logger .debug (f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: { rollout_id } " )
127+ row_invocation_id = row .execution_metadata .invocation_id if row .execution_metadata else None
128+ logger .debug (
129+ f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: { rollout_id } , "
130+ f"invocation_id: { row_invocation_id } "
131+ )
95132
96133 with self ._lock :
97134 active_connections_count = len (self .active_connections )
@@ -105,9 +142,9 @@ def broadcast_row_upserted(self, row: "EvaluationRow"):
105142 f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: { len (json_message )} ) for rollout_id: { rollout_id } "
106143 )
107144
108- # Queue the message for broadcasting in the main event loop
145+ # Queue the message for broadcasting in the main event loop, along with invocation_id for filtering
109146 logger .debug (f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: { rollout_id } " )
110- self ._broadcast_queue .put (json_message )
147+ self ._broadcast_queue .put (( json_message , row_invocation_id ) )
111148 logger .debug (f"[WEBSOCKET_BROADCAST] Successfully queued message for rollout_id: { rollout_id } " )
112149 except Exception as e :
113150 logger .error (
@@ -121,15 +158,25 @@ async def _start_broadcast_loop(self):
121158 try :
122159 # Wait for a message to be queued
123160 logger .debug ("[WEBSOCKET_BROADCAST_LOOP] Waiting for message from queue" )
124- message_data = await asyncio .get_event_loop ().run_in_executor (None , self ._broadcast_queue .get )
161+ queue_item = await asyncio .get_event_loop ().run_in_executor (None , self ._broadcast_queue .get )
162+
163+ # Queue item is a tuple of (json_message, row_invocation_id)
164+ if isinstance (queue_item , tuple ):
165+ json_message , row_invocation_id = queue_item
166+ else :
167+ # Backward compatibility: if it's just a string, send to all
168+ json_message = str (queue_item )
169+ row_invocation_id = None
170+
125171 logger .debug (
126- f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: { len (str (message_data ))} )"
172+ f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: { len (json_message )} ), "
173+ f"invocation_id: { row_invocation_id } "
127174 )
128175
129- # Regular string message for all connections
130- logger .debug ("[WEBSOCKET_BROADCAST_LOOP] Sending message to all connections" )
131- await self ._send_text_to_all_connections ( str ( message_data ) )
132- logger .debug ("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to all connections" )
176+ # Send message to connections that match the filter
177+ logger .debug ("[WEBSOCKET_BROADCAST_LOOP] Sending message to filtered connections" )
178+ await self ._send_text_to_filtered_connections ( json_message , row_invocation_id )
179+ logger .debug ("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to connections" )
133180
134181 except Exception as e :
135182 logger .error (f"[WEBSOCKET_BROADCAST_LOOP] Error in broadcast loop: { e } " )
@@ -138,28 +185,54 @@ async def _start_broadcast_loop(self):
138185 logger .info ("[WEBSOCKET_BROADCAST_LOOP] Broadcast loop cancelled" )
139186 break
140187
141- async def _send_text_to_all_connections (self , text : str ):
188+ async def _send_text_to_filtered_connections (self , text : str , row_invocation_id : Optional [str ] = None ):
189+ """
190+ Send text to connections that match the invocation_id filter.
191+
192+ Args:
193+ text: The message to send
194+ row_invocation_id: The invocation_id of the row being sent.
195+ Connections with no filter (None) receive all messages.
196+ Connections with a filter only receive messages where row_invocation_id is in their filter.
197+ """
142198 with self ._lock :
143199 connections = list (self .active_connections )
200+ connection_filters = dict (self ._connection_filters )
201+
202+ # Filter connections based on their subscribed invocation_ids
203+ eligible_connections = []
204+ for conn in connections :
205+ conn_filter = connection_filters .get (conn )
206+ if conn_filter is None :
207+ # No filter means subscribed to all
208+ eligible_connections .append (conn )
209+ elif row_invocation_id is not None and row_invocation_id in conn_filter :
210+ # Row's invocation_id matches connection's filter
211+ eligible_connections .append (conn )
212+ # else: skip this connection
213+
214+ logger .debug (
215+ f"[WEBSOCKET_SEND] Attempting to send to { len (eligible_connections )} of { len (connections )} connections "
216+ f"(filtered by invocation_id: { row_invocation_id } )"
217+ )
144218
145- logger .debug (f"[WEBSOCKET_SEND] Attempting to send to { len (connections )} connections" )
146-
147- if not connections :
148- logger .debug ("[WEBSOCKET_SEND] No connections available, skipping send" )
219+ if not eligible_connections :
220+ logger .debug ("[WEBSOCKET_SEND] No eligible connections, skipping send" )
149221 return
150222
151223 tasks = []
152- failed_connections = []
224+ task_connections = [] # Track which connection each task corresponds to
153225
154- for i , connection in enumerate (connections ):
226+ for i , connection in enumerate (eligible_connections ):
155227 try :
156228 logger .debug (f"[WEBSOCKET_SEND] Preparing to send to connection { i } " )
157229 tasks .append (connection .send_text (text ))
230+ task_connections .append (connection )
158231 except Exception as e :
159232 logger .error (f"[WEBSOCKET_SEND] Failed to prepare send to WebSocket { i } : { e } " )
160- failed_connections .append (connection )
161233
162234 # Execute all sends in parallel
235+ failed_connections = []
163236 if tasks :
164237 logger .debug (f"[WEBSOCKET_SEND] Executing { len (tasks )} parallel sends" )
165238 results = await asyncio .gather (* tasks , return_exceptions = True )
@@ -169,7 +242,7 @@ async def _send_text_to_all_connections(self, text: str):
169242 for i , result in enumerate (results ):
170243 if isinstance (result , Exception ):
171244 logger .error (f"[WEBSOCKET_SEND] Failed to send text to WebSocket { i } : { result } " )
172- failed_connections .append (connections [i ])
245+ failed_connections .append (task_connections [i ])
173246 else :
174247 logger .debug (f"[WEBSOCKET_SEND] Successfully sent to connection { i } " )
175248
@@ -180,6 +253,8 @@ async def _send_text_to_all_connections(self, text: str):
180253 for connection in failed_connections :
181254 try :
182255 self .active_connections .remove (connection )
256+ if connection in self ._connection_filters :
257+ del self ._connection_filters [connection ]
183258 except ValueError :
184259 pass
185260
@@ -393,7 +468,27 @@ def _setup_websocket_routes(self):
393468
394469 @self .app .websocket ("/ws" )
395470 async def websocket_endpoint (websocket : WebSocket ):
396- await self .websocket_manager .connect (websocket )
471+ # Parse query parameters from WebSocket connection URL
472+ # invocation_ids: comma-separated list of invocation IDs to filter
473+ # limit: maximum number of initial logs to load
474+ query_params = websocket .query_params
475+ invocation_ids_param = query_params .get ("invocation_ids" )
476+ limit_param = query_params .get ("limit" )
477+
478+ invocation_ids : Optional [List [str ]] = None
479+ if invocation_ids_param :
480+ invocation_ids = [id .strip () for id in invocation_ids_param .split ("," ) if id .strip ()]
481+ logger .info (f"[WEBSOCKET] Client filtering by invocation_ids: { invocation_ids } " )
482+
483+ limit : Optional [int ] = None
484+ if limit_param :
485+ try :
486+ limit = int (limit_param )
487+ logger .info (f"[WEBSOCKET] Client requested limit: { limit } " )
488+ except ValueError :
489+ logger .warning (f"[WEBSOCKET] Invalid limit parameter: { limit_param } " )
490+
491+ await self .websocket_manager .connect (websocket , invocation_ids = invocation_ids , limit = limit )
397492 try :
398493 while True :
399494 # Keep connection alive (for evaluation row updates)
0 commit comments