Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions spikee/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _do_single_request(
guardrail = True
if hasattr(gt, "categories"):
guardrail_categories = gt.categories
print("[Guardrail Triggered] {}: {}".format(entry["id"], error_message))
# print("[Guardrail Triggered] {}: {}".format(entry["id"], error_message))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean every time a guardrail is triggered we now print this on the console? Or am I reading this wrong?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to, commenting it out so it is now only displayed within the progress bar but still avaliable for debugging.


except MultiTurnSkip as ms:
error_message = str(ms)
Expand Down Expand Up @@ -720,6 +720,7 @@ def _run_threaded(
total_dataset_size,
initial_processed,
initial_success,
initial_guardrail
):
lock = threading.Lock()
bar_all = tqdm(
Expand All @@ -731,7 +732,11 @@ def _run_threaded(
position=0,
initial=initial_processed,
)
bar_entries.set_postfix(success=initial_success)
if initial_guardrail > 0:
bar_entries.set_postfix(success=initial_success, guardrails=initial_guardrail)
else:
bar_entries.set_postfix(success=initial_success)

executor = ThreadPoolExecutor(max_workers=num_threads)
futures = {
executor.submit(
Expand All @@ -750,6 +755,7 @@ def _run_threaded(
for entry in entries
}
success = initial_success
guardrail = initial_guardrail
try:
for fut in as_completed(futures):
entry = futures[fut]
Expand All @@ -758,12 +764,18 @@ def _run_threaded(
if isinstance(res, list):
for r in res:
success += int(r.get("success", False))
guardrail += int(r.get("guardrail", False))
append_jsonl_entry(output_file, r, lock)
else:
success += int(res.get("success", False))
guardrail += int(res.get("guardrail", False))
append_jsonl_entry(output_file, res, lock)
bar_entries.update(1)
bar_entries.set_postfix(success=success)

if guardrail > 0:
bar_entries.set_postfix(success=success, guardrails=guardrail)
else:
bar_entries.set_postfix(success=success)
except Exception as e:
print(f"[Error] Entry ID {entry['id']}: {e}")
traceback.print_exc()
Expand Down Expand Up @@ -915,6 +927,8 @@ def test_dataset(args):
print(f"[Info] Output will be saved to: {output_file}")

success_count = sum(1 for r in results if r.get("success"))
guardrail_count = sum(1 for r in results if r.get("guardrail"))

_run_threaded(
to_process,
target_module,
Expand All @@ -930,6 +944,7 @@ def test_dataset(args):
len(dataset_json),
len(completed_ids),
success_count,
guardrail_count
)

print(f"[Done] Testing finished. Results saved to {output_file}")
Expand Down