Skip to content

Commit 2dbd868

Browse files
xzrderekDylan Huang
andauthored
RemoteRolloutProcessor use data loader and pulling based on tags (#217)
* Add in data loader and pulling from tags * break early * fix tests * use mocks instead in the test * remove the requester_metadata stuff * pipelined * removing pyright * take out litellm * add types * clean up * add typescript simple example (#218) * add typescript simple example * publish npm package for eval protocol (#219) * publish typescript SDK * add createLangfuseConfigTags function and update version to 0.1.1 * use eval-protocol npm dependency * refactor statusInfoSchema to use a record type and update version to 0.1.2 * add eval_metadata to langfuse_row in RemoteRolloutProcessor * Refactor data generator function name and update eval-protocol version to 0.1.2 * done * move folders --------- Co-authored-by: Dylan Huang <dhuang@fireworks.ai>
1 parent 671c882 commit 2dbd868

28 files changed

+3084
-452
lines changed

eval_protocol/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
except ImportError:
6363
LangSmithAdapter = None
6464

65+
# Remote server types
66+
from .types.remote_rollout_processor import (
67+
InitRequest,
68+
RolloutMetadata,
69+
StatusResponse,
70+
create_langfuse_config_tags,
71+
)
6572

6673
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
6774

@@ -110,6 +117,11 @@
110117
# Submodules
111118
"rewards",
112119
"mcp",
120+
# Remote server types
121+
"InitRequest",
122+
"RolloutMetadata",
123+
"StatusResponse",
124+
"create_langfuse_config_tags",
113125
]
114126

115127
from . import _version

eval_protocol/adapters/langfuse.py

Lines changed: 36 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING, cast
1313

1414
from langfuse.api.resources.commons.types.observations_view import ObservationsView
15-
from eval_protocol.models import EvaluationRow, InputMetadata, Message
15+
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1616
from .base import BaseAdapter
1717
from .utils import extract_messages_from_data
1818

@@ -82,14 +82,41 @@ def convert_trace_to_evaluation_row(
8282
if not messages:
8383
return None
8484

85+
execution_metadata = ExecutionMetadata()
86+
row_id = None
87+
88+
if trace.tags:
89+
for tag in trace.tags:
90+
if tag.startswith("invocation_id:"):
91+
execution_metadata.invocation_id = tag.split(":", 1)[1]
92+
elif tag.startswith("experiment_id:"):
93+
execution_metadata.experiment_id = tag.split(":", 1)[1]
94+
elif tag.startswith("rollout_id:"):
95+
execution_metadata.rollout_id = tag.split(":", 1)[1]
96+
elif tag.startswith("run_id:"):
97+
execution_metadata.run_id = tag.split(":", 1)[1]
98+
elif tag.startswith("row_id:"):
99+
row_id = tag.split(":", 1)[1]
100+
101+
if (
102+
execution_metadata.invocation_id
103+
and execution_metadata.experiment_id
104+
and execution_metadata.rollout_id
105+
and execution_metadata.run_id
106+
and row_id
107+
):
108+
break # Break early if we've found all the metadata we need
109+
85110
return EvaluationRow(
86111
messages=messages,
87112
tools=tools,
88113
input_metadata=InputMetadata(
114+
row_id=row_id,
89115
session_data={
90116
"langfuse_trace_id": trace.id, # Store the trace ID here
91-
}
117+
},
92118
),
119+
execution_metadata=execution_metadata,
93120
)
94121

95122
except (AttributeError, ValueError, KeyError) as e:
@@ -259,9 +286,6 @@ def get_evaluation_rows(
259286
max_retries: int = 3,
260287
span_name: Optional[str] = None,
261288
converter: Optional[TraceConverter] = None,
262-
metadata: Optional[Dict[str, Any]] = None,
263-
requester_metadata: Optional[Dict[str, Any]] = None,
264-
requester_metadata_contains: Optional[str] = None,
265289
) -> List[EvaluationRow]:
266290
"""Pull traces from Langfuse and convert to EvaluationRow format.
267291
@@ -296,10 +320,6 @@ def get_evaluation_rows(
296320
to_timestamp = datetime.now()
297321
from_timestamp = to_timestamp - timedelta(hours=hours_back)
298322

299-
# If filtering by metadata/requester_metadata, prefer fetching metadata fields
300-
if (metadata is not None or requester_metadata is not None or requester_metadata_contains) and not fields:
301-
fields = "core,metadata,observations"
302-
303323
# Collect trace summaries via pagination (up to limit)
304324
all_traces = []
305325
page = 1
@@ -332,16 +352,18 @@ def get_evaluation_rows(
332352
to_timestamp=to_timestamp,
333353
order_by="timestamp.desc",
334354
)
355+
356+
# If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse
357+
if traces and traces.meta and traces.meta.total_items == 0 and page == 1:
358+
raise Exception("Empty results - indexing delay")
359+
335360
break
336361
except Exception as e:
337362
list_retries += 1
338-
if "429" in str(e) and list_retries < max_retries:
363+
if list_retries < max_retries and ("429" in str(e) or "Empty results" in str(e)):
339364
sleep_time = 2**list_retries # Exponential backoff
340365
logger.warning(
341-
"Rate limit hit on trace.list(), retrying in %ds (attempt %d/%d)",
342-
sleep_time,
343-
list_retries,
344-
max_retries,
366+
"Retrying in %ds (attempt %d/%d): %s", sleep_time, list_retries, max_retries, str(e)
345367
)
346368
time.sleep(sleep_time)
347369
else:
@@ -379,74 +401,6 @@ def get_evaluation_rows(
379401
selected_traces = all_traces
380402
logger.debug("Processing all %d collected traces (no sampling)", len(all_traces))
381403

382-
# Helper to check if a trace matches provided metadata filters. We look in multiple places
383-
# to account for Langfuse moving fields (e.g., metadata vs requester_metadata) and SDK shape.
384-
def _trace_matches_metadata_filters(trace_obj: Any) -> bool:
385-
if metadata is None and requester_metadata is None:
386-
return True
387-
388-
def _as_dict(val: Any) -> Dict[str, Any]:
389-
if val is None:
390-
return {}
391-
if isinstance(val, dict):
392-
return val
393-
# Some SDK objects expose .model_dump() or behave like pydantic models
394-
dump = getattr(val, "model_dump", None)
395-
if callable(dump):
396-
try:
397-
return dump() # type: ignore[no-any-return]
398-
except Exception:
399-
return {}
400-
return {}
401-
402-
# Try common locations for metadata on full trace
403-
trace_meta = _as_dict(getattr(trace_obj, "metadata", None))
404-
trace_req_meta = _as_dict(getattr(trace_obj, "requester_metadata", None))
405-
# Some Langfuse deployments nest requester_metadata inside metadata
406-
nested_req_meta = {}
407-
try:
408-
if isinstance(trace_meta, dict) and isinstance(trace_meta.get("requester_metadata"), dict):
409-
nested_req_meta = _as_dict(trace_meta.get("requester_metadata"))
410-
except Exception:
411-
nested_req_meta = {}
412-
413-
# Fallbacks: sometimes metadata is embedded in input
414-
input_meta = {}
415-
try:
416-
inp = getattr(trace_obj, "input", None)
417-
if isinstance(inp, dict):
418-
input_meta = _as_dict(inp.get("metadata"))
419-
except Exception:
420-
input_meta = {}
421-
422-
# Combine for matching convenience (later keys override earlier for equality check only)
423-
combined_meta = {**trace_meta, **input_meta}
424-
combined_req_meta = {**trace_req_meta}
425-
426-
# Also merge nested requester metadata when present
427-
if nested_req_meta:
428-
combined_req_meta = {**combined_req_meta, **nested_req_meta}
429-
430-
def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool:
431-
for k, v in needle.items():
432-
if haystack.get(k) != v:
433-
return False
434-
return True
435-
436-
ok_meta = True
437-
ok_req_meta = True
438-
439-
if metadata is not None:
440-
# Accept match if found either in metadata or requester_metadata buckets
441-
ok_meta = _is_subset(metadata, combined_meta) or _is_subset(metadata, combined_req_meta)
442-
443-
if requester_metadata is not None:
444-
ok_req_meta = _is_subset(requester_metadata, combined_req_meta) or _is_subset(
445-
requester_metadata, combined_meta
446-
)
447-
448-
return ok_meta and ok_req_meta
449-
450404
# Process each selected trace with sleep and retry logic
451405
for trace_info in selected_traces:
452406
# Sleep between gets to avoid rate limits
@@ -483,39 +437,6 @@ def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool:
483437
break # Skip this trace
484438

485439
if trace_full:
486-
# If metadata filters are provided, skip non-matching traces early
487-
try:
488-
if not _trace_matches_metadata_filters(trace_full):
489-
continue
490-
except Exception:
491-
# Be permissive on filter errors; treat as non-match
492-
continue
493-
494-
# If observations carry requester_metadata, allow substring filtering
495-
if requester_metadata_contains:
496-
contains_val = requester_metadata_contains
497-
found_match = False
498-
try:
499-
for obs in getattr(trace_full, "observations", []) or []:
500-
obs_rmd = getattr(obs, "requester_metadata", None)
501-
if isinstance(obs_rmd, dict) and any(
502-
(isinstance(v, str) and contains_val in v) for v in obs_rmd.values()
503-
):
504-
found_match = True
505-
break
506-
obs_md = getattr(obs, "metadata", None)
507-
if isinstance(obs_md, dict):
508-
nested = obs_md.get("requester_metadata")
509-
if isinstance(nested, dict) and any(
510-
(isinstance(v, str) and contains_val in v) for v in nested.values()
511-
):
512-
found_match = True
513-
break
514-
except Exception:
515-
found_match = False
516-
if not found_match:
517-
continue
518-
519440
try:
520441
if converter:
521442
eval_row = converter(trace_full, include_tool_calls, span_name)

eval_protocol/pytest/evaluation_test.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def evaluation_test(
7272
input_dataset: Sequence[DatasetPathParam] | None = None,
7373
input_rows: Sequence[list[EvaluationRow]] | None = None,
7474
data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None,
75-
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
75+
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter,
7676
rollout_processor: RolloutProcessor | None = None,
7777
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None,
7878
rollout_processor_kwargs: RolloutProcessorInputParam | None = None,
@@ -418,9 +418,7 @@ async def _execute_groupwise_eval_with_semaphore(
418418
all_results[run_idx] = results
419419
elif mode == "groupwise":
420420
# rollout all the completion_params for the same row at once, and then send the output to the test_func
421-
row_groups = defaultdict( # pyright: ignore[reportUnknownVariableType]
422-
list
423-
) # key: row_id, value: list of rollout_result
421+
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
424422
tasks: list[asyncio.Task[list[EvaluationRow]]] = []
425423
# completion_groups = []
426424
for idx, cp in enumerate(original_completion_params):
@@ -435,13 +433,13 @@ async def _execute_groupwise_eval_with_semaphore(
435433
)
436434
lst = []
437435

438-
async def _collect_result(config, lst): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
436+
async def _collect_result(config, lst):
439437
result = []
440438
async for row in rollout_processor_with_retry(
441439
rollout_processor, lst, config, run_idx
442440
): # pyright: ignore[reportUnknownArgumentType]
443-
result.append(row) # pyright: ignore[reportUnknownMemberType]
444-
return result # pyright: ignore[reportUnknownVariableType]
441+
result.append(row)
442+
return result
445443

446444
for ori_row in fresh_dataset:
447445
copied_row = ori_row.model_copy(deep=True)
@@ -450,32 +448,32 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
450448
str(ori_row.execution_metadata.rollout_id) + "_" + str(idx)
451449
)
452450
copied_row.input_metadata.completion_params = cp if cp is not None else {}
453-
lst.append(copied_row) # pyright: ignore[reportUnknownMemberType]
454-
tasks.append(asyncio.create_task(_collect_result(config, lst))) # pyright: ignore[reportUnknownArgumentType]
451+
lst.append(copied_row)
452+
tasks.append(asyncio.create_task(_collect_result(config, lst)))
455453
rollout_results = await asyncio.gather(*tasks)
456454
for result in rollout_results:
457455
for row in result:
458-
row_groups[row.input_metadata.row_id].append(row) # pyright: ignore[reportUnknownMemberType]
456+
row_groups[row.input_metadata.row_id].append(row)
459457
tasks = []
460-
for _, rows in row_groups.items(): # pyright: ignore[reportUnknownVariableType]
461-
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) # pyright: ignore[reportUnknownArgumentType]
458+
for _, rows in row_groups.items():
459+
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
462460
results = []
463461
for task in tasks:
464462
res = await task
465-
results.extend(res) # pyright: ignore[reportUnknownMemberType]
463+
results.extend(res)
466464
all_results[run_idx] = results
467465
else:
468466
# Batch mode: collect all results first, then evaluate (no pipelining)
469467
input_dataset = []
470468
async for row in rollout_processor_with_retry(
471469
rollout_processor, fresh_dataset, config, run_idx
472470
):
473-
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
471+
input_dataset.append(row)
474472
# NOTE: we will still evaluate errored rows (give users control over this)
475473
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
476474
results = await execute_pytest(
477475
test_func,
478-
processed_dataset=input_dataset, # pyright: ignore[reportUnknownArgumentType]
476+
processed_dataset=input_dataset,
479477
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
480478
)
481479
if (
@@ -538,16 +536,16 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
538536
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
539537
# rollout_id is used to differentiate the result from different completion_params
540538
if mode == "groupwise":
541-
results_by_group = [ # pyright: ignore[reportUnknownVariableType]
539+
results_by_group = [
542540
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params))
543541
]
544542
for i_run, result in enumerate(all_results):
545543
for r in result:
546544
completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess]
547-
results_by_group[completion_param_idx][i_run].append(r) # pyright: ignore[reportUnknownMemberType]
548-
for rollout_id, result in enumerate(results_by_group): # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
545+
results_by_group[completion_param_idx][i_run].append(r)
546+
for rollout_id, result in enumerate(results_by_group):
549547
postprocess(
550-
result, # pyright: ignore[reportUnknownArgumentType]
548+
result,
551549
aggregation_method,
552550
passed_threshold,
553551
active_logger,
@@ -599,7 +597,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
599597
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
600598

601599
# Create the dual mode wrapper
602-
dual_mode_wrapper = create_dual_mode_wrapper( # pyright: ignore[reportUnknownVariableType]
600+
dual_mode_wrapper = create_dual_mode_wrapper(
603601
test_func, mode, max_concurrent_rollouts, max_concurrent_evaluations, pytest_wrapper
604602
)
605603

0 commit comments

Comments
 (0)