diff --git a/omnigent/spec/parser.py b/omnigent/spec/parser.py index 8f53ea4ff..5b89d1ad9 100644 --- a/omnigent/spec/parser.py +++ b/omnigent/spec/parser.py @@ -112,6 +112,49 @@ class _ConfigYamlLoader(yaml.SafeLoader): ) +def _parse_int_field(raw: object, field_name: str) -> int: + """ + Coerce an integer config field while rejecting YAML booleans. + + Python treats ``bool`` as a subclass of ``int``. Without this guard, + values like ``false`` silently become ``0`` for fields such as + ``executor.max_iterations``. + """ + if isinstance(raw, bool): + raise OmnigentError( + f"{field_name} must be an integer, got boolean {raw!r}", + code=ErrorCode.INVALID_INPUT, + ) + try: + return int(raw) + except (TypeError, ValueError) as exc: + raise OmnigentError( + f"{field_name} must be an integer, got {raw!r}", + code=ErrorCode.INVALID_INPUT, + ) from exc + + +def _parse_float_field(raw: object, field_name: str) -> float: + """ + Coerce a numeric config field while rejecting YAML booleans. + + ``float(True)`` becomes ``1.0`` in Python, which is not a useful + interpretation for timing and threshold fields. + """ + if isinstance(raw, bool): + raise OmnigentError( + f"{field_name} must be a number, got boolean {raw!r}", + code=ErrorCode.INVALID_INPUT, + ) + try: + return float(raw) + except (TypeError, ValueError) as exc: + raise OmnigentError( + f"{field_name} must be a number, got {raw!r}", + code=ErrorCode.INVALID_INPUT, + ) from exc + + def parse(root: Path, *, expand_env: bool = True) -> AgentSpec: """ Parse an agent image directory into an :class:`AgentSpec`. @@ -299,7 +342,11 @@ def _parse_llm( connection = expand_env_vars(raw_dict) if expand_env else raw_dict profile_raw = raw.get("profile") profile = str(profile_raw) if profile_raw is not None else None - request_timeout = int(raw["request_timeout"]) if "request_timeout" in raw else 300 + request_timeout = ( + _parse_int_field(raw["request_timeout"], "llm.request_timeout") + if "request_timeout" in raw + else 300 + ) retry = _parse_retry(raw.get("retry")) reserved = {"model", "connection", "profile", "request_timeout", "retry"} extra = {k: v for k, v in raw.items() if k not in reserved} @@ -360,7 +407,7 @@ def _parse_tools_config( """ if raw is None: return ToolsConfig() - timeout = int(raw["timeout"]) if "timeout" in raw else 60 + timeout = _parse_int_field(raw["timeout"], "tools.timeout") if "timeout" in raw else 60 retry = _parse_retry(raw.get("retry")) builtins = _parse_builtin_tools(raw.get("builtins", [])) sandbox = _parse_sandbox_config(raw.get("sandbox")) @@ -469,17 +516,27 @@ def _parse_retry( return RetryPolicy() defaults = RetryPolicy() return RetryPolicy( - max_retries=int(raw.get("max_retries", defaults.max_retries)), - backoff_base_s=float(raw.get("backoff_base_s", defaults.backoff_base_s)), - backoff_max_s=float(raw.get("backoff_max_s", defaults.backoff_max_s)), + max_retries=_parse_int_field( + raw.get("max_retries", defaults.max_retries), + "retry.max_retries", + ), + backoff_base_s=_parse_float_field( + raw.get("backoff_base_s", defaults.backoff_base_s), + "retry.backoff_base_s", + ), + backoff_max_s=_parse_float_field( + raw.get("backoff_max_s", defaults.backoff_max_s), + "retry.backoff_max_s", + ), jitter=bool(raw.get("jitter", defaults.jitter)), timeout_per_request_s=( - float(raw["timeout_per_request_s"]) + _parse_float_field(raw["timeout_per_request_s"], "retry.timeout_per_request_s") if raw.get("timeout_per_request_s") is not None else defaults.timeout_per_request_s ), retryable_status_codes=tuple( - int(c) for c in raw.get("retryable_status_codes", defaults.retryable_status_codes) + _parse_int_field(c, "retry.retryable_status_codes") + for c in raw.get("retryable_status_codes", defaults.retryable_status_codes) ), ) @@ -534,7 +591,9 @@ def _parse_executor( if etype == "omnigent" and profile is not None and "profile" not in config: config["profile"] = profile raw_cw = raw.get("context_window") - context_window: int | None = int(raw_cw) if raw_cw is not None else None + context_window: int | None = ( + _parse_int_field(raw_cw, "executor.context_window") if raw_cw is not None else None + ) raw_model = raw.get("model") model: str | None = str(raw_model) if raw_model is not None else None # Parse ``executor.connection:`` — same shape as ``llm.connection:`` @@ -549,8 +608,11 @@ def _parse_executor( auth = _parse_executor_auth(raw, expand_env=expand_env) return ExecutorSpec( type=etype, - timeout=int(raw.get("timeout", 3600)), - max_iterations=int(raw.get("max_iterations", 1000)), + timeout=_parse_int_field(raw.get("timeout", 3600), "executor.timeout"), + max_iterations=_parse_int_field( + raw.get("max_iterations", 1000), + "executor.max_iterations", + ), profile=profile, config=config, model=model, @@ -766,7 +828,10 @@ def _parse_terminals( allow_cwd_override=bool(entry.get("allow_cwd_override", False)), allow_sandbox_override=bool(entry.get("allow_sandbox_override", False)), log_file=entry.get("log_file"), - scrollback=int(entry.get("scrollback", 10000)), + scrollback=_parse_int_field( + entry.get("scrollback", 10000), + f"terminals.{name}.scrollback", + ), session_prefix=str(entry.get("session_prefix", "omni_")), tmux_allow_passthrough=bool(entry.get("tmux_allow_passthrough", False)), tmux_start_on_attach=bool(entry.get("tmux_start_on_attach", False)), @@ -1671,8 +1736,14 @@ def _parse_compaction( if raw is None: return None return CompactionConfig( - trigger_threshold=float(raw.get("trigger_threshold", 0.8)), - recent_window=int(raw.get("recent_window", 5)), + trigger_threshold=_parse_float_field( + raw.get("trigger_threshold", 0.8), + "compaction.trigger_threshold", + ), + recent_window=_parse_int_field( + raw.get("recent_window", 5), + "compaction.recent_window", + ), ) @@ -2312,7 +2383,11 @@ def _parse_http_mcp_server( expand_env_vars(raw.get("headers", {})) if expand_env else raw.get("headers", {}) ), description=raw.get("description"), - timeout=int(raw["timeout"]) if "timeout" in raw else None, + timeout=( + _parse_int_field(raw["timeout"], f"MCP server {name!r}.timeout") + if "timeout" in raw + else None + ), retry=_parse_retry(raw["retry"]) if "retry" in raw else None, ) @@ -2404,7 +2479,11 @@ def _parse_stdio_mcp_server( args=[str(a) for a in raw_args], env={str(k): str(v) for k, v in env.items()}, description=raw.get("description"), - timeout=int(raw["timeout"]) if "timeout" in raw else None, + timeout=( + _parse_int_field(raw["timeout"], f"MCP server {name!r}.timeout") + if "timeout" in raw + else None + ), retry=_parse_retry(raw["retry"]) if "retry" in raw else None, ) @@ -2578,13 +2657,7 @@ def _parse_guardrails_ask_timeout(raw: Any) -> int: :raises OmnigentError: On non-integer or non-positive values. """ - try: - value = int(raw) - except (TypeError, ValueError) as exc: - raise OmnigentError( - f"guardrails.ask_timeout must be an integer, got {raw!r}", - code=ErrorCode.INVALID_INPUT, - ) from exc + value = _parse_int_field(raw, "guardrails.ask_timeout") if value <= 0: raise OmnigentError( "guardrails.ask_timeout must be > 0 " @@ -3250,13 +3323,7 @@ def _parse_policy_ask_timeout( """ if raw is None: return None - try: - value = int(raw) - except (TypeError, ValueError) as exc: - raise OmnigentError( - f"policy {policy_name!r}: `ask_timeout` must be an integer, got {raw!r}", - code=ErrorCode.INVALID_INPUT, - ) from exc + value = _parse_int_field(raw, f"policy {policy_name!r}: `ask_timeout`") if value <= 0: raise OmnigentError( f"policy {policy_name!r}: `ask_timeout` must be > 0 " diff --git a/tests/spec/test_parser.py b/tests/spec/test_parser.py index 6127418d2..4d62b78dd 100644 --- a/tests/spec/test_parser.py +++ b/tests/spec/test_parser.py @@ -2050,6 +2050,90 @@ def test_parse_executor_defaults(tmp_path: Path) -> None: assert spec.executor.type == "omnigent" +@pytest.mark.parametrize( + ("config", "match"), + [ + ( + {"llm": {"model": "openai/gpt-4o", "request_timeout": True}}, + r"llm\.request_timeout must be an integer", + ), + ( + {"tools": {"timeout": False}}, + r"tools\.timeout must be an integer", + ), + ( + {"llm": {"model": "openai/gpt-4o", "retry": {"max_retries": True}}}, + r"retry\.max_retries must be an integer", + ), + ( + {"llm": {"model": "openai/gpt-4o", "retry": {"backoff_base_s": False}}}, + r"retry\.backoff_base_s must be a number", + ), + ( + { + "llm": { + "model": "openai/gpt-4o", + "retry": {"retryable_status_codes": [429, True]}, + } + }, + r"retry\.retryable_status_codes must be an integer", + ), + ( + {"executor": {"timeout": True}}, + r"executor\.timeout must be an integer", + ), + ( + {"executor": {"max_iterations": False}}, + r"executor\.max_iterations must be an integer", + ), + ( + {"executor": {"context_window": True}}, + r"executor\.context_window must be an integer", + ), + ( + {"compaction": {"recent_window": False}}, + r"compaction\.recent_window must be an integer", + ), + ( + {"compaction": {"trigger_threshold": True}}, + r"compaction\.trigger_threshold must be a number", + ), + ( + {"guardrails": {"ask_timeout": True}}, + r"guardrails\.ask_timeout must be an integer", + ), + ], +) +def test_parse_rejects_boolean_values_for_numeric_config_fields( + tmp_path: Path, + config: dict[str, object], + match: str, +) -> None: + """Boolean YAML values must not be accepted as numeric config.""" + config = {"spec_version": 1, **config} + (tmp_path / "config.yaml").write_text(yaml.dump(config)) + + with pytest.raises(OmnigentError, match=match): + parse(tmp_path) + + +def test_parse_rejects_boolean_terminal_scrollback(tmp_path: Path) -> None: + """Terminal scrollback is a line count, not a boolean flag.""" + config = { + "spec_version": 1, + "terminals": { + "main": { + "command": "bash", + "scrollback": False, + }, + }, + } + (tmp_path / "config.yaml").write_text(yaml.dump(config)) + + with pytest.raises(OmnigentError, match=r"terminals\.main\.scrollback must be an integer"): + parse(tmp_path) + + def test_parse_executor_config_field(tmp_path: Path) -> None: """Executor block with a ``config`` sub-block parses string values. @@ -2127,6 +2211,22 @@ def test_parse_mcp_server_with_timeout_and_retry( assert mcp.retry.max_retries == 7 +def test_parse_rejects_boolean_mcp_timeout(agent_dir: Path) -> None: + """MCP timeout is a duration in seconds, not a boolean flag.""" + mcp_dir = agent_dir / "tools" / "mcp" + mcp_dir.mkdir(parents=True) + mcp_config = { + "name": "slow-service", + "transport": "http", + "url": "http://localhost:9000/mcp", + "timeout": True, + } + (mcp_dir / "slow.yaml").write_text(yaml.dump(mcp_config)) + + with pytest.raises(OmnigentError, match=r"MCP server 'slow-service'\.timeout"): + parse(agent_dir) + + def test_parse_mcp_stdio_minimal(agent_dir: Path) -> None: """ Parse a stdio MCP server with only the required ``command``.