diff --git a/agent-sdk-client/config.py b/agent-sdk-client/config.py index 791b468..8dfdca7 100644 --- a/agent-sdk-client/config.py +++ b/agent-sdk-client/config.py @@ -1,6 +1,61 @@ """Configuration for sdk-client Lambda.""" +import logging import os +import tomllib from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) +DEFAULT_CONFIG_PATH = Path(__file__).with_name("config.toml") + + +def extract_command(text: Optional[str]) -> Optional[str]: + """Extract command (with leading slash) from text, ignoring arguments/bot names.""" + if not text: + return None + + trimmed = text.strip() + if not trimmed.startswith('/'): + return None + + command = trimmed.split()[0] + if '@' in command: + command = command.split('@', 1)[0] + command = command.strip() + if not command or command == '/': + return None + return command + + +def _load_config(config_path: Path = DEFAULT_CONFIG_PATH) -> tuple[list[str], dict[str, str]]: + """Load agent/local commands from TOML config file.""" + if not config_path.exists(): + return [], {} + + try: + with config_path.open('rb') as f: + data = tomllib.load(f) + agent_commands = data.get('agent_commands', {}).get('commands', []) + if not isinstance(agent_commands, list): + logger.warning("Agent commands config is not a list; ignoring configuration") + agent_commands = [] + agent_commands = [cmd for cmd in agent_commands if isinstance(cmd, str)] + + local_commands_raw = data.get('local_commands', {}) + if not isinstance(local_commands_raw, dict): + logger.warning("Local commands config is not a table; ignoring configuration") + local_commands_raw = {} + local_commands = { + f"/{name.lstrip('/')}" if not name.startswith('/') else name: str(value) + for name, value in local_commands_raw.items() + if isinstance(name, str) and isinstance(value, str) + } + + return agent_commands, local_commands + except (OSError, tomllib.TOMLDecodeError) as exc: # pragma: no cover - defensive logging + logger.warning("Failed to load command configuration: %s", exc) + return [], {} @dataclass @@ -11,13 +66,40 @@ class Config: agent_server_url: str auth_token: str queue_url: str + agent_commands: list[str] + local_commands: dict[str, str] @classmethod - def from_env(cls) -> 'Config': + def from_env(cls, config_path: Optional[Path] = None) -> 'Config': """Load configuration from environment variables.""" + agent_cmds, local_cmds = _load_config(config_path or DEFAULT_CONFIG_PATH) return cls( telegram_token=os.getenv('TELEGRAM_BOT_TOKEN', ''), agent_server_url=os.getenv('AGENT_SERVER_URL', ''), auth_token=os.getenv('SDK_CLIENT_AUTH_TOKEN', 'default-token'), queue_url=os.getenv('QUEUE_URL', ''), + agent_commands=agent_cmds, + local_commands=local_cmds, ) + + def get_command(self, text: Optional[str]) -> Optional[str]: + return extract_command(text) + + def is_agent_command(self, cmd: Optional[str]) -> bool: + return bool(cmd) and cmd in self.agent_commands + + def is_local_command(self, cmd: Optional[str]) -> bool: + return bool(cmd) and cmd in self.local_commands + + def local_response(self, cmd: str) -> str: + return self.local_commands.get(cmd, "Unsupported command.") + + def unknown_command_message(self) -> str: + parts = [] + if self.agent_commands: + parts.append("Agent commands:\n" + "\n".join(self.agent_commands)) + if self.local_commands: + parts.append("Local commands:\n" + "\n".join(self.local_commands.keys())) + if not parts: + return "Unsupported command." + return "Unsupported command.\n\n" + "\n\n".join(parts) diff --git a/agent-sdk-client/config.toml b/agent-sdk-client/config.toml new file mode 100644 index 0000000..4186b4e --- /dev/null +++ b/agent-sdk-client/config.toml @@ -0,0 +1,10 @@ +[agent_commands] +# Commands forwarded to the Agent backend +commands = [ + "/custom-skill", + "/hello-world", +] + +[local_commands] +# Local-only commands handled by the client +help = "Hello World" diff --git a/agent-sdk-client/consumer.py b/agent-sdk-client/consumer.py index 673bb6c..52faf6a 100644 --- a/agent-sdk-client/consumer.py +++ b/agent-sdk-client/consumer.py @@ -55,6 +55,44 @@ async def process_message(message_data: dict) -> None: logger.warning("Received update with no message or edited_message") return + cmd = config.get_command(message.text) + if cmd: + if config.is_local_command(cmd): + logger.info( + "Handling local command in consumer (fallback path)", + extra={'chat_id': message.chat_id, 'message_id': message.message_id}, + ) + try: + await bot.send_message( + chat_id=message.chat_id, + text=config.local_response(cmd), + message_thread_id=message.message_thread_id, + reply_to_message_id=message.message_id, + ) + except Exception: + logger.warning("Failed to send local command response", exc_info=True) + return + + if not config.is_agent_command(cmd): + # Defensive guard: producer should already block non-agent commands. + logger.info( + "Skipping non-agent command (consumer fallback)", + extra={ + 'chat_id': message.chat_id, + 'message_id': message.message_id, + }, + ) + try: + await bot.send_message( + chat_id=message.chat_id, + text=config.unknown_command_message(), + message_thread_id=message.message_thread_id, + reply_to_message_id=message.message_id, + ) + except Exception: + logger.warning("Failed to send local command response", exc_info=True) + return + # Send typing indicator await bot.send_chat_action( chat_id=message.chat_id, diff --git a/agent-sdk-client/handler.py b/agent-sdk-client/handler.py index f681358..ae3cd9f 100644 --- a/agent-sdk-client/handler.py +++ b/agent-sdk-client/handler.py @@ -120,6 +120,33 @@ def _send_to_sqs_safe(sqs, queue_url: str, message_body: dict) -> bool: return False +def _handle_local_command(bot: Bot, message, config: Config, cmd: str) -> bool: + """Handle local commands or unknown commands.""" + if config.is_local_command(cmd): + text = config.local_response(cmd) + else: + text = config.unknown_command_message() + + try: + bot.send_message( + chat_id=message.chat_id, + text=text, + message_thread_id=message.message_thread_id, + reply_to_message_id=message.message_id, + ) + except Exception: + logger.warning("Failed to send local command response", exc_info=True) + + logger.info( + 'Handled non-whitelisted command locally', + extra={ + 'chat_id': message.chat_id, + 'message_id': message.message_id, + }, + ) + return True + + def lambda_handler(event: dict, context: Any) -> dict: """Lambda entry point - Producer. @@ -147,6 +174,15 @@ def lambda_handler(event: dict, context: Any) -> dict: logger.debug('Ignoring webhook without text message') return {'statusCode': 200} + cmd = config.get_command(message.text) + if cmd and config.is_local_command(cmd): + _handle_local_command(bot, message, config, cmd) + return {'statusCode': 200} + + if cmd and not config.is_agent_command(cmd): + _handle_local_command(bot, message, config, cmd) + return {'statusCode': 200} + # Write to SQS for async processing sqs = _get_sqs_client() message_body = { diff --git a/tests/test_command_config.py b/tests/test_command_config.py new file mode 100644 index 0000000..e268340 --- /dev/null +++ b/tests/test_command_config.py @@ -0,0 +1,117 @@ +import importlib.util +from pathlib import Path + +import pytest + +CLIENT_CONFIG_PATH = Path(__file__).resolve().parent.parent / "agent-sdk-client" / "config.py" +spec = importlib.util.spec_from_file_location("agent_sdk_client_config", CLIENT_CONFIG_PATH) +config_module = importlib.util.module_from_spec(spec) +assert spec.loader is not None +spec.loader.exec_module(config_module) +Config = config_module.Config +extract_command = config_module.extract_command + + +def load_config_from_text(text: str, tmp_path: Path) -> Config: + config_path = tmp_path / "config.toml" + config_path.write_text(text) + return Config.from_env(config_path=config_path) + + +def test_load_agent_and_local_commands(tmp_path): + cfg = load_config_from_text( + """[agent_commands] +commands = ["/a", "/b"] + +[local_commands] +help = "Hello" +""", + tmp_path, + ) + assert cfg.agent_commands == ["/a", "/b"] + assert cfg.local_commands == {"/help": "Hello"} + + +@pytest.mark.parametrize( + "text,cmd", + [ + ("hello world", None), + ("/allowed", "/allowed"), + ("/allowed extra args", "/allowed"), + ("/allowed@bot", "/allowed"), + ("/@bot", None), + ("/", None), + (None, None), + ], +) +def test_extract_command(text, cmd): + assert extract_command(text) == cmd + + +def test_command_classification(tmp_path): + cfg = load_config_from_text( + """[agent_commands] +commands = ["/agent"] + +[local_commands] +help = "Hello World" +""", + tmp_path, + ) + assert cfg.is_agent_command("/agent") + assert not cfg.is_agent_command("/other") + assert cfg.is_local_command("/help") + assert not cfg.is_local_command("/agent") + + +def test_unknown_command_message_lists_known(): + cfg = Config( + telegram_token="", + agent_server_url="", + auth_token="", + queue_url="", + agent_commands=["/agent1"], + local_commands={"/help": "hi"}, + ) + msg = cfg.unknown_command_message() + assert "Agent commands" in msg and "/agent1" in msg + assert "Local commands" in msg and "/help" in msg + + +def test_invalid_agent_commands_type(tmp_path, caplog): + with caplog.at_level("WARNING"): + cfg = load_config_from_text( + """[agent_commands] +commands = "not-a-list" +""", + tmp_path, + ) + assert cfg.agent_commands == [] + assert any("Agent commands config is not a list" in rec.message for rec in caplog.records) + + +def test_invalid_local_commands_type(tmp_path, caplog): + cfg = load_config_from_text( + """[local_commands] +value = 1 +""", + tmp_path, + ) + assert cfg.local_commands == {} + + +def test_missing_config_file(tmp_path): + missing = tmp_path / "missing.toml" + cfg = Config.from_env(config_path=missing) + assert cfg.agent_commands == [] + assert cfg.local_commands == {} + + +def test_malformed_toml_returns_empty(tmp_path, caplog): + path = tmp_path / "bad.toml" + path.write_text("not = [ [") + with caplog.at_level("WARNING"): + cfg = Config.from_env(config_path=path) + assert cfg.agent_commands == [] + assert cfg.local_commands == {} + assert any("Failed to load command configuration" in rec.message for rec in caplog.records)