Skip to content

Commit 668652d

Browse files
authored
Update LLM Judge Example and Adapter (#175)
* fix langfuse rate limit issue * to revert later, get 50 random traces to query * don't skip * make judgment async * bump limit up * lower concurrency for gemini * small limit to see if we get the error still * test * test * try this * fix * fix * no split * ok wtf * try something else * test * 1 run * same as aime now * try osmething else * remove gpt * gpt * try to mute and see what happens * monkey patch * try * broken still * how about 2 and 4 * fix single turn rollout acompletion * add back * test repro * add * undo weird changes i made * big run with kimi judge * lol * add timing filter * unique traces * update adapter
1 parent 908d14a commit 668652d

File tree

4 files changed

+241
-170
lines changed

4 files changed

+241
-170
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 128 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from langfuse.api.resources.commons.types.observations_view import ObservationsView
88
import logging
9+
import random
10+
import time
911
from datetime import datetime, timedelta
1012
from typing import Any, Dict, Iterator, List, Optional, cast
1113

@@ -59,54 +61,154 @@ def __init__(self):
5961
def get_evaluation_rows(
6062
self,
6163
limit: int = 100,
64+
sample_size: int = 50,
6265
tags: Optional[List[str]] = None,
6366
user_id: Optional[str] = None,
6467
session_id: Optional[str] = None,
6568
hours_back: Optional[int] = None,
69+
from_timestamp: Optional[datetime] = None,
70+
to_timestamp: Optional[datetime] = None,
6671
include_tool_calls: bool = True,
72+
sleep_between_gets: float = 2.5,
73+
max_retries: int = 3,
6774
) -> List[EvaluationRow]:
6875
"""Pull traces from Langfuse and convert to EvaluationRow format.
6976
7077
Args:
71-
limit: Maximum number of rows to return
78+
limit: Max number of trace summaries to collect via pagination (pre-sampling)
79+
sample_size: Number of traces to fetch full details for (sampled from collected summaries)
7280
tags: Filter by specific tags
7381
user_id: Filter by user ID
7482
session_id: Filter by session ID
7583
hours_back: Filter traces from this many hours ago
84+
from_timestamp: Explicit start time (overrides hours_back)
85+
to_timestamp: Explicit end time (overrides hours_back)
7686
include_tool_calls: Whether to include tool calling traces
87+
sleep_between_gets: Sleep time between individual trace.get() calls (2.5s for 30 req/min limit)
88+
max_retries: Maximum retries for rate limit errors
7789
78-
Yields:
79-
EvaluationRow: Converted evaluation rows
90+
Returns:
91+
List[EvaluationRow]: Converted evaluation rows
8092
"""
81-
# Get traces from Langfuse using new API
93+
eval_rows = []
8294

83-
if hours_back:
95+
# Determine time window: explicit from/to takes precedence over hours_back
96+
if from_timestamp is None and to_timestamp is None and hours_back:
8497
to_timestamp = datetime.now()
8598
from_timestamp = to_timestamp - timedelta(hours=hours_back)
86-
else:
87-
to_timestamp = None
88-
from_timestamp = None
8999

90-
eval_rows = []
100+
# Collect trace summaries via pagination (up to limit)
101+
all_traces = []
102+
page = 1
103+
collected = 0
91104

92-
traces: Traces = self.client.api.trace.list(
93-
limit=limit,
94-
tags=tags,
95-
user_id=user_id,
96-
session_id=session_id,
97-
from_timestamp=from_timestamp,
98-
to_timestamp=to_timestamp,
99-
)
105+
while collected < limit:
106+
current_page_limit = min(100, limit - collected) # Langfuse API max is 100
100107

101-
for trace in traces.data:
102-
try:
103-
trace: TraceWithFullDetails = self.client.api.trace.get(trace.id)
104-
eval_row = self._convert_trace_to_evaluation_row(trace, include_tool_calls)
105-
if eval_row:
106-
eval_rows.append(eval_row)
107-
except (AttributeError, ValueError, KeyError) as e:
108-
logger.warning("Failed to convert trace %s: %s", trace.id, e)
109-
continue
108+
logger.debug(
109+
"Fetching page %d with limit %d (collected: %d/%d)", page, current_page_limit, collected, limit
110+
)
111+
112+
# Fetch trace list with retry logic
113+
traces = None
114+
list_retries = 0
115+
while list_retries < max_retries:
116+
try:
117+
traces = self.client.api.trace.list(
118+
page=page,
119+
limit=current_page_limit,
120+
tags=tags,
121+
user_id=user_id,
122+
session_id=session_id,
123+
from_timestamp=from_timestamp,
124+
to_timestamp=to_timestamp,
125+
order_by="timestamp.desc",
126+
)
127+
break
128+
except Exception as e:
129+
list_retries += 1
130+
if "429" in str(e) and list_retries < max_retries:
131+
sleep_time = 2**list_retries # Exponential backoff
132+
logger.warning(
133+
"Rate limit hit on trace.list(), retrying in %ds (attempt %d/%d)",
134+
sleep_time,
135+
list_retries,
136+
max_retries,
137+
)
138+
time.sleep(sleep_time)
139+
else:
140+
logger.error("Failed to fetch trace list after %d retries: %s", max_retries, e)
141+
return eval_rows # Return what we have so far
142+
143+
if not traces or not traces.data:
144+
logger.debug("No more traces found on page %d", page)
145+
break
146+
147+
logger.debug("Collected %d traces from page %d", len(traces.data), page)
148+
149+
all_traces.extend(traces.data)
150+
collected += len(traces.data)
151+
152+
# Check if we have more pages
153+
if hasattr(traces.meta, "page") and hasattr(traces.meta, "total_pages"):
154+
if traces.meta.page >= traces.meta.total_pages:
155+
break
156+
elif len(traces.data) < current_page_limit:
157+
break
158+
159+
page += 1
160+
161+
if not all_traces:
162+
logger.debug("No traces found")
163+
return eval_rows
164+
165+
# Randomly sample traces to fetch full details (respect rate limits)
166+
actual_sample_size = min(sample_size, len(all_traces))
167+
selected_traces = random.sample(all_traces, actual_sample_size)
168+
169+
logger.debug("Randomly selected %d traces from %d collected", actual_sample_size, len(all_traces))
170+
171+
# Process each selected trace with sleep and retry logic
172+
for trace_info in selected_traces:
173+
# Sleep between gets to avoid rate limits
174+
if sleep_between_gets > 0:
175+
time.sleep(sleep_between_gets)
176+
177+
# Fetch full trace details with retry logic
178+
trace_full = None
179+
detail_retries = 0
180+
while detail_retries < max_retries:
181+
try:
182+
trace_full = self.client.api.trace.get(trace_info.id)
183+
break
184+
except Exception as e:
185+
detail_retries += 1
186+
if "429" in str(e) and detail_retries < max_retries:
187+
sleep_time = 2**detail_retries # Exponential backoff
188+
logger.warning(
189+
"Rate limit hit on trace.get(%s), retrying in %ds (attempt %d/%d)",
190+
trace_info.id,
191+
sleep_time,
192+
detail_retries,
193+
max_retries,
194+
)
195+
time.sleep(sleep_time)
196+
else:
197+
logger.warning("Failed to fetch trace %s after %d retries: %s", trace_info.id, max_retries, e)
198+
break # Skip this trace
199+
200+
if trace_full:
201+
try:
202+
eval_row = self._convert_trace_to_evaluation_row(trace_full, include_tool_calls)
203+
if eval_row:
204+
eval_rows.append(eval_row)
205+
except (AttributeError, ValueError, KeyError) as e:
206+
logger.warning("Failed to convert trace %s: %s", trace_info.id, e)
207+
continue
208+
209+
logger.info(
210+
"Successfully processed %d selected traces into %d evaluation rows", len(selected_traces), len(eval_rows)
211+
)
110212
return eval_rows
111213

112214
def get_evaluation_rows_by_ids(

eval_protocol/mcp/execution/policy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,12 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
194194
request_params["tools"] = tools
195195

196196
try:
197-
response = await acompletion(model=self.model_id, **request_params)
197+
response = await acompletion(
198+
model=self.model_id,
199+
**request_params,
200+
# api_base="https://litellm-cloud-proxy-prod-zfdbl7ykrq-uc.a.run.app/v1",
201+
# extra_body={"tags": ["kimi-k2-tau-bench"]},
202+
)
198203

199204
# Log cache hit/miss for monitoring
200205
hidden = getattr(response, "_hidden_params", {})

eval_protocol/quickstart/llm_judge.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import os
6+
from datetime import datetime
67
from typing import List, Dict, Any, Optional
78
from tqdm import tqdm
89

@@ -14,32 +15,44 @@
1415
from eval_protocol.quickstart.utils import (
1516
split_multi_turn_rows,
1617
JUDGE_CONFIGS,
17-
fetch_langfuse_traces_as_evaluation_rows,
1818
calculate_bootstrap_scores,
1919
push_scores_to_langfuse,
20-
run_judgment,
20+
run_judgment_async,
2121
)
22+
import asyncio
23+
from openai import AsyncOpenAI
24+
from eval_protocol.adapters.langfuse import create_langfuse_adapter
2225

23-
import concurrent.futures
24-
from concurrent.futures import ThreadPoolExecutor
26+
adapter = create_langfuse_adapter()
2527

2628

27-
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
2829
@pytest.mark.asyncio
2930
@evaluation_test(
30-
input_rows=[fetch_langfuse_traces_as_evaluation_rows()],
31+
input_rows=[
32+
adapter.get_evaluation_rows(
33+
to_timestamp=datetime(2025, 9, 12, 0, 11, 18),
34+
limit=711,
35+
sample_size=50,
36+
sleep_between_gets=3.0,
37+
max_retries=5,
38+
)
39+
],
3140
completion_params=[
41+
{"model": "gpt-4.1"},
3242
{
33-
"model": "fireworks_ai/accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
43+
"max_tokens": 131000,
44+
"extra_body": {"reasoning_effort": "medium"},
45+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
3446
},
3547
{
3648
"max_tokens": 131000,
3749
"extra_body": {"reasoning_effort": "low"},
38-
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
50+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b",
3951
},
4052
],
4153
rollout_processor=SingleTurnRolloutProcessor(),
4254
preprocess_fn=split_multi_turn_rows,
55+
max_concurrent_rollouts=64,
4356
mode="all",
4457
)
4558
async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]:
@@ -73,11 +86,21 @@ async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]:
7386
judgments = []
7487
max_concurrency = JUDGE_CONFIGS[judge_name]["max_concurrency"]
7588

76-
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
77-
futures = [executor.submit(run_judgment, row, model_name, judge_name) for row in rows]
89+
judge_config = JUDGE_CONFIGS[judge_name]
90+
91+
async with AsyncOpenAI(
92+
api_key=judge_config.get("api_key"), base_url=judge_config.get("base_url")
93+
) as shared_client:
94+
semaphore = asyncio.Semaphore(max_concurrency)
95+
96+
async def run_judgment(row):
97+
async with semaphore:
98+
return await run_judgment_async(row, model_name, judge_name, shared_client)
99+
100+
tasks = [run_judgment(row) for row in rows]
78101

79-
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating judgments"):
80-
result = future.result()
102+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Generating judgments"):
103+
result = await coro
81104
if result and result["games"][0] and result["games"][1]:
82105
judgments.append(result)
83106

0 commit comments

Comments
 (0)