Skip to content

Commit b17cf90

Browse files
author
Dylan Huang
committed
Merge branch 'main' into fix-evaluation-test-type-checks
2 parents 3202461 + cbbd407 commit b17cf90

2 files changed

Lines changed: 290 additions & 27 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 248 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import asyncio
2+
import configparser
3+
import functools
24
import inspect
35
import json
46
import math
57
import os
68
import pathlib
9+
import requests
710
import statistics
811
import time
912
from collections import defaultdict
13+
from pathlib import Path
1014
from typing import Any, Callable
1115

1216
import pytest
1317

18+
1419
from eval_protocol.dataset_logger import default_logger
1520
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1621
from eval_protocol.human_id import generate_id, num_combinations
@@ -62,6 +67,42 @@
6267

6368
from ..common_utils import load_jsonl
6469

70+
from pytest import StashKey
71+
from typing_extensions import Literal
72+
73+
74+
EXPERIMENT_LINKS_STASH_KEY = StashKey[list]()
75+
76+
77+
def _store_experiment_link(experiment_id: str, job_link: str, status: Literal["success", "failure"]):
78+
"""Store experiment link in pytest session stash."""
79+
try:
80+
import sys
81+
82+
# Walk up the call stack to find the pytest session
83+
session = None
84+
frame = sys._getframe()
85+
while frame:
86+
if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"):
87+
session = frame.f_locals["session"]
88+
break
89+
frame = frame.f_back
90+
91+
if session is not None:
92+
global EXPERIMENT_LINKS_STASH_KEY
93+
94+
if EXPERIMENT_LINKS_STASH_KEY not in session.stash:
95+
session.stash[EXPERIMENT_LINKS_STASH_KEY] = []
96+
97+
session.stash[EXPERIMENT_LINKS_STASH_KEY].append(
98+
{"experiment_id": experiment_id, "job_link": job_link, "status": status}
99+
)
100+
else:
101+
pass
102+
103+
except Exception as e:
104+
pass
105+
65106

66107
def postprocess(
67108
all_results: list[list[EvaluationRow]],
@@ -213,22 +254,193 @@ def postprocess(
213254
# Do not fail evaluation if summary writing fails
214255
pass
215256

216-
# # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary
217-
# try:
218-
# if active_logger is not None:
219-
# rows = active_logger.read()
220-
# # Write to a .jsonl file alongside the summary file
221-
# jsonl_path = "logs.jsonl"
222-
# import json
223-
224-
# with open(jsonl_path, "w", encoding="utf-8") as f_jsonl:
225-
# for row in rows:
226-
# json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl)
227-
# f_jsonl.write("\n")
228-
# except Exception as e:
229-
# # Do not fail evaluation if log writing fails
230-
# print(e)
231-
# pass
257+
try:
258+
# Default is to save and upload experiment JSONL files, unless explicitly disabled
259+
should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1"
260+
261+
if should_save_and_upload:
262+
current_run_rows = [item for sublist in all_results for item in sublist]
263+
if current_run_rows:
264+
experiments: Dict[str, List[EvaluationRow]] = defaultdict(list)
265+
for row in current_run_rows:
266+
if row.execution_metadata and row.execution_metadata.experiment_id:
267+
experiments[row.execution_metadata.experiment_id].append(row)
268+
269+
exp_dir = pathlib.Path("experiment_results")
270+
exp_dir.mkdir(parents=True, exist_ok=True)
271+
272+
# Create one JSONL file per experiment_id
273+
for experiment_id, exp_rows in experiments.items():
274+
if not experiment_id or not exp_rows:
275+
continue
276+
277+
# Generate dataset name (sanitize for Fireworks API compatibility)
278+
# API requires: lowercase a-z, 0-9, and hyphen (-) only
279+
safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower()
280+
safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower()
281+
dataset_name = f"{safe_test_func_name}-{safe_experiment_id}"
282+
283+
if len(dataset_name) > 63:
284+
dataset_name = dataset_name[:63]
285+
286+
exp_file = exp_dir / f"{experiment_id}.jsonl"
287+
with open(exp_file, "w", encoding="utf-8") as f:
288+
for row in exp_rows:
289+
row_data = row.model_dump(exclude_none=True, mode="json")
290+
291+
if row.evaluation_result:
292+
row_data["evals"] = {"score": row.evaluation_result.score}
293+
294+
row_data["eval_details"] = {
295+
"score": row.evaluation_result.score,
296+
"is_score_valid": row.evaluation_result.is_score_valid,
297+
"reason": row.evaluation_result.reason or "",
298+
"metrics": {
299+
name: metric.model_dump() if metric else {}
300+
for name, metric in (row.evaluation_result.metrics or {}).items()
301+
},
302+
}
303+
else:
304+
# Default values if no evaluation result
305+
row_data["evals"] = {"score": 0}
306+
row_data["eval_details"] = {
307+
"score": 0,
308+
"is_score_valid": True,
309+
"reason": "No evaluation result",
310+
"metrics": {},
311+
}
312+
313+
json.dump(row_data, f, ensure_ascii=False)
314+
f.write("\n")
315+
316+
def get_auth_value(key):
317+
"""Get auth value from config file or environment."""
318+
try:
319+
config_path = Path.home() / ".fireworks" / "auth.ini"
320+
if config_path.exists():
321+
config = configparser.ConfigParser()
322+
config.read(config_path)
323+
for section in ["DEFAULT", "auth"]:
324+
if config.has_section(section) and config.has_option(section, key):
325+
return config.get(section, key)
326+
except Exception:
327+
pass
328+
return os.getenv(key)
329+
330+
fireworks_api_key = get_auth_value("FIREWORKS_API_KEY")
331+
fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID")
332+
333+
if not fireworks_api_key and not fireworks_account_id:
334+
_store_experiment_link(
335+
experiment_id,
336+
"No Fireworks API key AND account ID found",
337+
"failure",
338+
)
339+
continue
340+
elif not fireworks_api_key:
341+
_store_experiment_link(
342+
experiment_id,
343+
"No Fireworks API key found",
344+
"failure",
345+
)
346+
continue
347+
elif not fireworks_account_id:
348+
_store_experiment_link(
349+
experiment_id,
350+
"No Fireworks account ID found",
351+
"failure",
352+
)
353+
continue
354+
355+
headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"}
356+
357+
# Make dataset first
358+
dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets"
359+
360+
dataset_payload = {
361+
"dataset": {
362+
"displayName": dataset_name,
363+
"evalProtocol": {},
364+
"format": "FORMAT_UNSPECIFIED",
365+
"exampleCount": f"{len(exp_rows)}",
366+
},
367+
"datasetId": dataset_name,
368+
}
369+
370+
dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers)
371+
372+
# Skip if dataset creation failed
373+
if dataset_response.status_code not in [200, 201]:
374+
_store_experiment_link(
375+
experiment_id,
376+
f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}",
377+
"failure",
378+
)
379+
continue
380+
381+
dataset_data = dataset_response.json()
382+
dataset_id = dataset_data.get("datasetId", dataset_name)
383+
384+
# Upload the JSONL file content
385+
upload_url = (
386+
f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
387+
)
388+
upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"}
389+
390+
with open(exp_file, "rb") as f:
391+
files = {"file": f}
392+
upload_response = requests.post(upload_url, files=files, headers=upload_headers)
393+
394+
# Skip if upload failed
395+
if upload_response.status_code not in [200, 201]:
396+
_store_experiment_link(
397+
experiment_id,
398+
f"File upload failed: {upload_response.status_code} {upload_response.text}",
399+
"failure",
400+
)
401+
continue
402+
403+
# Create evaluation job (optional - don't skip experiment if this fails)
404+
eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs"
405+
# Truncate job ID to fit 63 character limit
406+
job_id_base = f"{dataset_name}-job"
407+
if len(job_id_base) > 63:
408+
# Keep the "-job" suffix and truncate the dataset_name part
409+
max_dataset_name_len = 63 - 4 # 4 = len("-job")
410+
truncated_dataset_name = dataset_name[:max_dataset_name_len]
411+
job_id_base = f"{truncated_dataset_name}-job"
412+
413+
eval_job_payload = {
414+
"evaluationJobId": job_id_base,
415+
"evaluationJob": {
416+
"evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy",
417+
"inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy",
418+
"outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}",
419+
},
420+
}
421+
422+
eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers)
423+
424+
if eval_response.status_code in [200, 201]:
425+
eval_job_data = eval_response.json()
426+
job_id = eval_job_data.get("evaluationJobId", job_id_base)
427+
428+
_store_experiment_link(
429+
experiment_id,
430+
f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}",
431+
"success",
432+
)
433+
else:
434+
_store_experiment_link(
435+
experiment_id,
436+
f"Job creation failed: {eval_response.status_code} {eval_response.text}",
437+
"failure",
438+
)
439+
440+
except Exception as e:
441+
# Do not fail evaluation if experiment JSONL writing fails
442+
print(f"Warning: Failed to persist results: {e}")
443+
pass
232444

233445
# Check threshold after logging
234446
if threshold is not None and not passed:
@@ -354,15 +566,26 @@ def decorator(
354566
validate_signature(sig, mode, completion_params)
355567

356568
# Calculate all possible combinations of parameters
357-
combinations = generate_parameter_combinations(
358-
input_dataset,
359-
completion_params,
360-
input_messages,
361-
input_rows,
362-
evaluation_test_kwargs,
363-
max_dataset_rows,
364-
combine_datasets,
365-
)
569+
if mode == "groupwise":
570+
combinations = generate_parameter_combinations(
571+
input_dataset,
572+
completion_params,
573+
input_messages,
574+
input_rows,
575+
evaluation_test_kwargs,
576+
max_dataset_rows,
577+
combine_datasets,
578+
)
579+
else:
580+
combinations = generate_parameter_combinations(
581+
input_dataset,
582+
completion_params,
583+
input_messages,
584+
input_rows,
585+
evaluation_test_kwargs,
586+
max_dataset_rows,
587+
combine_datasets,
588+
)
366589
if len(combinations) == 0:
367590
raise ValueError(
368591
"No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows."
@@ -710,7 +933,6 @@ def create_dual_mode_wrapper() -> Callable:
710933
Returns:
711934
A callable that can handle both pytest test execution and direct function calls
712935
"""
713-
import asyncio
714936

715937
# Check if the test function is async
716938
is_async = asyncio.iscoroutinefunction(test_func)
@@ -757,7 +979,6 @@ async def dual_mode_wrapper(*args, **kwargs):
757979
}
758980

759981
# Copy all attributes from the pytest wrapper to our dual mode wrapper
760-
import functools
761982

762983
functools.update_wrapper(dual_mode_wrapper, pytest_wrapper)
763984

eval_protocol/pytest/plugin.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from typing import Optional
1818
import json
1919
import pathlib
20+
import sys
21+
from pytest import StashKey
2022

2123

2224
def pytest_addoption(parser) -> None:
@@ -104,6 +106,15 @@ def pytest_addoption(parser) -> None:
104106
"Pass a float >= 0.0 (e.g., 0.05). If only this is set, success threshold defaults to 0.0."
105107
),
106108
)
109+
group.addoption(
110+
"--ep-no-upload",
111+
action="store_true",
112+
default=False,
113+
help=(
114+
"Disable saving and uploading of detailed experiment JSON files to Fireworks. "
115+
"Default: false (experiment JSONs are saved and uploaded by default)."
116+
),
117+
)
107118

108119

109120
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -229,6 +240,9 @@ def pytest_configure(config) -> None:
229240
if threshold_env is not None:
230241
os.environ["EP_PASSED_THRESHOLD"] = threshold_env
231242

243+
if config.getoption("--ep-no-upload"):
244+
os.environ["EP_NO_UPLOAD"] = "1"
245+
232246
# Allow ad-hoc overrides of input params via CLI flags
233247
try:
234248
merged: dict = {}
@@ -263,3 +277,31 @@ def pytest_configure(config) -> None:
263277
except Exception:
264278
# best effort, do not crash pytest session
265279
pass
280+
281+
282+
def pytest_sessionfinish(session, exitstatus):
283+
"""Print all collected Fireworks experiment links from pytest stash."""
284+
try:
285+
from .evaluation_test import EXPERIMENT_LINKS_STASH_KEY
286+
287+
# Get links from pytest stash using shared key
288+
links = []
289+
290+
if EXPERIMENT_LINKS_STASH_KEY in session.stash:
291+
links = session.stash[EXPERIMENT_LINKS_STASH_KEY]
292+
293+
if links:
294+
print("\n" + "=" * 80, file=sys.__stderr__)
295+
print("🔥 FIREWORKS EXPERIMENT LINKS", file=sys.__stderr__)
296+
print("=" * 80, file=sys.__stderr__)
297+
298+
for link in links:
299+
if link["status"] == "success":
300+
print(f"🔗 Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
301+
else:
302+
print(f"❌ Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
303+
304+
print("=" * 80, file=sys.__stderr__)
305+
sys.__stderr__.flush()
306+
except Exception as e:
307+
pass

0 commit comments

Comments
 (0)