Skip to content

Commit 84949cd

Browse files
committed
perf: Add server-side filtering for ep logs to improve performance
When opening the logs UI with a filterConfig URL parameter (e.g., filtering by invocation_id), the backend now filters data server-side instead of sending all logs to the frontend for client-side filtering. Changes: - Add invocation_ids and limit parameters to SqliteEvaluationRowStore.read_rows() - Add server-side filtering using SQLite JSON extraction - Default limit of 1000 rows for initial WebSocket load - WebSocket endpoint parses invocation_ids from URL query params - Frontend extracts invocation_ids from URL filterConfig and passes to WebSocket - Real-time broadcast messages filtered per-connection based on subscribed invocation_ids - Order results by most recent first (DESC by rowid) This significantly speeds up ep logs when viewing filtered results by avoiding loading all historical data.
1 parent 7f8056c commit 84949cd

9 files changed

Lines changed: 399 additions & 53 deletions

File tree

eval_protocol/dataset_logger/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import List, Optional
23

34
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
45
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
@@ -14,7 +15,7 @@ class _NoOpLogger(DatasetLogger):
1415
def log(self, row):
1516
return None
1617

17-
def read(self, rollout_id=None):
18+
def read(self, rollout_id=None, invocation_ids=None, limit=None):
1819
return []
1920

2021
return _NoOpLogger()
@@ -33,8 +34,13 @@ def _get_logger(self):
3334
def log(self, row):
3435
return self._get_logger().log(row)
3536

36-
def read(self, rollout_id=None):
37-
return self._get_logger().read(rollout_id)
37+
def read(
38+
self,
39+
rollout_id: Optional[str] = None,
40+
invocation_ids: Optional[List[str]] = None,
41+
limit: Optional[int] = None,
42+
):
43+
return self._get_logger().read(rollout_id=rollout_id, invocation_ids=invocation_ids, limit=limit)
3844

3945

4046
default_logger: DatasetLogger = _LazyLogger()

eval_protocol/dataset_logger/dataset_logger.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,19 @@ def log(self, row: "EvaluationRow") -> None:
2424
pass
2525

2626
@abstractmethod
27-
def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]:
27+
def read(
28+
self,
29+
rollout_id: Optional[str] = None,
30+
invocation_ids: Optional[List[str]] = None,
31+
limit: Optional[int] = None,
32+
) -> List["EvaluationRow"]:
2833
"""
29-
Retrieve EvaluationRow logs.
34+
Retrieve EvaluationRow logs with optional filtering.
3035
3136
Args:
32-
row_id (Optional[str]): If provided, filter logs by this row_id.
37+
rollout_id (Optional[str]): If provided, filter logs by this rollout_id.
38+
invocation_ids (Optional[List[str]]): If provided, filter logs by these invocation_ids.
39+
limit (Optional[int]): If provided, limit the number of rows returned (most recent first).
3340
3441
Returns:
3542
List[EvaluationRow]: List of retrieved evaluation rows.

eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ def log(self, row: "EvaluationRow") -> None:
3838
logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}")
3939
pass
4040

41-
def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
41+
def read(
42+
self,
43+
rollout_id: Optional[str] = None,
44+
invocation_ids: Optional[List[str]] = None,
45+
limit: Optional[int] = None,
46+
) -> List["EvaluationRow"]:
4247
from eval_protocol.models import EvaluationRow
4348

44-
results = self._store.read_rows(rollout_id=rollout_id)
49+
results = self._store.read_rows(rollout_id=rollout_id, invocation_ids=invocation_ids, limit=limit)
4550
return [EvaluationRow(**data) for data in results]

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import List, Optional
33

4-
from peewee import CharField, Model, SqliteDatabase
4+
from peewee import CharField, Model, SqliteDatabase, fn, SQL
55
from playhouse.sqlite_ext import JSONField
66

77
from eval_protocol.event_bus.sqlite_event_bus_database import (
@@ -67,12 +67,55 @@ def _do_upsert(self, rollout_id: str, data: dict) -> None:
6767
else:
6868
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
6969

70-
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
71-
if rollout_id is None:
72-
query = self._EvaluationRow.select().dicts()
73-
else:
74-
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
75-
results = list(query)
70+
def read_rows(
71+
self,
72+
rollout_id: Optional[str] = None,
73+
invocation_ids: Optional[List[str]] = None,
74+
limit: Optional[int] = None,
75+
) -> List[dict]:
76+
"""
77+
Read evaluation rows from the database with optional filtering.
78+
79+
Args:
80+
rollout_id: Filter by a specific rollout_id (exact match)
81+
invocation_ids: Filter by a list of invocation_ids (rows matching any)
82+
limit: Maximum number of rows to return (most recent first)
83+
84+
Returns:
85+
List of evaluation row data dictionaries
86+
"""
87+
query = self._EvaluationRow.select()
88+
89+
if rollout_id is not None:
90+
query = query.where(self._EvaluationRow.rollout_id == rollout_id)
91+
92+
# Apply invocation_ids filter using JSON extraction
93+
# Note: This filters rows where data->'execution_metadata'->>'invocation_id' matches any of the provided IDs
94+
if invocation_ids is not None and len(invocation_ids) > 0:
95+
# Build a condition that matches any of the invocation_ids
96+
# Using SQLite JSON extraction: json_extract(data, '$.execution_metadata.invocation_id')
97+
invocation_conditions = []
98+
for inv_id in invocation_ids:
99+
invocation_conditions.append(
100+
fn.json_extract(self._EvaluationRow.data, "$.execution_metadata.invocation_id") == inv_id
101+
)
102+
# Combine with OR
103+
if len(invocation_conditions) == 1:
104+
query = query.where(invocation_conditions[0])
105+
else:
106+
from functools import reduce
107+
from operator import or_
108+
109+
combined_condition = reduce(or_, invocation_conditions)
110+
query = query.where(combined_condition)
111+
112+
# Order by rowid descending to get most recent rows first
113+
query = query.order_by(SQL("rowid DESC"))
114+
115+
if limit is not None:
116+
query = query.limit(limit)
117+
118+
results = list(query.dicts())
76119
return [result["data"] for result in results]
77120

78121
def delete_row(self, rollout_id: str) -> int:

eval_protocol/utils/logs_server.py

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def enable_debug_mode():
4545
print("Debug mode enabled for all relevant loggers")
4646

4747

48+
DEFAULT_MAX_LOGS_LIMIT = 1000 # Default limit for initial log load to prevent slowdowns
49+
50+
4851
class WebSocketManager:
4952
"""Manages WebSocket connections and broadcasts messages."""
5053

@@ -53,17 +56,42 @@ def __init__(self):
5356
self._broadcast_queue: Queue = Queue()
5457
self._broadcast_task: Optional[asyncio.Task] = None
5558
self._lock = threading.Lock()
59+
# Track which invocation_ids each connection is subscribed to (None = all)
60+
self._connection_filters: Dict[WebSocket, Optional[List[str]]] = {}
5661

57-
async def connect(self, websocket: WebSocket):
62+
async def connect(
63+
self,
64+
websocket: WebSocket,
65+
invocation_ids: Optional[List[str]] = None,
66+
limit: Optional[int] = None,
67+
):
68+
"""
69+
Connect a WebSocket client and send initial logs.
70+
71+
Args:
72+
websocket: The WebSocket connection
73+
invocation_ids: Optional list of invocation_ids to filter logs
74+
limit: Maximum number of logs to send initially (defaults to DEFAULT_MAX_LOGS_LIMIT)
75+
"""
5876
logger.debug("[WEBSOCKET_CONNECT] New websocket connection attempt")
5977
await websocket.accept()
6078
with self._lock:
6179
self.active_connections.append(websocket)
80+
self._connection_filters[websocket] = invocation_ids
6281
connection_count = len(self.active_connections)
63-
logger.info(f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}")
82+
logger.info(
83+
f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}, "
84+
f"invocation_ids filter: {invocation_ids}, limit: {limit}"
85+
)
86+
87+
# Use provided limit or default
88+
effective_limit = limit if limit is not None else DEFAULT_MAX_LOGS_LIMIT
6489

65-
logger.debug("[WEBSOCKET_CONNECT] Reading logs for initialization")
66-
logs = default_logger.read()
90+
logger.debug(
91+
f"[WEBSOCKET_CONNECT] Reading logs for initialization with "
92+
f"invocation_ids={invocation_ids}, limit={effective_limit}"
93+
)
94+
logs = default_logger.read(invocation_ids=invocation_ids, limit=effective_limit)
6795
logger.debug(f"[WEBSOCKET_CONNECT] Found {len(logs)} logs to send")
6896

6997
data = {
@@ -82,16 +110,25 @@ def disconnect(self, websocket: WebSocket):
82110
logger.debug("[WEBSOCKET_DISCONNECT] Removed websocket from active connections")
83111
else:
84112
logger.debug("[WEBSOCKET_DISCONNECT] Websocket was not in active connections")
113+
# Clean up connection filter
114+
if websocket in self._connection_filters:
115+
del self._connection_filters[websocket]
85116
connection_count = len(self.active_connections)
86117
logger.info(f"[WEBSOCKET_DISCONNECT] WebSocket disconnected. Total connections: {connection_count}")
87118

88119
def broadcast_row_upserted(self, row: "EvaluationRow"):
89120
"""Broadcast a row-upsert event to all connected clients.
90121
91122
Safe no-op if server loop is not running or there are no connections.
123+
Messages are only sent to connections whose invocation_id filter matches the row,
124+
or to connections with no filter (subscribed to all).
92125
"""
93126
rollout_id = row.execution_metadata.rollout_id if row.execution_metadata else "unknown"
94-
logger.debug(f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}")
127+
row_invocation_id = row.execution_metadata.invocation_id if row.execution_metadata else None
128+
logger.debug(
129+
f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}, "
130+
f"invocation_id: {row_invocation_id}"
131+
)
95132

96133
with self._lock:
97134
active_connections_count = len(self.active_connections)
@@ -105,9 +142,9 @@ def broadcast_row_upserted(self, row: "EvaluationRow"):
105142
f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: {len(json_message)}) for rollout_id: {rollout_id}"
106143
)
107144

108-
# Queue the message for broadcasting in the main event loop
145+
# Queue the message for broadcasting in the main event loop, along with invocation_id for filtering
109146
logger.debug(f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: {rollout_id}")
110-
self._broadcast_queue.put(json_message)
147+
self._broadcast_queue.put((json_message, row_invocation_id))
111148
logger.debug(f"[WEBSOCKET_BROADCAST] Successfully queued message for rollout_id: {rollout_id}")
112149
except Exception as e:
113150
logger.error(
@@ -121,15 +158,25 @@ async def _start_broadcast_loop(self):
121158
try:
122159
# Wait for a message to be queued
123160
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Waiting for message from queue")
124-
message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
161+
queue_item = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
162+
163+
# Queue item is a tuple of (json_message, row_invocation_id)
164+
if isinstance(queue_item, tuple):
165+
json_message, row_invocation_id = queue_item
166+
else:
167+
# Backward compatibility: if it's just a string, send to all
168+
json_message = str(queue_item)
169+
row_invocation_id = None
170+
125171
logger.debug(
126-
f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(str(message_data))})"
172+
f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(json_message)}), "
173+
f"invocation_id: {row_invocation_id}"
127174
)
128175

129-
# Regular string message for all connections
130-
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to all connections")
131-
await self._send_text_to_all_connections(str(message_data))
132-
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to all connections")
176+
# Send message to connections that match the filter
177+
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to filtered connections")
178+
await self._send_text_to_filtered_connections(json_message, row_invocation_id)
179+
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to connections")
133180

134181
except Exception as e:
135182
logger.error(f"[WEBSOCKET_BROADCAST_LOOP] Error in broadcast loop: {e}")
@@ -138,28 +185,54 @@ async def _start_broadcast_loop(self):
138185
logger.info("[WEBSOCKET_BROADCAST_LOOP] Broadcast loop cancelled")
139186
break
140187

141-
async def _send_text_to_all_connections(self, text: str):
188+
async def _send_text_to_filtered_connections(self, text: str, row_invocation_id: Optional[str] = None):
189+
"""
190+
Send text to connections that match the invocation_id filter.
191+
192+
Args:
193+
text: The message to send
194+
row_invocation_id: The invocation_id of the row being sent.
195+
Connections with no filter (None) receive all messages.
196+
Connections with a filter only receive messages where row_invocation_id is in their filter.
197+
"""
142198
with self._lock:
143199
connections = list(self.active_connections)
200+
connection_filters = dict(self._connection_filters)
201+
202+
# Filter connections based on their subscribed invocation_ids
203+
eligible_connections = []
204+
for conn in connections:
205+
conn_filter = connection_filters.get(conn)
206+
if conn_filter is None:
207+
# No filter means subscribed to all
208+
eligible_connections.append(conn)
209+
elif row_invocation_id is not None and row_invocation_id in conn_filter:
210+
# Row's invocation_id matches connection's filter
211+
eligible_connections.append(conn)
212+
# else: skip this connection
213+
214+
logger.debug(
215+
f"[WEBSOCKET_SEND] Attempting to send to {len(eligible_connections)} of {len(connections)} connections "
216+
f"(filtered by invocation_id: {row_invocation_id})"
217+
)
144218

145-
logger.debug(f"[WEBSOCKET_SEND] Attempting to send to {len(connections)} connections")
146-
147-
if not connections:
148-
logger.debug("[WEBSOCKET_SEND] No connections available, skipping send")
219+
if not eligible_connections:
220+
logger.debug("[WEBSOCKET_SEND] No eligible connections, skipping send")
149221
return
150222

151223
tasks = []
152-
failed_connections = []
224+
task_connections = [] # Track which connection each task corresponds to
153225

154-
for i, connection in enumerate(connections):
226+
for i, connection in enumerate(eligible_connections):
155227
try:
156228
logger.debug(f"[WEBSOCKET_SEND] Preparing to send to connection {i}")
157229
tasks.append(connection.send_text(text))
230+
task_connections.append(connection)
158231
except Exception as e:
159232
logger.error(f"[WEBSOCKET_SEND] Failed to prepare send to WebSocket {i}: {e}")
160-
failed_connections.append(connection)
161233

162234
# Execute all sends in parallel
235+
failed_connections = []
163236
if tasks:
164237
logger.debug(f"[WEBSOCKET_SEND] Executing {len(tasks)} parallel sends")
165238
results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -169,7 +242,7 @@ async def _send_text_to_all_connections(self, text: str):
169242
for i, result in enumerate(results):
170243
if isinstance(result, Exception):
171244
logger.error(f"[WEBSOCKET_SEND] Failed to send text to WebSocket {i}: {result}")
172-
failed_connections.append(connections[i])
245+
failed_connections.append(task_connections[i])
173246
else:
174247
logger.debug(f"[WEBSOCKET_SEND] Successfully sent to connection {i}")
175248

@@ -180,6 +253,8 @@ async def _send_text_to_all_connections(self, text: str):
180253
for connection in failed_connections:
181254
try:
182255
self.active_connections.remove(connection)
256+
if connection in self._connection_filters:
257+
del self._connection_filters[connection]
183258
except ValueError:
184259
pass
185260

@@ -393,7 +468,27 @@ def _setup_websocket_routes(self):
393468

394469
@self.app.websocket("/ws")
395470
async def websocket_endpoint(websocket: WebSocket):
396-
await self.websocket_manager.connect(websocket)
471+
# Parse query parameters from WebSocket connection URL
472+
# invocation_ids: comma-separated list of invocation IDs to filter
473+
# limit: maximum number of initial logs to load
474+
query_params = websocket.query_params
475+
invocation_ids_param = query_params.get("invocation_ids")
476+
limit_param = query_params.get("limit")
477+
478+
invocation_ids: Optional[List[str]] = None
479+
if invocation_ids_param:
480+
invocation_ids = [id.strip() for id in invocation_ids_param.split(",") if id.strip()]
481+
logger.info(f"[WEBSOCKET] Client filtering by invocation_ids: {invocation_ids}")
482+
483+
limit: Optional[int] = None
484+
if limit_param:
485+
try:
486+
limit = int(limit_param)
487+
logger.info(f"[WEBSOCKET] Client requested limit: {limit}")
488+
except ValueError:
489+
logger.warning(f"[WEBSOCKET] Invalid limit parameter: {limit_param}")
490+
491+
await self.websocket_manager.connect(websocket, invocation_ids=invocation_ids, limit=limit)
397492
try:
398493
while True:
399494
# Keep connection alive (for evaluation row updates)

0 commit comments

Comments
 (0)