Skip to content

Commit 060d72c

Browse files
author
Dylan Huang
committed
fix mock tests
1 parent 5e7a5fa commit 060d72c

File tree

3 files changed

+75
-134
lines changed

3 files changed

+75
-134
lines changed

tests/test_cli_create_rft.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _write_json(path: str, data: dict) -> None:
2424
def stub_fireworks(monkeypatch) -> dict[str, Any]:
2525
"""
2626
Stub Fireworks SDK so tests stay offline and so create_rft.py can inspect a stable
27-
create() signature (it uses inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)).
27+
create() signature (it uses inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create)).
2828
2929
Returns:
3030
A dict containing the last captured create() kwargs under key "kwargs".
@@ -72,12 +72,15 @@ def create(
7272
return SimpleNamespace(name=f"accounts/{account_id}/reinforcementFineTuningJobs/xyz")
7373

7474
class _FakeFW:
75-
def __init__(self, api_key=None, base_url=None):
75+
def __init__(self, api_key=None, base_url=None, account_id=None, default_headers=None):
7676
self.api_key = api_key
7777
self.base_url = base_url
78+
self.account_id = account_id
79+
self.default_headers = default_headers
7880
self.reinforcement_fine_tuning_jobs = _FakeJobs()
7981

80-
monkeypatch.setattr(cr, "Fireworks", _FakeFW)
82+
# Patch create_fireworks_client to return our fake client
83+
monkeypatch.setattr(cr, "create_fireworks_client", lambda **kwargs: _FakeFW(**kwargs))
8184
return captured
8285

8386

tests/test_ep_upload_e2e.py

Lines changed: 51 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def mock_gcs_upload():
8080

8181
@pytest.fixture
8282
def mock_fireworks_client():
83-
"""Mock the Fireworks SDK client used in evaluation.py"""
84-
with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class:
83+
"""Mock the Fireworks SDK client used in fireworks_client.py"""
84+
with patch("eval_protocol.fireworks_client.Fireworks") as mock_fw_class:
8585
mock_client = MagicMock()
8686
mock_fw_class.return_value = mock_client
8787

@@ -92,32 +92,27 @@ def mock_fireworks_client():
9292
mock_create_response.description = "Test description"
9393
mock_client.evaluators.create.return_value = mock_create_response
9494

95-
# Mock evaluators.get_upload_endpoint response - will be set dynamically
96-
def get_upload_endpoint_side_effect(evaluator_id, filename_to_size):
95+
# Mock evaluator_versions.create response
96+
mock_version_response = MagicMock()
97+
mock_version_response.name = "accounts/test_account/evaluators/test-eval/versions/v1"
98+
mock_client.evaluator_versions.create.return_value = mock_version_response
99+
100+
# Mock evaluator_versions.get_upload_endpoint response - will be set dynamically
101+
def get_upload_endpoint_side_effect(evaluator_id, version_id, filename_to_size):
97102
response = MagicMock()
98103
signed_urls = {}
99104
for filename in filename_to_size.keys():
100105
signed_urls[filename] = f"https://storage.googleapis.com/test-bucket/{filename}?signed=true"
101106
response.filename_to_signed_urls = signed_urls
102107
return response
103108

104-
mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect
109+
mock_client.evaluator_versions.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect
105110

106-
# Mock evaluators.validate_upload response
111+
# Mock evaluator_versions.validate_upload response
107112
mock_validate_response = MagicMock()
108113
mock_validate_response.success = True
109114
mock_validate_response.valid = True
110-
mock_client.evaluators.validate_upload.return_value = mock_validate_response
111-
112-
yield mock_client
113-
114-
115-
@pytest.fixture
116-
def mock_platform_api_client():
117-
"""Mock the Fireworks SDK client used in platform_api.py for secrets"""
118-
with patch("eval_protocol.platform_api.Fireworks") as mock_fw_class:
119-
mock_client = MagicMock()
120-
mock_fw_class.return_value = mock_client
115+
mock_client.evaluator_versions.validate_upload.return_value = mock_validate_response
121116

122117
# Mock secrets.get - raise NotFoundError to simulate secret doesn't exist
123118
from fireworks import NotFoundError
@@ -129,13 +124,23 @@ def mock_platform_api_client():
129124
)
130125

131126
# Mock secrets.create - successful
132-
mock_create_response = MagicMock()
133-
mock_create_response.name = "accounts/test_account/secrets/test-secret"
134-
mock_client.secrets.create.return_value = mock_create_response
127+
mock_secrets_create_response = MagicMock()
128+
mock_secrets_create_response.name = "accounts/test_account/secrets/test-secret"
129+
mock_client.secrets.create.return_value = mock_secrets_create_response
135130

136131
yield mock_client
137132

138133

134+
@pytest.fixture
135+
def mock_platform_api_client(mock_fireworks_client):
136+
"""
137+
Mock the Fireworks SDK client for secrets.
138+
This is now just an alias for mock_fireworks_client since both use the same patched location.
139+
The mock_fireworks_client fixture already includes secrets mocking.
140+
"""
141+
yield mock_fireworks_client
142+
143+
139144
def test_ep_upload_discovers_and_uploads_evaluation_test(
140145
mock_env_variables, mock_fireworks_client, mock_platform_api_client, mock_gcs_upload, monkeypatch
141146
):
@@ -219,13 +224,18 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow:
219224
# Step 1: Create evaluator
220225
assert mock_fireworks_client.evaluators.create.called, "Should call evaluators.create"
221226

222-
# Step 2: Get upload endpoint
223-
assert mock_fireworks_client.evaluators.get_upload_endpoint.called, (
224-
"Should call evaluators.get_upload_endpoint"
227+
# Step 1b: Create evaluator version
228+
assert mock_fireworks_client.evaluator_versions.create.called, "Should call evaluator_versions.create"
229+
230+
# Step 2: Get upload endpoint (via evaluator_versions API)
231+
assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, (
232+
"Should call evaluator_versions.get_upload_endpoint"
225233
)
226234

227-
# Step 3: Validate upload
228-
assert mock_fireworks_client.evaluators.validate_upload.called, "Should call evaluators.validate_upload"
235+
# Step 3: Validate upload (via evaluator_versions API)
236+
assert mock_fireworks_client.evaluator_versions.validate_upload.called, (
237+
"Should call evaluator_versions.validate_upload"
238+
)
229239

230240
# Step 4: GCS upload
231241
assert mock_gcs_upload.send.called, "Should upload tar.gz to GCS"
@@ -325,8 +335,9 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow:
325335

326336
# Verify upload flow completed via Fireworks SDK
327337
assert mock_fireworks_client.evaluators.create.called
328-
assert mock_fireworks_client.evaluators.get_upload_endpoint.called
329-
assert mock_fireworks_client.evaluators.validate_upload.called
338+
assert mock_fireworks_client.evaluator_versions.create.called
339+
assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called
340+
assert mock_fireworks_client.evaluator_versions.validate_upload.called
330341
assert mock_gcs_upload.send.called
331342

332343
finally:
@@ -505,17 +516,24 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow:
505516
# Step 1: Create evaluator
506517
assert mock_fireworks_client.evaluators.create.called, "Missing create call"
507518

508-
# Step 2: Get upload endpoint
509-
assert mock_fireworks_client.evaluators.get_upload_endpoint.called, "Missing getUploadEndpoint call"
519+
# Step 1b: Create evaluator version
520+
assert mock_fireworks_client.evaluator_versions.create.called, "Missing evaluator_versions.create call"
521+
522+
# Step 2: Get upload endpoint (via evaluator_versions API)
523+
assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, (
524+
"Missing evaluator_versions.get_upload_endpoint call"
525+
)
510526

511527
# Step 3: Upload to GCS
512528
assert mock_gcs_upload.send.called, "Missing GCS upload"
513529
gcs_request = mock_gcs_upload.send.call_args[0][0]
514530
assert gcs_request.method == "PUT"
515531
assert "storage.googleapis.com" in gcs_request.url
516532

517-
# Step 4: Validate
518-
assert mock_fireworks_client.evaluators.validate_upload.called, "Missing validateUpload call"
533+
# Step 4: Validate (via evaluator_versions API)
534+
assert mock_fireworks_client.evaluator_versions.validate_upload.called, (
535+
"Missing evaluator_versions.validate_upload call"
536+
)
519537

520538
# 4. VERIFY PAYLOAD DETAILS
521539
create_call = mock_fireworks_client.evaluators.create.call_args
@@ -532,8 +550,8 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow:
532550
assert "test_math_eval.py::test_math_correctness" in entry_point
533551

534552
# 5. VERIFY TAR.GZ WAS CREATED AND UPLOADED
535-
# Check getUploadEndpoint call payload
536-
upload_call = mock_fireworks_client.evaluators.get_upload_endpoint.call_args
553+
# Check getUploadEndpoint call payload (via evaluator_versions API)
554+
upload_call = mock_fireworks_client.evaluator_versions.get_upload_endpoint.call_args
537555
assert upload_call is not None
538556
filename_to_size = upload_call.kwargs.get("filename_to_size", {})
539557
assert filename_to_size, "Should have filename_to_size"
@@ -582,95 +600,3 @@ def test_create_tar_includes_dockerignored_files(tmp_path):
582600

583601
for expected_path in expected_paths:
584602
assert expected_path in names, f"Expected {expected_path} in archive"
585-
586-
587-
def test_ep_upload_force_flag_triggers_delete_flow(
588-
mock_env_variables,
589-
mock_gcs_upload,
590-
mock_platform_api_client,
591-
):
592-
"""
593-
Test that --force flag triggers the check/delete/recreate flow
594-
"""
595-
from eval_protocol.cli_commands.upload import upload_command, _discover_tests
596-
597-
test_content = """
598-
from eval_protocol.pytest import evaluation_test
599-
from eval_protocol.models import EvaluationRow
600-
601-
@evaluation_test(input_rows=[[EvaluationRow()]])
602-
async def test_force_eval(row: EvaluationRow) -> EvaluationRow:
603-
return row
604-
"""
605-
606-
test_project_dir, test_file_path = create_test_project_with_evaluation_test(test_content, "test_force.py")
607-
608-
original_cwd = os.getcwd()
609-
610-
try:
611-
os.chdir(test_project_dir)
612-
613-
# Mock the Fireworks client with evaluator existing (for force flow)
614-
with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class:
615-
mock_client = MagicMock()
616-
mock_fw_class.return_value = mock_client
617-
618-
# Mock evaluators.get to return an existing evaluator (not raise NotFoundError)
619-
mock_existing_evaluator = MagicMock()
620-
mock_existing_evaluator.name = "accounts/test_account/evaluators/test-force"
621-
mock_client.evaluators.get.return_value = mock_existing_evaluator
622-
623-
# Mock evaluators.delete
624-
mock_client.evaluators.delete.return_value = None
625-
626-
# Mock evaluators.create response
627-
mock_create_response = MagicMock()
628-
mock_create_response.name = "accounts/test_account/evaluators/test-force"
629-
mock_client.evaluators.create.return_value = mock_create_response
630-
631-
# Mock get_upload_endpoint
632-
def get_upload_endpoint_side_effect(evaluator_id, filename_to_size):
633-
response = MagicMock()
634-
signed_urls = {}
635-
for filename in filename_to_size.keys():
636-
signed_urls[filename] = f"https://storage.googleapis.com/test-bucket/{filename}?signed=true"
637-
response.filename_to_signed_urls = signed_urls
638-
return response
639-
640-
mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect
641-
642-
# Mock validate_upload
643-
mock_client.evaluators.validate_upload.return_value = MagicMock()
644-
645-
discovered_tests = _discover_tests(test_project_dir)
646-
647-
args = argparse.Namespace(
648-
path=test_project_dir,
649-
entry=None,
650-
id="test-force",
651-
display_name=None,
652-
description=None,
653-
force=True, # Force flag enabled
654-
yes=True,
655-
)
656-
657-
with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select:
658-
mock_select.return_value = discovered_tests
659-
exit_code = upload_command(args)
660-
661-
assert exit_code == 0
662-
663-
# Verify check happened (evaluators.get was called)
664-
assert mock_client.evaluators.get.called, "Should check if evaluator exists"
665-
666-
# Verify delete happened (since evaluator existed)
667-
assert mock_client.evaluators.delete.called, "Should delete existing evaluator"
668-
669-
# Verify create happened after delete
670-
assert mock_client.evaluators.create.called, "Should create evaluator after delete"
671-
672-
finally:
673-
os.chdir(original_cwd)
674-
if test_project_dir in sys.path:
675-
sys.path.remove(test_project_dir)
676-
shutil.rmtree(test_project_dir, ignore_errors=True)

tests/test_evaluation.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_create_evaluation_helper(monkeypatch):
4141

4242
# Track SDK calls
4343
create_called = False
44+
version_create_called = False
4445
upload_endpoint_called = False
4546
validate_called = False
4647

@@ -61,7 +62,16 @@ def mock_create(evaluator_id, evaluator):
6162
assert evaluator["description"] == "Test description"
6263
return mock_evaluator_result
6364

64-
def mock_get_upload_endpoint(evaluator_id, filename_to_size):
65+
# Mock evaluator_versions.create
66+
mock_version_result = MagicMock()
67+
mock_version_result.name = "accounts/test_account/evaluators/test-eval/versions/v1"
68+
69+
def mock_version_create(evaluator_id, evaluator_version):
70+
nonlocal version_create_called
71+
version_create_called = True
72+
return mock_version_result
73+
74+
def mock_get_upload_endpoint(evaluator_id, version_id, filename_to_size):
6575
nonlocal upload_endpoint_called
6676
upload_endpoint_called = True
6777
mock_response = MagicMock()
@@ -71,7 +81,7 @@ def mock_get_upload_endpoint(evaluator_id, filename_to_size):
7181
mock_response.filename_to_signed_urls = signed_urls
7282
return mock_response
7383

74-
def mock_validate_upload(evaluator_id, body):
84+
def mock_validate_upload(evaluator_id, version_id):
7585
nonlocal validate_called
7686
validate_called = True
7787
return MagicMock()
@@ -83,13 +93,14 @@ def mock_validate_upload(evaluator_id, body):
8393
mock_gcs_response.raise_for_status = MagicMock()
8494
mock_session.send.return_value = mock_gcs_response
8595

86-
# Patch the Fireworks client
87-
with patch("eval_protocol.evaluation.Fireworks") as mock_fireworks_class:
96+
# Patch the Fireworks client at the location where it's imported
97+
with patch("eval_protocol.fireworks_client.Fireworks") as mock_fireworks_class:
8898
mock_client = MagicMock()
8999
mock_fireworks_class.return_value = mock_client
90100
mock_client.evaluators.create = mock_create
91-
mock_client.evaluators.get_upload_endpoint = mock_get_upload_endpoint
92-
mock_client.evaluators.validate_upload = mock_validate_upload
101+
mock_client.evaluator_versions.create = mock_version_create
102+
mock_client.evaluator_versions.get_upload_endpoint = mock_get_upload_endpoint
103+
mock_client.evaluator_versions.validate_upload = mock_validate_upload
93104

94105
# Patch requests.Session for GCS upload
95106
monkeypatch.setattr("requests.Session", lambda: mock_session)
@@ -109,6 +120,7 @@ def mock_validate_upload(evaluator_id, body):
109120

110121
# Verify full upload flow was executed
111122
assert create_called, "Create endpoint should be called"
123+
assert version_create_called, "Version create should be called"
112124
assert upload_endpoint_called, "GetUploadEndpoint should be called"
113125
assert validate_called, "ValidateUpload should be called"
114126
assert mock_session.send.called, "GCS upload should happen"

0 commit comments

Comments
 (0)