Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 96 additions & 29 deletions omnigent/spec/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,49 @@ class _ConfigYamlLoader(yaml.SafeLoader):
)


def _parse_int_field(raw: object, field_name: str) -> int:
"""
Coerce an integer config field while rejecting YAML booleans.

Python treats ``bool`` as a subclass of ``int``. Without this guard,
values like ``false`` silently become ``0`` for fields such as
``executor.max_iterations``.
"""
if isinstance(raw, bool):
raise OmnigentError(
f"{field_name} must be an integer, got boolean {raw!r}",
code=ErrorCode.INVALID_INPUT,
)
try:
return int(raw)
except (TypeError, ValueError) as exc:
raise OmnigentError(
f"{field_name} must be an integer, got {raw!r}",
code=ErrorCode.INVALID_INPUT,
) from exc


def _parse_float_field(raw: object, field_name: str) -> float:
"""
Coerce a numeric config field while rejecting YAML booleans.

``float(True)`` becomes ``1.0`` in Python, which is not a useful
interpretation for timing and threshold fields.
"""
if isinstance(raw, bool):
raise OmnigentError(
f"{field_name} must be a number, got boolean {raw!r}",
code=ErrorCode.INVALID_INPUT,
)
try:
return float(raw)
except (TypeError, ValueError) as exc:
raise OmnigentError(
f"{field_name} must be a number, got {raw!r}",
code=ErrorCode.INVALID_INPUT,
) from exc


def parse(root: Path, *, expand_env: bool = True) -> AgentSpec:
"""
Parse an agent image directory into an :class:`AgentSpec`.
Expand Down Expand Up @@ -299,7 +342,11 @@ def _parse_llm(
connection = expand_env_vars(raw_dict) if expand_env else raw_dict
profile_raw = raw.get("profile")
profile = str(profile_raw) if profile_raw is not None else None
request_timeout = int(raw["request_timeout"]) if "request_timeout" in raw else 300
request_timeout = (
_parse_int_field(raw["request_timeout"], "llm.request_timeout")
if "request_timeout" in raw
else 300
)
retry = _parse_retry(raw.get("retry"))
reserved = {"model", "connection", "profile", "request_timeout", "retry"}
extra = {k: v for k, v in raw.items() if k not in reserved}
Expand Down Expand Up @@ -360,7 +407,7 @@ def _parse_tools_config(
"""
if raw is None:
return ToolsConfig()
timeout = int(raw["timeout"]) if "timeout" in raw else 60
timeout = _parse_int_field(raw["timeout"], "tools.timeout") if "timeout" in raw else 60
retry = _parse_retry(raw.get("retry"))
builtins = _parse_builtin_tools(raw.get("builtins", []))
sandbox = _parse_sandbox_config(raw.get("sandbox"))
Expand Down Expand Up @@ -469,17 +516,27 @@ def _parse_retry(
return RetryPolicy()
defaults = RetryPolicy()
return RetryPolicy(
max_retries=int(raw.get("max_retries", defaults.max_retries)),
backoff_base_s=float(raw.get("backoff_base_s", defaults.backoff_base_s)),
backoff_max_s=float(raw.get("backoff_max_s", defaults.backoff_max_s)),
max_retries=_parse_int_field(
raw.get("max_retries", defaults.max_retries),
"retry.max_retries",
),
backoff_base_s=_parse_float_field(
raw.get("backoff_base_s", defaults.backoff_base_s),
"retry.backoff_base_s",
),
backoff_max_s=_parse_float_field(
raw.get("backoff_max_s", defaults.backoff_max_s),
"retry.backoff_max_s",
),
jitter=bool(raw.get("jitter", defaults.jitter)),
timeout_per_request_s=(
float(raw["timeout_per_request_s"])
_parse_float_field(raw["timeout_per_request_s"], "retry.timeout_per_request_s")
if raw.get("timeout_per_request_s") is not None
else defaults.timeout_per_request_s
),
retryable_status_codes=tuple(
int(c) for c in raw.get("retryable_status_codes", defaults.retryable_status_codes)
_parse_int_field(c, "retry.retryable_status_codes")
for c in raw.get("retryable_status_codes", defaults.retryable_status_codes)
),
)

Expand Down Expand Up @@ -534,7 +591,9 @@ def _parse_executor(
if etype == "omnigent" and profile is not None and "profile" not in config:
config["profile"] = profile
raw_cw = raw.get("context_window")
context_window: int | None = int(raw_cw) if raw_cw is not None else None
context_window: int | None = (
_parse_int_field(raw_cw, "executor.context_window") if raw_cw is not None else None
)
raw_model = raw.get("model")
model: str | None = str(raw_model) if raw_model is not None else None
# Parse ``executor.connection:`` — same shape as ``llm.connection:``
Expand All @@ -549,8 +608,11 @@ def _parse_executor(
auth = _parse_executor_auth(raw, expand_env=expand_env)
return ExecutorSpec(
type=etype,
timeout=int(raw.get("timeout", 3600)),
max_iterations=int(raw.get("max_iterations", 1000)),
timeout=_parse_int_field(raw.get("timeout", 3600), "executor.timeout"),
max_iterations=_parse_int_field(
raw.get("max_iterations", 1000),
"executor.max_iterations",
),
profile=profile,
config=config,
model=model,
Expand Down Expand Up @@ -766,7 +828,10 @@ def _parse_terminals(
allow_cwd_override=bool(entry.get("allow_cwd_override", False)),
allow_sandbox_override=bool(entry.get("allow_sandbox_override", False)),
log_file=entry.get("log_file"),
scrollback=int(entry.get("scrollback", 10000)),
scrollback=_parse_int_field(
entry.get("scrollback", 10000),
f"terminals.{name}.scrollback",
),
session_prefix=str(entry.get("session_prefix", "omni_")),
tmux_allow_passthrough=bool(entry.get("tmux_allow_passthrough", False)),
tmux_start_on_attach=bool(entry.get("tmux_start_on_attach", False)),
Expand Down Expand Up @@ -1671,8 +1736,14 @@ def _parse_compaction(
if raw is None:
return None
return CompactionConfig(
trigger_threshold=float(raw.get("trigger_threshold", 0.8)),
recent_window=int(raw.get("recent_window", 5)),
trigger_threshold=_parse_float_field(
raw.get("trigger_threshold", 0.8),
"compaction.trigger_threshold",
),
recent_window=_parse_int_field(
raw.get("recent_window", 5),
"compaction.recent_window",
),
)


Expand Down Expand Up @@ -2312,7 +2383,11 @@ def _parse_http_mcp_server(
expand_env_vars(raw.get("headers", {})) if expand_env else raw.get("headers", {})
),
description=raw.get("description"),
timeout=int(raw["timeout"]) if "timeout" in raw else None,
timeout=(
_parse_int_field(raw["timeout"], f"MCP server {name!r}.timeout")
if "timeout" in raw
else None
),
retry=_parse_retry(raw["retry"]) if "retry" in raw else None,
)

Expand Down Expand Up @@ -2404,7 +2479,11 @@ def _parse_stdio_mcp_server(
args=[str(a) for a in raw_args],
env={str(k): str(v) for k, v in env.items()},
description=raw.get("description"),
timeout=int(raw["timeout"]) if "timeout" in raw else None,
timeout=(
_parse_int_field(raw["timeout"], f"MCP server {name!r}.timeout")
if "timeout" in raw
else None
),
retry=_parse_retry(raw["retry"]) if "retry" in raw else None,
)

Expand Down Expand Up @@ -2578,13 +2657,7 @@ def _parse_guardrails_ask_timeout(raw: Any) -> int:
:raises OmnigentError: On non-integer or non-positive
values.
"""
try:
value = int(raw)
except (TypeError, ValueError) as exc:
raise OmnigentError(
f"guardrails.ask_timeout must be an integer, got {raw!r}",
code=ErrorCode.INVALID_INPUT,
) from exc
value = _parse_int_field(raw, "guardrails.ask_timeout")
if value <= 0:
raise OmnigentError(
"guardrails.ask_timeout must be > 0 "
Expand Down Expand Up @@ -3250,13 +3323,7 @@ def _parse_policy_ask_timeout(
"""
if raw is None:
return None
try:
value = int(raw)
except (TypeError, ValueError) as exc:
raise OmnigentError(
f"policy {policy_name!r}: `ask_timeout` must be an integer, got {raw!r}",
code=ErrorCode.INVALID_INPUT,
) from exc
value = _parse_int_field(raw, f"policy {policy_name!r}: `ask_timeout`")
if value <= 0:
raise OmnigentError(
f"policy {policy_name!r}: `ask_timeout` must be > 0 "
Expand Down
100 changes: 100 additions & 0 deletions tests/spec/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,90 @@ def test_parse_executor_defaults(tmp_path: Path) -> None:
assert spec.executor.type == "omnigent"


@pytest.mark.parametrize(
("config", "match"),
[
(
{"llm": {"model": "openai/gpt-4o", "request_timeout": True}},
r"llm\.request_timeout must be an integer",
),
(
{"tools": {"timeout": False}},
r"tools\.timeout must be an integer",
),
(
{"llm": {"model": "openai/gpt-4o", "retry": {"max_retries": True}}},
r"retry\.max_retries must be an integer",
),
(
{"llm": {"model": "openai/gpt-4o", "retry": {"backoff_base_s": False}}},
r"retry\.backoff_base_s must be a number",
),
(
{
"llm": {
"model": "openai/gpt-4o",
"retry": {"retryable_status_codes": [429, True]},
}
},
r"retry\.retryable_status_codes must be an integer",
),
(
{"executor": {"timeout": True}},
r"executor\.timeout must be an integer",
),
(
{"executor": {"max_iterations": False}},
r"executor\.max_iterations must be an integer",
),
(
{"executor": {"context_window": True}},
r"executor\.context_window must be an integer",
),
(
{"compaction": {"recent_window": False}},
r"compaction\.recent_window must be an integer",
),
(
{"compaction": {"trigger_threshold": True}},
r"compaction\.trigger_threshold must be a number",
),
(
{"guardrails": {"ask_timeout": True}},
r"guardrails\.ask_timeout must be an integer",
),
],
)
def test_parse_rejects_boolean_values_for_numeric_config_fields(
tmp_path: Path,
config: dict[str, object],
match: str,
) -> None:
"""Boolean YAML values must not be accepted as numeric config."""
config = {"spec_version": 1, **config}
(tmp_path / "config.yaml").write_text(yaml.dump(config))

with pytest.raises(OmnigentError, match=match):
parse(tmp_path)


def test_parse_rejects_boolean_terminal_scrollback(tmp_path: Path) -> None:
"""Terminal scrollback is a line count, not a boolean flag."""
config = {
"spec_version": 1,
"terminals": {
"main": {
"command": "bash",
"scrollback": False,
},
},
}
(tmp_path / "config.yaml").write_text(yaml.dump(config))

with pytest.raises(OmnigentError, match=r"terminals\.main\.scrollback must be an integer"):
parse(tmp_path)


def test_parse_executor_config_field(tmp_path: Path) -> None:
"""Executor block with a ``config`` sub-block parses string values.

Expand Down Expand Up @@ -2127,6 +2211,22 @@ def test_parse_mcp_server_with_timeout_and_retry(
assert mcp.retry.max_retries == 7


def test_parse_rejects_boolean_mcp_timeout(agent_dir: Path) -> None:
"""MCP timeout is a duration in seconds, not a boolean flag."""
mcp_dir = agent_dir / "tools" / "mcp"
mcp_dir.mkdir(parents=True)
mcp_config = {
"name": "slow-service",
"transport": "http",
"url": "http://localhost:9000/mcp",
"timeout": True,
}
(mcp_dir / "slow.yaml").write_text(yaml.dump(mcp_config))

with pytest.raises(OmnigentError, match=r"MCP server 'slow-service'\.timeout"):
parse(agent_dir)


def test_parse_mcp_stdio_minimal(agent_dir: Path) -> None:
"""
Parse a stdio MCP server with only the required ``command``.
Expand Down
Loading