Skip to content

Commit d43f4fb

Browse files
committed
Fix duplicate CLI shorthand flag registration
1 parent 7172373 commit d43f4fb

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

eval_protocol/cli_commands/utils.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,34 @@ def _add_flag(
705705
)
706706

707707

708+
def _resolve_available_flags(
709+
parser: argparse.ArgumentParser,
710+
primary_flag: str,
711+
alias_flags: list[str],
712+
reserved_flags: set[str],
713+
) -> list[str]:
714+
option_string_actions = getattr(parser, "_option_string_actions", {})
715+
if primary_flag in option_string_actions:
716+
raise ValueError(f"Conflicting primary option string: {primary_flag}")
717+
718+
flags = [primary_flag]
719+
seen = {primary_flag}
720+
721+
for alias in alias_flags:
722+
if alias in seen:
723+
continue
724+
seen.add(alias)
725+
726+
if alias in reserved_flags:
727+
continue
728+
if alias in option_string_actions:
729+
continue
730+
731+
flags.append(alias)
732+
733+
return flags
734+
735+
708736
def add_args_from_callable_signature(
709737
parser: argparse.ArgumentParser,
710738
fn: Callable[..., Any],
@@ -722,6 +750,7 @@ def add_args_from_callable_signature(
722750
sig = inspect.signature(fn)
723751
help = _parse_args_section_from_doc(inspect.getdoc(fn) or "")
724752
hints = typing.get_type_hints(fn, include_extras=True)
753+
arg_specs: list[tuple[str, list[str], Any, str | None]] = []
725754

726755
for name in sig.parameters.keys():
727756
resolved_type = unwrap_union(hints.get(name))
@@ -738,17 +767,21 @@ def add_args_from_callable_signature(
738767
prefix = name.replace("_", "-")
739768
field_kebab = field_name.replace("_", "-")
740769
flag_name = f"--{prefix}-{field_kebab}"
741-
flags = [flag_name] + aliases.get(f"{name}.{field_name}", []) + [f"--{field_kebab}"]
770+
alias_flags = aliases.get(f"{name}.{field_name}", []) + [f"--{field_kebab}"]
742771
help_text = help_overrides.get(f"{name}.{field_name}", field_help.get(field_name))
743772

744-
_add_flag(parser, flags, field_hints.get(field_name, field_type), help_text)
773+
arg_specs.append((flag_name, alias_flags, field_hints.get(field_name, field_type), help_text))
745774
continue
746775

747776
if name in top_level_skip:
748777
continue
749778

750779
flag_name = "--" + name.replace("_", "-")
751-
flags = [flag_name] + aliases.get(name, [])
752780
help_text = help_overrides.get(name, help.get(name))
781+
arg_specs.append((flag_name, aliases.get(name, []), hints.get(name), help_text))
782+
783+
reserved_flags = {primary_flag for primary_flag, _, _, _ in arg_specs}
753784

754-
_add_flag(parser, flags, hints.get(name), help_text)
785+
for primary_flag, alias_flags, hint, help_text in arg_specs:
786+
flags = _resolve_available_flags(parser, primary_flag, alias_flags, reserved_flags)
787+
_add_flag(parser, flags, hint, help_text)

tests/test_cli_args.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import argparse
12
import subprocess
23
import sys
4+
from typing import TypedDict
35

46
import pytest
57

68
# Module to be tested
79
from eval_protocol.cli import parse_args
10+
from eval_protocol.cli_commands.utils import add_args_from_callable_signature
811

912

1013
def test_unknown_flag_fails_fast(capsys):
@@ -41,3 +44,61 @@ def test_verbose_flag():
4144

4245
parsed_not_verbose, _ = parse_args(["upload", "--path", "."])
4346
assert parsed_not_verbose.verbose is False
47+
48+
49+
def test_add_args_skips_duplicate_nested_shorthand_flags():
50+
class AwsS3Config(TypedDict):
51+
credentials_secret: str
52+
53+
class AzureBlobStorageConfig(TypedDict):
54+
credentials_secret: str
55+
56+
def create(
57+
*,
58+
aws_s3_config: AwsS3Config | None = None,
59+
azure_blob_storage_config: AzureBlobStorageConfig | None = None,
60+
) -> None:
61+
return None
62+
63+
parser = argparse.ArgumentParser()
64+
add_args_from_callable_signature(parser, create)
65+
66+
parsed = parser.parse_args(
67+
[
68+
"--credentials-secret",
69+
"aws-short",
70+
"--azure-blob-storage-config-credentials-secret",
71+
"azure-prefixed",
72+
]
73+
)
74+
75+
assert parsed.aws_s3_config_credentials_secret == "aws-short"
76+
assert parsed.azure_blob_storage_config_credentials_secret == "azure-prefixed"
77+
assert "--credentials-secret" in parser._option_string_actions
78+
79+
80+
def test_add_args_preserves_top_level_canonical_flag_over_nested_shorthand():
81+
class InferenceParameters(TypedDict):
82+
extra_body: str
83+
84+
def create(
85+
*,
86+
inference_parameters: InferenceParameters | None = None,
87+
extra_body: str | None = None,
88+
) -> None:
89+
return None
90+
91+
parser = argparse.ArgumentParser()
92+
add_args_from_callable_signature(parser, create)
93+
94+
parsed = parser.parse_args(
95+
[
96+
"--extra-body",
97+
"top-level",
98+
"--inference-parameters-extra-body",
99+
"nested",
100+
]
101+
)
102+
103+
assert parsed.extra_body == "top-level"
104+
assert parsed.inference_parameters_extra_body == "nested"

0 commit comments

Comments
 (0)