Skip to content

Commit c276c07

Browse files
committed
fix
1 parent ae0b4c1 commit c276c07

File tree

2 files changed

+140
-24
lines changed

2 files changed

+140
-24
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,37 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
251251
return f"{base}{suffix}"
252252

253253

254+
def _resolve_selected_test(
255+
project_root: str,
256+
evaluator_id: Optional[str],
257+
selected_tests: Optional[list] = None,
258+
) -> tuple[Optional[str], Optional[str]]:
259+
"""
260+
Resolve a single test's source file path and function name to use downstream.
261+
Priority:
262+
1) If selected_tests provided and length == 1, use it.
263+
2) Else discover tests; if exactly one test, use it.
264+
3) Else, if evaluator_id provided, match by normalized '<file-stem>-<func-name>'.
265+
Returns: (file_path, func_name) or (None, None) if unresolved.
266+
"""
267+
try:
268+
tests = selected_tests if selected_tests is not None else _discover_tests(project_root)
269+
if not tests:
270+
return None, None
271+
if len(tests) == 1:
272+
return tests[0].file_path, tests[0].qualname.split(".")[-1]
273+
if evaluator_id:
274+
for t in tests:
275+
func_name = t.qualname.split(".")[-1]
276+
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
277+
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
278+
if candidate == evaluator_id:
279+
return t.file_path, func_name
280+
return None, None
281+
except Exception:
282+
return None, None
283+
284+
254285
def _poll_evaluator_status(
255286
evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10
256287
) -> bool:
@@ -354,12 +385,15 @@ def create_rft_command(args) -> int:
354385
if len(selected_tests) != 1:
355386
print("Error: Please select exactly one evaluation test for 'create rft'.")
356387
return 1
388+
# Derive evaluator_id from user's single selection
357389
chosen = selected_tests[0]
358390
func_name = chosen.qualname.split(".")[-1]
359391
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
360392
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
361-
selected_test_file_path = chosen.file_path
362-
selected_test_func_name = func_name
393+
# Resolve selected test once for downstream
394+
selected_test_file_path, selected_test_func_name = _resolve_selected_test(
395+
project_root, evaluator_id, selected_tests=selected_tests
396+
)
363397
# Resolve evaluator resource name to fully-qualified format required by API
364398
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
365399

@@ -392,6 +426,11 @@ def create_rft_command(args) -> int:
392426
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
393427
return 1
394428
skip_upload = True
429+
# Populate selected test info for dataset inference later
430+
st_path, st_func = _resolve_selected_test(project_root, evaluator_id)
431+
if st_path and st_func:
432+
selected_test_file_path = st_path
433+
selected_test_func_name = st_func
395434
except requests.exceptions.RequestException:
396435
pass
397436

@@ -402,32 +441,16 @@ def create_rft_command(args) -> int:
402441

403442
tests = _discover_tests(project_root)
404443
selected_entry: Optional[str] = None
405-
if len(tests) == 1:
406-
func_name = tests[0].qualname.split(".")[-1]
407-
abs_path = os.path.abspath(tests[0].file_path)
444+
st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests)
445+
if st_path and st_func:
446+
abs_path = os.path.abspath(st_path)
408447
try:
409448
rel = os.path.relpath(abs_path, project_root)
410449
except Exception:
411450
rel = abs_path
412-
selected_entry = f"{rel}::{func_name}"
413-
selected_test_file_path = tests[0].file_path
414-
selected_test_func_name = func_name
415-
else:
416-
# Try to match evaluator_id to a discovered test's normalized ID
417-
for t in tests:
418-
func_name = t.qualname.split(".")[-1]
419-
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
420-
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
421-
if candidate == evaluator_id:
422-
abs_path = os.path.abspath(t.file_path)
423-
try:
424-
rel = os.path.relpath(abs_path, project_root)
425-
except Exception:
426-
rel = abs_path
427-
selected_entry = f"{rel}::{func_name}"
428-
selected_test_file_path = t.file_path
429-
selected_test_func_name = func_name
430-
break
451+
selected_entry = f"{rel}::{st_func}"
452+
selected_test_file_path = st_path
453+
selected_test_func_name = st_func
431454
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
432455
if selected_entry is None and len(tests) > 1:
433456
print(

tests/test_cli_create_rft_infer.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,96 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
615615
assert captured["dataset_id"] is not None
616616
assert captured["dataset_id"].startswith("test-input-ds-test-input-ds-dataset-")
617617
assert captured["jsonl_path"] == str(id_jsonl)
618+
619+
620+
def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(tmp_path, monkeypatch):
621+
# Setup project with multiple tests; evaluator exists (skip upload)
622+
project = tmp_path / "proj"
623+
project.mkdir()
624+
monkeypatch.chdir(project)
625+
626+
# Env
627+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
628+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
629+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
630+
631+
# Two tests discovered
632+
f1 = project / "evals" / "alpha.py"
633+
f2 = project / "evals" / "beta.py"
634+
f1.parent.mkdir(parents=True, exist_ok=True)
635+
f1.write_text("# alpha", encoding="utf-8")
636+
f2.write_text("# beta", encoding="utf-8")
637+
d1 = SimpleNamespace(qualname="alpha.test_one", file_path=str(f1))
638+
d2 = SimpleNamespace(qualname="beta.test_two", file_path=str(f2))
639+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2])
640+
641+
# Evaluator exists and is ACTIVE (skip upload)
642+
class _Resp:
643+
ok = True
644+
645+
def json(self):
646+
return {"state": "ACTIVE"}
647+
648+
def raise_for_status(self):
649+
return None
650+
651+
monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp())
652+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
653+
654+
# We will provide JSONL via input_dataset extractor for matching test (beta.test_two)
655+
jsonl_path = project / "data.jsonl"
656+
jsonl_path.write_text('{"c":3}\n', encoding="utf-8")
657+
658+
# Stub extractors: only the matching test name should matter; our implementation calls extractor with file+func
659+
def _extract_input_jsonl(file_path, func_name):
660+
# Simulate returning JSONL regardless; dataset inference uses the selected test determined by evaluator_id
661+
return str(jsonl_path)
662+
663+
monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: None)
664+
monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", _extract_input_jsonl)
665+
monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None)
666+
667+
captured = {"dataset_id": None, "jsonl_path": None}
668+
669+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
670+
captured["dataset_id"] = dataset_id
671+
captured["jsonl_path"] = jsonl_path
672+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
673+
674+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
675+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
676+
677+
import argparse
678+
679+
# Provide evaluator_id that matches beta.test_two
680+
eval_id = cr._normalize_evaluator_id("beta-test_two")
681+
args = argparse.Namespace(
682+
evaluator_id=eval_id,
683+
yes=True,
684+
dry_run=False,
685+
force=False,
686+
env_file=None,
687+
dataset_id=None,
688+
dataset_jsonl=None,
689+
dataset_display_name=None,
690+
dataset_builder=None,
691+
base_model=None,
692+
warm_start_from="accounts/acct123/models/ft-abc123",
693+
output_model=None,
694+
n=None,
695+
max_tokens=None,
696+
learning_rate=None,
697+
batch_size=None,
698+
epochs=None,
699+
lora_rank=None,
700+
max_context_length=None,
701+
chunk_size=None,
702+
eval_auto_carveout=None,
703+
)
704+
705+
rc = cr.create_rft_command(args)
706+
assert rc == 0
707+
assert captured["dataset_id"] is not None
708+
# Ensure the dataset id is based on evaluator_id
709+
assert captured["dataset_id"].startswith(f"{eval_id}-dataset-")
710+
assert captured["jsonl_path"] == str(jsonl_path)

0 commit comments

Comments
 (0)