|
1 | 1 | import asyncio |
| 2 | +import configparser |
| 3 | +import functools |
2 | 4 | import inspect |
3 | 5 | import json |
4 | 6 | import math |
5 | 7 | import os |
6 | 8 | import pathlib |
| 9 | +import requests |
7 | 10 | import statistics |
8 | 11 | import time |
9 | 12 | from collections import defaultdict |
| 13 | +from pathlib import Path |
10 | 14 | from typing import Any, Callable |
11 | 15 |
|
12 | 16 | import pytest |
13 | 17 |
|
| 18 | + |
14 | 19 | from eval_protocol.dataset_logger import default_logger |
15 | 20 | from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
16 | 21 | from eval_protocol.human_id import generate_id, num_combinations |
|
62 | 67 |
|
63 | 68 | from ..common_utils import load_jsonl |
64 | 69 |
|
| 70 | +from pytest import StashKey |
| 71 | +from typing_extensions import Literal |
| 72 | + |
| 73 | + |
| 74 | +EXPERIMENT_LINKS_STASH_KEY = StashKey[list]() |
| 75 | + |
| 76 | + |
| 77 | +def _store_experiment_link(experiment_id: str, job_link: str, status: Literal["success", "failure"]): |
| 78 | + """Store experiment link in pytest session stash.""" |
| 79 | + try: |
| 80 | + import sys |
| 81 | + |
| 82 | + # Walk up the call stack to find the pytest session |
| 83 | + session = None |
| 84 | + frame = sys._getframe() |
| 85 | + while frame: |
| 86 | + if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"): |
| 87 | + session = frame.f_locals["session"] |
| 88 | + break |
| 89 | + frame = frame.f_back |
| 90 | + |
| 91 | + if session is not None: |
| 92 | + global EXPERIMENT_LINKS_STASH_KEY |
| 93 | + |
| 94 | + if EXPERIMENT_LINKS_STASH_KEY not in session.stash: |
| 95 | + session.stash[EXPERIMENT_LINKS_STASH_KEY] = [] |
| 96 | + |
| 97 | + session.stash[EXPERIMENT_LINKS_STASH_KEY].append( |
| 98 | + {"experiment_id": experiment_id, "job_link": job_link, "status": status} |
| 99 | + ) |
| 100 | + else: |
| 101 | + pass |
| 102 | + |
| 103 | + except Exception as e: |
| 104 | + pass |
| 105 | + |
65 | 106 |
|
66 | 107 | def postprocess( |
67 | 108 | all_results: list[list[EvaluationRow]], |
@@ -213,22 +254,193 @@ def postprocess( |
213 | 254 | # Do not fail evaluation if summary writing fails |
214 | 255 | pass |
215 | 256 |
|
216 | | - # # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary |
217 | | - # try: |
218 | | - # if active_logger is not None: |
219 | | - # rows = active_logger.read() |
220 | | - # # Write to a .jsonl file alongside the summary file |
221 | | - # jsonl_path = "logs.jsonl" |
222 | | - # import json |
223 | | - |
224 | | - # with open(jsonl_path, "w", encoding="utf-8") as f_jsonl: |
225 | | - # for row in rows: |
226 | | - # json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl) |
227 | | - # f_jsonl.write("\n") |
228 | | - # except Exception as e: |
229 | | - # # Do not fail evaluation if log writing fails |
230 | | - # print(e) |
231 | | - # pass |
| 257 | + try: |
| 258 | + # Default is to save and upload experiment JSONL files, unless explicitly disabled |
| 259 | + should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1" |
| 260 | + |
| 261 | + if should_save_and_upload: |
| 262 | + current_run_rows = [item for sublist in all_results for item in sublist] |
| 263 | + if current_run_rows: |
| 264 | + experiments: Dict[str, List[EvaluationRow]] = defaultdict(list) |
| 265 | + for row in current_run_rows: |
| 266 | + if row.execution_metadata and row.execution_metadata.experiment_id: |
| 267 | + experiments[row.execution_metadata.experiment_id].append(row) |
| 268 | + |
| 269 | + exp_dir = pathlib.Path("experiment_results") |
| 270 | + exp_dir.mkdir(parents=True, exist_ok=True) |
| 271 | + |
| 272 | + # Create one JSONL file per experiment_id |
| 273 | + for experiment_id, exp_rows in experiments.items(): |
| 274 | + if not experiment_id or not exp_rows: |
| 275 | + continue |
| 276 | + |
| 277 | + # Generate dataset name (sanitize for Fireworks API compatibility) |
| 278 | + # API requires: lowercase a-z, 0-9, and hyphen (-) only |
| 279 | + safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower() |
| 280 | + safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower() |
| 281 | + dataset_name = f"{safe_test_func_name}-{safe_experiment_id}" |
| 282 | + |
| 283 | + if len(dataset_name) > 63: |
| 284 | + dataset_name = dataset_name[:63] |
| 285 | + |
| 286 | + exp_file = exp_dir / f"{experiment_id}.jsonl" |
| 287 | + with open(exp_file, "w", encoding="utf-8") as f: |
| 288 | + for row in exp_rows: |
| 289 | + row_data = row.model_dump(exclude_none=True, mode="json") |
| 290 | + |
| 291 | + if row.evaluation_result: |
| 292 | + row_data["evals"] = {"score": row.evaluation_result.score} |
| 293 | + |
| 294 | + row_data["eval_details"] = { |
| 295 | + "score": row.evaluation_result.score, |
| 296 | + "is_score_valid": row.evaluation_result.is_score_valid, |
| 297 | + "reason": row.evaluation_result.reason or "", |
| 298 | + "metrics": { |
| 299 | + name: metric.model_dump() if metric else {} |
| 300 | + for name, metric in (row.evaluation_result.metrics or {}).items() |
| 301 | + }, |
| 302 | + } |
| 303 | + else: |
| 304 | + # Default values if no evaluation result |
| 305 | + row_data["evals"] = {"score": 0} |
| 306 | + row_data["eval_details"] = { |
| 307 | + "score": 0, |
| 308 | + "is_score_valid": True, |
| 309 | + "reason": "No evaluation result", |
| 310 | + "metrics": {}, |
| 311 | + } |
| 312 | + |
| 313 | + json.dump(row_data, f, ensure_ascii=False) |
| 314 | + f.write("\n") |
| 315 | + |
| 316 | + def get_auth_value(key): |
| 317 | + """Get auth value from config file or environment.""" |
| 318 | + try: |
| 319 | + config_path = Path.home() / ".fireworks" / "auth.ini" |
| 320 | + if config_path.exists(): |
| 321 | + config = configparser.ConfigParser() |
| 322 | + config.read(config_path) |
| 323 | + for section in ["DEFAULT", "auth"]: |
| 324 | + if config.has_section(section) and config.has_option(section, key): |
| 325 | + return config.get(section, key) |
| 326 | + except Exception: |
| 327 | + pass |
| 328 | + return os.getenv(key) |
| 329 | + |
| 330 | + fireworks_api_key = get_auth_value("FIREWORKS_API_KEY") |
| 331 | + fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID") |
| 332 | + |
| 333 | + if not fireworks_api_key and not fireworks_account_id: |
| 334 | + _store_experiment_link( |
| 335 | + experiment_id, |
| 336 | + "No Fireworks API key AND account ID found", |
| 337 | + "failure", |
| 338 | + ) |
| 339 | + continue |
| 340 | + elif not fireworks_api_key: |
| 341 | + _store_experiment_link( |
| 342 | + experiment_id, |
| 343 | + "No Fireworks API key found", |
| 344 | + "failure", |
| 345 | + ) |
| 346 | + continue |
| 347 | + elif not fireworks_account_id: |
| 348 | + _store_experiment_link( |
| 349 | + experiment_id, |
| 350 | + "No Fireworks account ID found", |
| 351 | + "failure", |
| 352 | + ) |
| 353 | + continue |
| 354 | + |
| 355 | + headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"} |
| 356 | + |
| 357 | + # Make dataset first |
| 358 | + dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets" |
| 359 | + |
| 360 | + dataset_payload = { |
| 361 | + "dataset": { |
| 362 | + "displayName": dataset_name, |
| 363 | + "evalProtocol": {}, |
| 364 | + "format": "FORMAT_UNSPECIFIED", |
| 365 | + "exampleCount": f"{len(exp_rows)}", |
| 366 | + }, |
| 367 | + "datasetId": dataset_name, |
| 368 | + } |
| 369 | + |
| 370 | + dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers) |
| 371 | + |
| 372 | + # Skip if dataset creation failed |
| 373 | + if dataset_response.status_code not in [200, 201]: |
| 374 | + _store_experiment_link( |
| 375 | + experiment_id, |
| 376 | + f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}", |
| 377 | + "failure", |
| 378 | + ) |
| 379 | + continue |
| 380 | + |
| 381 | + dataset_data = dataset_response.json() |
| 382 | + dataset_id = dataset_data.get("datasetId", dataset_name) |
| 383 | + |
| 384 | + # Upload the JSONL file content |
| 385 | + upload_url = ( |
| 386 | + f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload" |
| 387 | + ) |
| 388 | + upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"} |
| 389 | + |
| 390 | + with open(exp_file, "rb") as f: |
| 391 | + files = {"file": f} |
| 392 | + upload_response = requests.post(upload_url, files=files, headers=upload_headers) |
| 393 | + |
| 394 | + # Skip if upload failed |
| 395 | + if upload_response.status_code not in [200, 201]: |
| 396 | + _store_experiment_link( |
| 397 | + experiment_id, |
| 398 | + f"File upload failed: {upload_response.status_code} {upload_response.text}", |
| 399 | + "failure", |
| 400 | + ) |
| 401 | + continue |
| 402 | + |
| 403 | + # Create evaluation job (optional - don't skip experiment if this fails) |
| 404 | + eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs" |
| 405 | + # Truncate job ID to fit 63 character limit |
| 406 | + job_id_base = f"{dataset_name}-job" |
| 407 | + if len(job_id_base) > 63: |
| 408 | + # Keep the "-job" suffix and truncate the dataset_name part |
| 409 | + max_dataset_name_len = 63 - 4 # 4 = len("-job") |
| 410 | + truncated_dataset_name = dataset_name[:max_dataset_name_len] |
| 411 | + job_id_base = f"{truncated_dataset_name}-job" |
| 412 | + |
| 413 | + eval_job_payload = { |
| 414 | + "evaluationJobId": job_id_base, |
| 415 | + "evaluationJob": { |
| 416 | + "evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy", |
| 417 | + "inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy", |
| 418 | + "outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}", |
| 419 | + }, |
| 420 | + } |
| 421 | + |
| 422 | + eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers) |
| 423 | + |
| 424 | + if eval_response.status_code in [200, 201]: |
| 425 | + eval_job_data = eval_response.json() |
| 426 | + job_id = eval_job_data.get("evaluationJobId", job_id_base) |
| 427 | + |
| 428 | + _store_experiment_link( |
| 429 | + experiment_id, |
| 430 | + f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}", |
| 431 | + "success", |
| 432 | + ) |
| 433 | + else: |
| 434 | + _store_experiment_link( |
| 435 | + experiment_id, |
| 436 | + f"Job creation failed: {eval_response.status_code} {eval_response.text}", |
| 437 | + "failure", |
| 438 | + ) |
| 439 | + |
| 440 | + except Exception as e: |
| 441 | + # Do not fail evaluation if experiment JSONL writing fails |
| 442 | + print(f"Warning: Failed to persist results: {e}") |
| 443 | + pass |
232 | 444 |
|
233 | 445 | # Check threshold after logging |
234 | 446 | if threshold is not None and not passed: |
@@ -354,15 +566,26 @@ def decorator( |
354 | 566 | validate_signature(sig, mode, completion_params) |
355 | 567 |
|
356 | 568 | # Calculate all possible combinations of parameters |
357 | | - combinations = generate_parameter_combinations( |
358 | | - input_dataset, |
359 | | - completion_params, |
360 | | - input_messages, |
361 | | - input_rows, |
362 | | - evaluation_test_kwargs, |
363 | | - max_dataset_rows, |
364 | | - combine_datasets, |
365 | | - ) |
| 569 | + if mode == "groupwise": |
| 570 | + combinations = generate_parameter_combinations( |
| 571 | + input_dataset, |
| 572 | + completion_params, |
| 573 | + input_messages, |
| 574 | + input_rows, |
| 575 | + evaluation_test_kwargs, |
| 576 | + max_dataset_rows, |
| 577 | + combine_datasets, |
| 578 | + ) |
| 579 | + else: |
| 580 | + combinations = generate_parameter_combinations( |
| 581 | + input_dataset, |
| 582 | + completion_params, |
| 583 | + input_messages, |
| 584 | + input_rows, |
| 585 | + evaluation_test_kwargs, |
| 586 | + max_dataset_rows, |
| 587 | + combine_datasets, |
| 588 | + ) |
366 | 589 | if len(combinations) == 0: |
367 | 590 | raise ValueError( |
368 | 591 | "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows." |
@@ -710,7 +933,6 @@ def create_dual_mode_wrapper() -> Callable: |
710 | 933 | Returns: |
711 | 934 | A callable that can handle both pytest test execution and direct function calls |
712 | 935 | """ |
713 | | - import asyncio |
714 | 936 |
|
715 | 937 | # Check if the test function is async |
716 | 938 | is_async = asyncio.iscoroutinefunction(test_func) |
@@ -757,7 +979,6 @@ async def dual_mode_wrapper(*args, **kwargs): |
757 | 979 | } |
758 | 980 |
|
759 | 981 | # Copy all attributes from the pytest wrapper to our dual mode wrapper |
760 | | - import functools |
761 | 982 |
|
762 | 983 | functools.update_wrapper(dual_mode_wrapper, pytest_wrapper) |
763 | 984 |
|
|
0 commit comments