Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions eval_protocol/cli_commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,34 @@ def _add_flag(
)


def _resolve_available_flags(
parser: argparse.ArgumentParser,
primary_flag: str,
alias_flags: list[str],
reserved_flags: set[str],
) -> list[str]:
option_string_actions = getattr(parser, "_option_string_actions", {})
if primary_flag in option_string_actions:
raise ValueError(f"Conflicting primary option string: {primary_flag}")

flags = [primary_flag]
seen = {primary_flag}

for alias in alias_flags:
if alias in seen:
continue
seen.add(alias)

if alias in reserved_flags:
continue
if alias in option_string_actions:
continue

flags.append(alias)

return flags


def add_args_from_callable_signature(
parser: argparse.ArgumentParser,
fn: Callable[..., Any],
Expand All @@ -722,6 +750,7 @@ def add_args_from_callable_signature(
sig = inspect.signature(fn)
help = _parse_args_section_from_doc(inspect.getdoc(fn) or "")
hints = typing.get_type_hints(fn, include_extras=True)
arg_specs: list[tuple[str, list[str], Any, str | None]] = []

for name in sig.parameters.keys():
resolved_type = unwrap_union(hints.get(name))
Expand All @@ -738,17 +767,21 @@ def add_args_from_callable_signature(
prefix = name.replace("_", "-")
field_kebab = field_name.replace("_", "-")
flag_name = f"--{prefix}-{field_kebab}"
flags = [flag_name] + aliases.get(f"{name}.{field_name}", []) + [f"--{field_kebab}"]
alias_flags = aliases.get(f"{name}.{field_name}", []) + [f"--{field_kebab}"]
help_text = help_overrides.get(f"{name}.{field_name}", field_help.get(field_name))

_add_flag(parser, flags, field_hints.get(field_name, field_type), help_text)
arg_specs.append((flag_name, alias_flags, field_hints.get(field_name, field_type), help_text))
continue

if name in top_level_skip:
continue

flag_name = "--" + name.replace("_", "-")
flags = [flag_name] + aliases.get(name, [])
help_text = help_overrides.get(name, help.get(name))
arg_specs.append((flag_name, aliases.get(name, []), hints.get(name), help_text))

reserved_flags = {primary_flag for primary_flag, _, _, _ in arg_specs}

_add_flag(parser, flags, hints.get(name), help_text)
for primary_flag, alias_flags, hint, help_text in arg_specs:
flags = _resolve_available_flags(parser, primary_flag, alias_flags, reserved_flags)
_add_flag(parser, flags, hint, help_text)
37 changes: 31 additions & 6 deletions tests/test_adapters_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,31 @@
from typing import Any, Dict

import pytest
import requests

from eval_protocol.models import EvaluationRow, InputMetadata, Message


def _skip_on_external_dataset_network_error(exc: Exception, provider: str) -> None:
"""Skip flaky real-data E2E tests when the upstream dataset service is unreachable."""
transient_markers = (
"read timed out",
"connection reset",
"connection aborted",
"temporary failure",
"failed to establish a new connection",
"max retries exceeded",
"name or service not known",
"service unavailable",
)

if isinstance(exc, requests.exceptions.RequestException):
pytest.skip(f"Skipping due to {provider} network issue: {exc}")

if any(marker in str(exc).lower() for marker in transient_markers):
pytest.skip(f"Skipping due to {provider} network issue: {exc}")


class TestLangfuseAdapterE2E:
"""End-to-end tests for Langfuse adapter with real deployment."""

Expand Down Expand Up @@ -295,13 +316,17 @@ def math_transform(row: Dict[str, Any]) -> Dict[str, Any]:
}

# Create adapter
adapter = create_huggingface_adapter(
dataset_id="SuperSecureHuman/competition_math_hf_dataset",
transform_fn=math_transform,
)
try:
adapter = create_huggingface_adapter(
dataset_id="SuperSecureHuman/competition_math_hf_dataset",
transform_fn=math_transform,
)

# Test loading data
rows = list(adapter.get_evaluation_rows(split="test", limit=3))
# Test loading data
rows = list(adapter.get_evaluation_rows(split="test", limit=3))
except Exception as exc:
_skip_on_external_dataset_network_error(exc, provider="Hugging Face")
raise

# Verify data
assert len(rows) > 0, "Should retrieve MATH dataset data"
Expand Down
61 changes: 61 additions & 0 deletions tests/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
import subprocess
import sys
from typing import TypedDict

import pytest

# Module to be tested
from eval_protocol.cli import parse_args
from eval_protocol.cli_commands.utils import add_args_from_callable_signature


def test_unknown_flag_fails_fast(capsys):
Expand Down Expand Up @@ -41,3 +44,61 @@ def test_verbose_flag():

parsed_not_verbose, _ = parse_args(["upload", "--path", "."])
assert parsed_not_verbose.verbose is False


def test_add_args_skips_duplicate_nested_shorthand_flags():
class AwsS3Config(TypedDict):
credentials_secret: str

class AzureBlobStorageConfig(TypedDict):
credentials_secret: str

def create(
*,
aws_s3_config: AwsS3Config | None = None,
azure_blob_storage_config: AzureBlobStorageConfig | None = None,
) -> None:
return None

parser = argparse.ArgumentParser()
add_args_from_callable_signature(parser, create)

parsed = parser.parse_args(
[
"--credentials-secret",
"aws-short",
"--azure-blob-storage-config-credentials-secret",
"azure-prefixed",
]
)

assert parsed.aws_s3_config_credentials_secret == "aws-short"
assert parsed.azure_blob_storage_config_credentials_secret == "azure-prefixed"
assert "--credentials-secret" in parser._option_string_actions


def test_add_args_preserves_top_level_canonical_flag_over_nested_shorthand():
class InferenceParameters(TypedDict):
extra_body: str

def create(
*,
inference_parameters: InferenceParameters | None = None,
extra_body: str | None = None,
) -> None:
return None

parser = argparse.ArgumentParser()
add_args_from_callable_signature(parser, create)

parsed = parser.parse_args(
[
"--extra-body",
"top-level",
"--inference-parameters-extra-body",
"nested",
]
)

assert parsed.extra_body == "top-level"
assert parsed.inference_parameters_extra_body == "nested"
Loading