Skip to content
84 changes: 83 additions & 1 deletion agent-sdk-client/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,61 @@
"""Configuration for sdk-client Lambda."""
import logging
import os
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)
DEFAULT_CONFIG_PATH = Path(__file__).with_name("config.toml")


def extract_command(text: Optional[str]) -> Optional[str]:
"""Extract command (with leading slash) from text, ignoring arguments/bot names."""
if not text:
return None

trimmed = text.strip()
if not trimmed.startswith('/'):
return None

command = trimmed.split()[0]
if '@' in command:
command = command.split('@', 1)[0]
command = command.strip()
if not command or command == '/':
return None
return command


def _load_config(config_path: Path = DEFAULT_CONFIG_PATH) -> tuple[list[str], dict[str, str]]:
"""Load agent/local commands from TOML config file."""
if not config_path.exists():
return [], {}

try:
with config_path.open('rb') as f:
data = tomllib.load(f)
agent_commands = data.get('agent_commands', {}).get('commands', [])
if not isinstance(agent_commands, list):
logger.warning("Agent commands config is not a list; ignoring configuration")
agent_commands = []
agent_commands = [cmd for cmd in agent_commands if isinstance(cmd, str)]

local_commands_raw = data.get('local_commands', {})
if not isinstance(local_commands_raw, dict):
logger.warning("Local commands config is not a table; ignoring configuration")
local_commands_raw = {}
local_commands = {
f"/{name.lstrip('/')}" if not name.startswith('/') else name: str(value)
for name, value in local_commands_raw.items()
if isinstance(name, str) and isinstance(value, str)
}

return agent_commands, local_commands
except (OSError, tomllib.TOMLDecodeError) as exc: # pragma: no cover - defensive logging
logger.warning("Failed to load command configuration: %s", exc)
return [], {}


@dataclass
Expand All @@ -11,13 +66,40 @@ class Config:
agent_server_url: str
auth_token: str
queue_url: str
agent_commands: list[str]
local_commands: dict[str, str]

@classmethod
def from_env(cls) -> 'Config':
def from_env(cls, config_path: Optional[Path] = None) -> 'Config':
"""Load configuration from environment variables."""
agent_cmds, local_cmds = _load_config(config_path or DEFAULT_CONFIG_PATH)
return cls(
telegram_token=os.getenv('TELEGRAM_BOT_TOKEN', ''),
agent_server_url=os.getenv('AGENT_SERVER_URL', ''),
auth_token=os.getenv('SDK_CLIENT_AUTH_TOKEN', 'default-token'),
queue_url=os.getenv('QUEUE_URL', ''),
agent_commands=agent_cmds,
local_commands=local_cmds,
)

def get_command(self, text: Optional[str]) -> Optional[str]:
return extract_command(text)

def is_agent_command(self, cmd: Optional[str]) -> bool:
return bool(cmd) and cmd in self.agent_commands

def is_local_command(self, cmd: Optional[str]) -> bool:
return bool(cmd) and cmd in self.local_commands

def local_response(self, cmd: str) -> str:
return self.local_commands.get(cmd, "Unsupported command.")

def unknown_command_message(self) -> str:
parts = []
if self.agent_commands:
parts.append("Agent commands:\n" + "\n".join(self.agent_commands))
if self.local_commands:
parts.append("Local commands:\n" + "\n".join(self.local_commands.keys()))
if not parts:
return "Unsupported command."
return "Unsupported command.\n\n" + "\n\n".join(parts)
10 changes: 10 additions & 0 deletions agent-sdk-client/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[agent_commands]
# Commands forwarded to the Agent backend
commands = [
"/custom-skill",
"/hello-world",
]

[local_commands]
# Local-only commands handled by the client
help = "Hello World"
38 changes: 38 additions & 0 deletions agent-sdk-client/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,44 @@ async def process_message(message_data: dict) -> None:
logger.warning("Received update with no message or edited_message")
return

cmd = config.get_command(message.text)
if cmd:
if config.is_local_command(cmd):
logger.info(
"Handling local command in consumer (fallback path)",
extra={'chat_id': message.chat_id, 'message_id': message.message_id},
)
try:
await bot.send_message(
chat_id=message.chat_id,
text=config.local_response(cmd),
message_thread_id=message.message_thread_id,
reply_to_message_id=message.message_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)
return

if not config.is_agent_command(cmd):
# Defensive guard: producer should already block non-agent commands.
logger.info(
"Skipping non-agent command (consumer fallback)",
extra={
'chat_id': message.chat_id,
'message_id': message.message_id,
},
)
try:
await bot.send_message(
chat_id=message.chat_id,
text=config.unknown_command_message(),
message_thread_id=message.message_thread_id,
reply_to_message_id=message.message_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)
return

# Send typing indicator
await bot.send_chat_action(
chat_id=message.chat_id,
Expand Down
36 changes: 36 additions & 0 deletions agent-sdk-client/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ def _send_to_sqs_safe(sqs, queue_url: str, message_body: dict) -> bool:
return False


def _handle_local_command(bot: Bot, message, config: Config, cmd: str) -> bool:
"""Handle local commands or unknown commands."""
if config.is_local_command(cmd):
text = config.local_response(cmd)
else:
text = config.unknown_command_message()

try:
bot.send_message(
chat_id=message.chat_id,
text=text,
message_thread_id=message.message_thread_id,
reply_to_message_id=message.message_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)

logger.info(
'Handled non-whitelisted command locally',
extra={
'chat_id': message.chat_id,
'message_id': message.message_id,
},
)
return True


def lambda_handler(event: dict, context: Any) -> dict:
"""Lambda entry point - Producer.

Expand Down Expand Up @@ -147,6 +174,15 @@ def lambda_handler(event: dict, context: Any) -> dict:
logger.debug('Ignoring webhook without text message')
return {'statusCode': 200}

cmd = config.get_command(message.text)
if cmd and config.is_local_command(cmd):
_handle_local_command(bot, message, config, cmd)
return {'statusCode': 200}

if cmd and not config.is_agent_command(cmd):
_handle_local_command(bot, message, config, cmd)
return {'statusCode': 200}

# Write to SQS for async processing
sqs = _get_sqs_client()
message_body = {
Expand Down
117 changes: 117 additions & 0 deletions tests/test_command_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import importlib.util
from pathlib import Path

import pytest

CLIENT_CONFIG_PATH = Path(__file__).resolve().parent.parent / "agent-sdk-client" / "config.py"
spec = importlib.util.spec_from_file_location("agent_sdk_client_config", CLIENT_CONFIG_PATH)
config_module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(config_module)
Config = config_module.Config
extract_command = config_module.extract_command


def load_config_from_text(text: str, tmp_path: Path) -> Config:
config_path = tmp_path / "config.toml"
config_path.write_text(text)
return Config.from_env(config_path=config_path)


def test_load_agent_and_local_commands(tmp_path):
cfg = load_config_from_text(
"""[agent_commands]
commands = ["/a", "/b"]

[local_commands]
help = "Hello"
""",
tmp_path,
)
assert cfg.agent_commands == ["/a", "/b"]
assert cfg.local_commands == {"/help": "Hello"}


@pytest.mark.parametrize(
"text,cmd",
[
("hello world", None),
("/allowed", "/allowed"),
("/allowed extra args", "/allowed"),
("/allowed@bot", "/allowed"),
("/@bot", None),
("/", None),
(None, None),
],
)
def test_extract_command(text, cmd):
assert extract_command(text) == cmd


def test_command_classification(tmp_path):
cfg = load_config_from_text(
"""[agent_commands]
commands = ["/agent"]

[local_commands]
help = "Hello World"
""",
tmp_path,
)
assert cfg.is_agent_command("/agent")
assert not cfg.is_agent_command("/other")
assert cfg.is_local_command("/help")
assert not cfg.is_local_command("/agent")


def test_unknown_command_message_lists_known():
cfg = Config(
telegram_token="",
agent_server_url="",
auth_token="",
queue_url="",
agent_commands=["/agent1"],
local_commands={"/help": "hi"},
)
msg = cfg.unknown_command_message()
assert "Agent commands" in msg and "/agent1" in msg
assert "Local commands" in msg and "/help" in msg


def test_invalid_agent_commands_type(tmp_path, caplog):
with caplog.at_level("WARNING"):
cfg = load_config_from_text(
"""[agent_commands]
commands = "not-a-list"
""",
tmp_path,
)
assert cfg.agent_commands == []
assert any("Agent commands config is not a list" in rec.message for rec in caplog.records)


def test_invalid_local_commands_type(tmp_path, caplog):
cfg = load_config_from_text(
"""[local_commands]
value = 1
""",
tmp_path,
)
assert cfg.local_commands == {}


def test_missing_config_file(tmp_path):
missing = tmp_path / "missing.toml"
cfg = Config.from_env(config_path=missing)
assert cfg.agent_commands == []
assert cfg.local_commands == {}


def test_malformed_toml_returns_empty(tmp_path, caplog):
path = tmp_path / "bad.toml"
path.write_text("not = [ [")
with caplog.at_level("WARNING"):
cfg = Config.from_env(config_path=path)
assert cfg.agent_commands == []
assert cfg.local_commands == {}
assert any("Failed to load command configuration" in rec.message for rec in caplog.records)