diff --git a/bin/ask b/bin/ask index f3167bff..5572cfb1 100755 --- a/bin/ask +++ b/bin/ask @@ -45,9 +45,16 @@ sys.path.insert(0, str(lib_dir)) from compat import read_stdin_text, setup_windows_encoding setup_windows_encoding() +from aliases import load_aliases, resolve_alias from cli_output import EXIT_ERROR, EXIT_OK from providers import parse_qualified_provider from session_utils import find_project_session_file +from agent_comm import ( + AgentMessage, broadcast_message, build_chain_messages, + parse_chain_spec, resolve_agent_to_provider, wrap_message, +) +from task_router import auto_route +from team_config import load_team_config, resolve_team_agent # Provider to daemon command mapping @@ -481,11 +488,22 @@ def make_task_id() -> str: def _usage() -> None: + aliases = load_aliases() + alias_list = ", ".join(f"{k}→{v}" for k, v in sorted(aliases.items())) print("Usage: ask [options] ", file=sys.stderr) print("", file=sys.stderr) print("Providers:", file=sys.stderr) print(" gemini, codex, opencode, droid, claude, copilot, codebuddy, qwen", file=sys.stderr) print("", file=sys.stderr) + print("Aliases:", file=sys.stderr) + print(f" {alias_list}", file=sys.stderr) + print("", file=sys.stderr) + team = load_team_config(Path.cwd()) + if team: + agents = ", ".join(f"{a.name}→{a.provider}" for a in team.agents) + print(f"Team '{team.name}' ({team.strategy}):", file=sys.stderr) + print(f" {agents}", file=sys.stderr) + print("", file=sys.stderr) print("Options:", file=sys.stderr) print(" -h, --help Show this help message", file=sys.stderr) print(" -t, --timeout SECONDS Request timeout (default: 3600)", file=sys.stderr) @@ -493,6 +511,10 @@ def _usage() -> None: print(" --foreground Run in foreground (no nohup/background)", file=sys.stderr) print(" --background Force background mode", file=sys.stderr) print(" --no-wrap Don't wrap with CCB protocol markers", file=sys.stderr) + print(" --auto Auto-select provider based on message content", file=sys.stderr) + print(" --to Send message to another agent (inter-agent)", file=sys.stderr) + print(" --broadcast Send message to all team agents", file=sys.stderr) + print(" --chain Run agent chain: \"a:task1 | b:task2 | c:task3\"", file=sys.stderr) def main(argv: list[str]) -> int: @@ -500,28 +522,24 @@ def main(argv: list[str]) -> int: _usage() return EXIT_ERROR - # First argument must be the provider + # First argument must be the provider (or --auto / --help) raw_provider = argv[1].lower() if raw_provider in ("-h", "--help"): _usage() return EXIT_OK - base_provider, instance = parse_qualified_provider(raw_provider) - - if base_provider not in PROVIDER_DAEMONS: - print(f"[ERROR] Unknown provider: {base_provider}", file=sys.stderr) - print(f"[ERROR] Available: {', '.join(PROVIDER_DAEMONS.keys())}", file=sys.stderr) - return EXIT_ERROR - - daemon_cmd = PROVIDER_DAEMONS[base_provider] - provider = raw_provider # keep full qualified key for daemon routing + auto_mode = raw_provider == "--auto" + cwd = Path.cwd() - # Parse remaining arguments + # Parse remaining arguments (shared by both auto and normal modes) timeout: float = 3600.0 notify_mode = False no_wrap = False foreground_mode = _default_foreground() + to_agent: str = "" + broadcast_mode = False + chain_spec: str = "" parts: list[str] = [] it = iter(argv[2:]) @@ -551,6 +569,23 @@ def main(argv: list[str]) -> int: if token == "--no-wrap": no_wrap = True continue + if token == "--to": + try: + to_agent = next(it).strip().lower() + except StopIteration: + print("[ERROR] --to requires an agent name", file=sys.stderr) + return EXIT_ERROR + continue + if token == "--broadcast": + broadcast_mode = True + continue + if token == "--chain": + try: + chain_spec = next(it).strip() + except StopIteration: + print("[ERROR] --chain requires a spec string", file=sys.stderr) + return EXIT_ERROR + continue parts.append(token) message = " ".join(parts).strip() @@ -560,6 +595,115 @@ def main(argv: list[str]) -> int: print("[ERROR] Message cannot be empty", file=sys.stderr) return EXIT_ERROR + team = load_team_config(cwd) + aliases = load_aliases(cwd) + + # --chain mode: sequential multi-agent pipeline + if chain_spec: + steps = parse_chain_spec(chain_spec) + if not steps: + print("[ERROR] Invalid chain spec", file=sys.stderr) + return EXIT_ERROR + chain_msgs = build_chain_messages(steps) + ask_cmd = str(Path(__file__).resolve()) + prev_output = "" + for i, msg in enumerate(chain_msgs): + # Resolve agent name to provider + target = resolve_agent_to_provider(msg.receiver, team, aliases) + if not target: + print(f"[ERROR] Unknown agent in chain: {msg.receiver}", file=sys.stderr) + return EXIT_ERROR + # Inject previous output as context + if prev_output: + msg = AgentMessage(sender=msg.sender, receiver=target, content=msg.content, context=prev_output) + else: + msg = AgentMessage(sender=msg.sender, receiver=target, content=msg.content) + wrapped = wrap_message(msg) + print(f"[CHAIN {i+1}/{len(chain_msgs)}] {msg.sender} → {target}", file=sys.stderr) + try: + result = subprocess.run( + [sys.executable, ask_cmd, target, "--foreground", "--no-wrap"], + input=wrapped, capture_output=True, text=True, + timeout=timeout, + ) + prev_output = result.stdout.strip() + if prev_output: + print(prev_output) + except subprocess.TimeoutExpired: + print(f"[ERROR] Chain step {i+1} timed out", file=sys.stderr) + return EXIT_ERROR + except Exception as e: + print(f"[ERROR] Chain step {i+1}: {e}", file=sys.stderr) + return EXIT_ERROR + return EXIT_OK + + # --broadcast mode: send to all team agents + if broadcast_mode: + if not team: + print("[ERROR] --broadcast requires a team config (.ccb/team.json)", file=sys.stderr) + return EXIT_ERROR + sender = raw_provider if not auto_mode else "auto" + msgs = broadcast_message(sender, message, team) + if not msgs: + print("[WARN] No broadcast recipients", file=sys.stderr) + return EXIT_OK + ask_cmd = str(Path(__file__).resolve()) + for msg in msgs: + wrapped = wrap_message(msg) + print(f"[BROADCAST] → {msg.receiver}", file=sys.stderr) + proc = subprocess.Popen( + [sys.executable, ask_cmd, msg.receiver, "--foreground", "--no-wrap"], + stdin=subprocess.PIPE, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + close_fds=True, + ) + try: + proc.stdin.write(wrapped.encode("utf-8")) + proc.stdin.close() + except Exception: + pass + print(f"[BROADCAST] Sent to {len(msgs)} agents", file=sys.stderr) + return EXIT_OK + + # --to mode: wrap message with sender metadata and redirect to target agent + if to_agent: + target_provider = resolve_agent_to_provider(to_agent, team, aliases) + if not target_provider: + print(f"[ERROR] Unknown target agent: {to_agent}", file=sys.stderr) + return EXIT_ERROR + sender = raw_provider if not auto_mode else "auto" + msg = AgentMessage(sender=sender, receiver=target_provider, content=message) + message = wrap_message(msg) + raw_provider = target_provider + print(f"[TO] {sender} → {to_agent} ({target_provider})", file=sys.stderr) + + # --auto mode: select provider based on message content + elif auto_mode: + route = auto_route(message, team) + raw_provider = route.provider + print(f"[AUTO] → {raw_provider}" + (f" ({route.reason})" if route.reason else ""), file=sys.stderr) + else: + # Resolution order: team agents > aliases > direct provider names + team_agent = resolve_team_agent(raw_provider, team) + if team_agent: + raw_provider = team_agent.provider + else: + base_part, _, instance_part = raw_provider.partition(":") + base_part = resolve_alias(base_part, aliases) + raw_provider = f"{base_part}:{instance_part}" if instance_part else base_part + + base_provider, instance = parse_qualified_provider(raw_provider) + + if base_provider not in PROVIDER_DAEMONS: + print(f"[ERROR] Unknown provider: {base_provider}", file=sys.stderr) + print(f"[ERROR] Available: {', '.join(PROVIDER_DAEMONS.keys())}", file=sys.stderr) + return EXIT_ERROR + + daemon_cmd = PROVIDER_DAEMONS[base_provider] + provider = raw_provider # keep full qualified key for daemon routing + # Notify mode: sync send, no wait for reply (used for hook notifications) if notify_mode: _require_caller() diff --git a/lib/agent_comm.py b/lib/agent_comm.py new file mode 100644 index 00000000..d556ce9e --- /dev/null +++ b/lib/agent_comm.py @@ -0,0 +1,196 @@ +"""Inter-agent communication for CCB Agent Teams. + +Supports three communication patterns: +1. Directed message: ask a --to b "请审查这段代码" +2. Task chain: sequential multi-agent pipeline +3. Broadcast: notify all team members + +Messages are wrapped with metadata so the receiving agent knows +the sender and context. Execution uses the existing ask infrastructure. +""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Sequence + +from aliases import load_aliases, resolve_alias +from team_config import TeamAgent, TeamConfig, load_team_config, resolve_team_agent + + +@dataclass +class AgentMessage: + """A message from one agent to another.""" + sender: str + receiver: str + content: str + context: str = "" # optional context from previous agent's output + + +def resolve_agent_to_provider( + name: str, + team: Optional[TeamConfig], + aliases: Dict[str, str], +) -> Optional[str]: + """Resolve an agent name to a provider, checking team agents then aliases. + + Returns None if the name doesn't resolve to a known provider. + """ + key = (name or "").strip().lower() + if not key: + return None + + # Team agent takes priority + agent = resolve_team_agent(key, team) + if agent: + return agent.provider + + # Try alias + resolved = resolve_alias(key, aliases) + return resolved if resolved else None + + +def wrap_message(msg: AgentMessage) -> str: + """Wrap a message with sender metadata for the receiving agent. + + The wrapped format allows the receiving agent to understand context. + """ + lines = [] + lines.append(f"[CCB_FROM agent={msg.sender}]") + if msg.context: + lines.append(f"[CCB_CONTEXT]\n{msg.context}\n[/CCB_CONTEXT]") + lines.append(msg.content) + return "\n".join(lines) + + +def build_chain_script( + steps: List[AgentMessage], + ask_cmd: str, + timeout: float = 3600.0, + foreground: bool = True, +) -> str: + """Build a shell script that executes a chain of agent tasks sequentially. + + Each step feeds the previous output as context to the next agent. + Returns the shell script content. + """ + import shlex + + lines = ["#!/bin/sh", "set -e", ""] + lines.append("# CCB Agent Chain") + lines.append(f"# Steps: {len(steps)}") + lines.append("") + + for i, step in enumerate(steps): + step_var = f"STEP{i}_OUTPUT" + provider = step.receiver + content = step.content + + # First step: no context from previous + if i == 0: + wrapped = wrap_message(step) + else: + # Inject previous output as context + step_with_ctx = AgentMessage( + sender=step.sender, + receiver=step.receiver, + content=step.content, + context=f"$STEP{i-1}_OUTPUT", + ) + # For shell script, we build inline + wrapped = f"[CCB_FROM agent={step.sender}]\\n[CCB_CONTEXT]\\n$STEP{i-1}_OUTPUT\\n[/CCB_CONTEXT]\\n{content}" + + fg_flag = "--foreground" if foreground else "" + timeout_flag = f"--timeout {timeout}" if timeout else "" + + lines.append(f"echo '[CHAIN] Step {i+1}/{len(steps)}: {step.sender} → {provider}'") + if i == 0: + msg_escaped = shlex.quote(wrapped) + lines.append(f"{step_var}=$(python3 {shlex.quote(ask_cmd)} {shlex.quote(provider)} {fg_flag} {timeout_flag} {msg_escaped})") + else: + lines.append(f'WRAPPED="[CCB_FROM agent={step.sender}]') + lines.append(f"[CCB_CONTEXT]") + lines.append(f"$STEP{i-1}_OUTPUT") + lines.append(f"[/CCB_CONTEXT]") + lines.append(f'{content}"') + lines.append(f'{step_var}=$(echo "$WRAPPED" | python3 {shlex.quote(ask_cmd)} {shlex.quote(provider)} {fg_flag} {timeout_flag})') + lines.append(f'echo "${{step_var}}"') + lines.append("") + + # Final output is last step's output + if steps: + lines.append(f'echo "$STEP{len(steps)-1}_OUTPUT"') + + return "\n".join(lines) + + +def broadcast_message( + sender: str, + content: str, + team: TeamConfig, + exclude_sender: bool = True, +) -> List[AgentMessage]: + """Create messages to broadcast to all team agents. + + Returns a list of AgentMessages, one per recipient. + """ + messages = [] + sender_lower = sender.lower() + for agent in team.agents: + if exclude_sender and agent.name.lower() == sender_lower: + continue + messages.append(AgentMessage( + sender=sender, + receiver=agent.provider, + content=content, + )) + return messages + + +def parse_chain_spec(spec: str) -> List[tuple[str, str]]: + """Parse a chain specification string into (agent, task) pairs. + + Format: "agent1:task1 | agent2:task2 | agent3:task3" + Each segment is "agent:task" separated by " | ". + + Returns list of (agent_name, task_description) tuples. + """ + steps = [] + for segment in spec.split("|"): + segment = segment.strip() + if not segment: + continue + if ":" in segment: + agent, _, task = segment.partition(":") + agent = agent.strip() + task = task.strip() + if agent and task: + steps.append((agent, task)) + else: + # No colon: treat whole segment as task with empty agent + steps.append(("", segment.strip())) + return steps + + +def build_chain_messages( + chain: List[tuple[str, str]], +) -> List[AgentMessage]: + """Convert a parsed chain spec into AgentMessages. + + Each step's sender is the previous step's receiver (or 'user' for first). + """ + messages = [] + prev_agent = "user" + for agent_name, task in chain: + messages.append(AgentMessage( + sender=prev_agent, + receiver=agent_name, + content=task, + )) + prev_agent = agent_name + return messages diff --git a/lib/aliases.py b/lib/aliases.py new file mode 100644 index 00000000..4dc9627f --- /dev/null +++ b/lib/aliases.py @@ -0,0 +1,62 @@ +"""Agent name aliases for CCB. + +Resolves short aliases (a, b, c, ...) to provider names. + +Configuration layers (higher overrides lower): +1. Hardcoded defaults (DEFAULT_ALIASES) +2. ~/.ccb/aliases.json (global) +3. .ccb/aliases.json (project-level, relative to work_dir) +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Dict, Optional + +DEFAULT_ALIASES: Dict[str, str] = { + "a": "codex", + "b": "gemini", + "c": "claude", + "d": "opencode", + "e": "droid", + "f": "copilot", + "g": "codebuddy", + "h": "qwen", +} + + +def _load_json(path: Path) -> Dict[str, str]: + """Load aliases from a JSON file, returning {} on any error.""" + try: + if not path.is_file(): + return {} + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return {} + # Only keep str->str entries + return {str(k): str(v) for k, v in data.items()} + except (json.JSONDecodeError, OSError, ValueError): + print(f"[WARN] Failed to parse alias config: {path}", file=sys.stderr) + return {} + + +def load_aliases(work_dir: Optional[Path] = None) -> Dict[str, str]: + """Merge alias configs: defaults < ~/.ccb/aliases.json < .ccb/aliases.json.""" + merged = dict(DEFAULT_ALIASES) + + global_path = Path.home() / ".ccb" / "aliases.json" + merged.update(_load_json(global_path)) + + if work_dir is not None: + project_path = work_dir / ".ccb" / "aliases.json" + merged.update(_load_json(project_path)) + + return merged + + +def resolve_alias(name: str, aliases: Dict[str, str]) -> str: + """Resolve an alias to a provider name. Non-aliases pass through unchanged.""" + key = (name or "").strip().lower() + return aliases.get(key, key) diff --git a/lib/task_router.py b/lib/task_router.py new file mode 100644 index 00000000..01c3ca27 --- /dev/null +++ b/lib/task_router.py @@ -0,0 +1,210 @@ +"""Smart task routing for CCB Agent Teams. + +Routes tasks to the best provider based on message content analysis. +Supports keyword matching (Chinese + English) and team skill-based matching. + +Used by `ask --auto "message"` to auto-select the best provider. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + +from team_config import TeamAgent, TeamConfig + + +@dataclass +class RouteResult: + """Result of routing a task to a provider.""" + provider: str + model: str = "" + reason: str = "" + score: float = 0.0 + + +# --------------------------------------------------------------------------- +# Keyword → provider routing rules +# --------------------------------------------------------------------------- + +@dataclass +class RoutingRule: + """A keyword-based routing rule.""" + provider: str + model: str + keywords: List[str] + weight: float = 1.0 + + +# Default routing rules (reference: HiveMind ProviderRouter + CLAUDE.md mapping) +DEFAULT_ROUTING_RULES: List[RoutingRule] = [ + RoutingRule( + provider="gemini", model="3f", + keywords=["frontend", "前端", "react", "vue", "css", "html", "ui", "design", "设计", "样式", "组件", "tailwind", "nextjs"], + weight=1.5, + ), + RoutingRule( + provider="codex", model="o3", + keywords=["algorithm", "算法", "math", "数学", "proof", "证明", "reasoning", "推理", "逻辑", "logic", "complexity", "复杂度"], + weight=1.5, + ), + RoutingRule( + provider="codex", model="o3", + keywords=["review", "审查", "审核", "audit", "security", "安全", "code review", "代码审查"], + weight=1.5, + ), + RoutingRule( + provider="qwen", model="", + keywords=["python", "编程", "代码生成", "code", "coding", "implement", "实现", "sql", "database", "数据库", "数据分析"], + weight=1.0, + ), + RoutingRule( + provider="kimi", model="thinking", + keywords=["中文", "chinese", "翻译", "translate", "translation", "文案", "写作", "writing", "长文", "文档", "document", "总结", "summary", "分析"], + weight=1.0, + ), + RoutingRule( + provider="kimi", model="", + keywords=["explain", "解释", "概念", "concept", "快速", "quick", "shell", "bash", "运维", "devops"], + weight=0.8, + ), + RoutingRule( + provider="claude", model="", + keywords=["architecture", "架构", "设计模式", "design pattern", "重构", "refactor", "planning", "规划"], + weight=1.0, + ), +] + +# Default fallback when no keywords match +DEFAULT_FALLBACK = RouteResult(provider="kimi", model="", reason="default fallback", score=0.0) + + +def _score_message(message: str, keywords: List[str]) -> float: + """Score a message against a list of keywords. Returns number of matches.""" + text = message.lower() + score = 0.0 + for kw in keywords: + if kw.lower() in text: + score += 1.0 + return score + + +def route_by_keywords( + message: str, + rules: Optional[List[RoutingRule]] = None, + fallback: Optional[RouteResult] = None, +) -> RouteResult: + """Route a message to a provider based on keyword matching. + + Returns the RouteResult with the highest weighted score. + """ + if rules is None: + rules = DEFAULT_ROUTING_RULES + if fallback is None: + fallback = DEFAULT_FALLBACK + + if not message or not message.strip(): + return fallback + + best: Optional[RouteResult] = None + best_score = 0.0 + + for rule in rules: + raw_score = _score_message(message, rule.keywords) + if raw_score <= 0: + continue + weighted = raw_score * rule.weight + if weighted > best_score: + best_score = weighted + matched = [kw for kw in rule.keywords if kw.lower() in message.lower()] + best = RouteResult( + provider=rule.provider, + model=rule.model, + reason=f"keywords: {', '.join(matched[:3])}", + score=weighted, + ) + + return best if best else fallback + + +# --------------------------------------------------------------------------- +# Team skill-based routing +# --------------------------------------------------------------------------- + +def _score_agent_skills(agent: TeamAgent, message: str) -> float: + """Score a team agent against a message based on skills + role keywords.""" + if not agent.skills and not agent.role: + return 0.0 + text = message.lower() + score = 0.0 + for skill in agent.skills: + if skill in text: + score += 1.5 # skills are more specific, higher weight + if agent.role and agent.role in text: + score += 1.0 + return score + + +def route_by_team( + message: str, + team: TeamConfig, +) -> Optional[RouteResult]: + """Route a message to the best team agent based on skills and role matching. + + Returns None if no agent has a positive match score. + """ + if not message or not message.strip() or not team.agents: + return None + + best_agent: Optional[TeamAgent] = None + best_score = 0.0 + + for agent in team.agents: + score = _score_agent_skills(agent, message) + if score > best_score: + best_score = score + best_agent = agent + + if best_agent is None: + return None + + matched = [] + text = message.lower() + for s in best_agent.skills: + if s in text: + matched.append(s) + if best_agent.role and best_agent.role in text: + matched.append(f"role:{best_agent.role}") + + return RouteResult( + provider=best_agent.provider, + model=best_agent.model, + reason=f"team:{best_agent.name} ({', '.join(matched[:3])})", + score=best_score, + ) + + +# --------------------------------------------------------------------------- +# Unified auto-route +# --------------------------------------------------------------------------- + +def auto_route( + message: str, + team: Optional[TeamConfig] = None, +) -> RouteResult: + """Auto-route a message to the best provider. + + Resolution order: + 1. Team skill-based matching (if team config exists) + 2. Keyword-based matching (default rules) + 3. Default fallback + """ + # Try team-based routing first + if team is not None: + result = route_by_team(message, team) + if result: + return result + + # Fall back to keyword-based routing + return route_by_keywords(message) diff --git a/lib/team_config.py b/lib/team_config.py new file mode 100644 index 00000000..a884f648 --- /dev/null +++ b/lib/team_config.py @@ -0,0 +1,129 @@ +"""Team configuration for CCB Agent Teams. + +Loads team config from JSON files and resolves team agent names to providers. + +Configuration layers (higher overrides lower): +1. ~/.ccb/team.json (global) +2. .ccb/team.json (project-level) + +A team config defines named agents with provider, model, role, and skills. +Team agent names take priority over aliases when resolving provider names. +""" + +from __future__ import annotations + +import json +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + + +@dataclass +class TeamAgent: + """A named agent within a team.""" + name: str + provider: str + model: str = "" + role: str = "" + skills: List[str] = field(default_factory=list) + + +@dataclass +class TeamConfig: + """Team configuration with named agents and allocation strategy.""" + name: str + agents: List[TeamAgent] = field(default_factory=list) + strategy: str = "skill_based" # round_robin | load_balance | skill_based + description: str = "" + + def agent_map(self) -> Dict[str, TeamAgent]: + """Build name → TeamAgent lookup (case-insensitive).""" + return {a.name.lower(): a for a in self.agents} + + +VALID_STRATEGIES = {"round_robin", "load_balance", "skill_based"} + + +def _parse_agent(raw: dict) -> Optional[TeamAgent]: + """Parse a single agent entry from JSON. Returns None on invalid data.""" + if not isinstance(raw, dict): + return None + name = str(raw.get("name", "")).strip() + provider = str(raw.get("provider", "")).strip().lower() + if not name or not provider: + return None + return TeamAgent( + name=name.lower(), + provider=provider, + model=str(raw.get("model", "")).strip(), + role=str(raw.get("role", "")).strip().lower(), + skills=[str(s).strip().lower() for s in raw.get("skills", []) if str(s).strip()], + ) + + +def _load_team_json(path: Path) -> Optional[TeamConfig]: + """Load a team config from a JSON file. Returns None on any error.""" + try: + if not path.is_file(): + return None + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + except (json.JSONDecodeError, OSError, ValueError): + print(f"[WARN] Failed to parse team config: {path}", file=sys.stderr) + return None + + name = str(data.get("name", "")).strip() + if not name: + name = "default" + + strategy = str(data.get("strategy", "skill_based")).strip().lower() + if strategy not in VALID_STRATEGIES: + strategy = "skill_based" + + agents: List[TeamAgent] = [] + for raw_agent in data.get("agents", []): + agent = _parse_agent(raw_agent) + if agent: + agents.append(agent) + + if not agents: + return None + + return TeamConfig( + name=name, + agents=agents, + strategy=strategy, + description=str(data.get("description", "")).strip(), + ) + + +def load_team_config(work_dir: Optional[Path] = None) -> Optional[TeamConfig]: + """Load team config: project .ccb/team.json overrides global ~/.ccb/team.json. + + Returns None if no valid team config is found. + """ + global_path = Path.home() / ".ccb" / "team.json" + global_config = _load_team_json(global_path) + + project_config: Optional[TeamConfig] = None + if work_dir is not None: + project_path = work_dir / ".ccb" / "team.json" + project_config = _load_team_json(project_path) + + # Project-level takes full priority (not merged) + return project_config or global_config + + +def resolve_team_agent( + name: str, + team: Optional[TeamConfig], +) -> Optional[TeamAgent]: + """Resolve a name to a TeamAgent. Returns None if not a team agent.""" + if team is None: + return None + key = (name or "").strip().lower() + if not key: + return None + return team.agent_map().get(key) diff --git a/test/test_agent_comm.py b/test/test_agent_comm.py new file mode 100644 index 00000000..553dc5ad --- /dev/null +++ b/test/test_agent_comm.py @@ -0,0 +1,307 @@ +"""Tests for lib/agent_comm.py — inter-agent communication.""" + +from __future__ import annotations + +import pytest + +from agent_comm import ( + AgentMessage, + broadcast_message, + build_chain_messages, + parse_chain_spec, + resolve_agent_to_provider, + wrap_message, +) +from team_config import TeamAgent, TeamConfig + + +# --------------------------------------------------------------------------- +# AgentMessage +# --------------------------------------------------------------------------- + +class TestAgentMessage: + def test_basic(self): + msg = AgentMessage(sender="codex", receiver="gemini", content="hello") + assert msg.sender == "codex" + assert msg.receiver == "gemini" + assert msg.content == "hello" + assert msg.context == "" + + def test_with_context(self): + msg = AgentMessage(sender="a", receiver="b", content="review", context="prev output") + assert msg.context == "prev output" + + +# --------------------------------------------------------------------------- +# wrap_message +# --------------------------------------------------------------------------- + +class TestWrapMessage: + def test_basic_wrap(self): + msg = AgentMessage(sender="codex", receiver="gemini", content="hello") + result = wrap_message(msg) + assert "[CCB_FROM agent=codex]" in result + assert "hello" in result + assert "[CCB_CONTEXT]" not in result + + def test_wrap_with_context(self): + msg = AgentMessage(sender="a", receiver="b", content="review this", context="code here") + result = wrap_message(msg) + assert "[CCB_FROM agent=a]" in result + assert "[CCB_CONTEXT]" in result + assert "code here" in result + assert "[/CCB_CONTEXT]" in result + assert "review this" in result + + def test_wrap_order(self): + msg = AgentMessage(sender="x", receiver="y", content="task", context="ctx") + result = wrap_message(msg) + lines = result.split("\n") + # First line should be FROM + assert lines[0] == "[CCB_FROM agent=x]" + # Last line should be content + assert lines[-1] == "task" + + +# --------------------------------------------------------------------------- +# resolve_agent_to_provider +# --------------------------------------------------------------------------- + +class TestResolveAgentToProvider: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="test", + agents=[ + TeamAgent(name="researcher", provider="gemini"), + TeamAgent(name="coder", provider="codex"), + ], + ) + + def test_team_agent(self, team): + aliases = {"a": "codex"} + assert resolve_agent_to_provider("researcher", team, aliases) == "gemini" + + def test_alias(self, team): + aliases = {"a": "codex", "b": "gemini"} + assert resolve_agent_to_provider("a", None, aliases) == "codex" + + def test_team_over_alias(self, team): + aliases = {"researcher": "kimi"} # alias would give kimi + # Team agent should win + assert resolve_agent_to_provider("researcher", team, aliases) == "gemini" + + def test_direct_provider(self): + aliases = {} + assert resolve_agent_to_provider("kimi", None, aliases) == "kimi" + + def test_empty_name(self): + assert resolve_agent_to_provider("", None, {}) is None + + def test_none_name(self): + assert resolve_agent_to_provider(None, None, {}) is None + + def test_case_insensitive(self, team): + aliases = {} + assert resolve_agent_to_provider("Researcher", team, aliases) == "gemini" + + +# --------------------------------------------------------------------------- +# broadcast_message +# --------------------------------------------------------------------------- + +class TestBroadcastMessage: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev", + agents=[ + TeamAgent(name="a", provider="codex"), + TeamAgent(name="b", provider="gemini"), + TeamAgent(name="c", provider="claude"), + ], + ) + + def test_broadcast_excludes_sender(self, team): + msgs = broadcast_message("a", "hello everyone", team, exclude_sender=True) + assert len(msgs) == 2 + receivers = [m.receiver for m in msgs] + assert "codex" not in receivers # sender excluded + assert "gemini" in receivers + assert "claude" in receivers + + def test_broadcast_includes_sender(self, team): + msgs = broadcast_message("a", "hello", team, exclude_sender=False) + assert len(msgs) == 3 + + def test_broadcast_content(self, team): + msgs = broadcast_message("a", "sync up", team) + for m in msgs: + assert m.content == "sync up" + assert m.sender == "a" + + def test_broadcast_empty_team(self): + team = TeamConfig(name="empty", agents=[]) + msgs = broadcast_message("a", "hello", team) + assert msgs == [] + + def test_broadcast_sender_not_in_team(self, team): + msgs = broadcast_message("external", "hello", team) + assert len(msgs) == 3 # all agents receive + + +# --------------------------------------------------------------------------- +# parse_chain_spec +# --------------------------------------------------------------------------- + +class TestParseChainSpec: + def test_basic_chain(self): + result = parse_chain_spec("a:research | b:implement | c:review") + assert result == [("a", "research"), ("b", "implement"), ("c", "review")] + + def test_single_step(self): + result = parse_chain_spec("codex:write code") + assert result == [("codex", "write code")] + + def test_empty_string(self): + assert parse_chain_spec("") == [] + + def test_whitespace_handling(self): + result = parse_chain_spec(" a : task1 | b : task2 ") + assert result == [("a", "task1"), ("b", "task2")] + + def test_no_colon_skips(self): + # Segment without colon: agent is empty, treated as task only + result = parse_chain_spec("just a task") + assert result == [("", "just a task")] + + def test_mixed_valid_invalid(self): + result = parse_chain_spec("a:task1 | | b:task2") + assert result == [("a", "task1"), ("b", "task2")] + + def test_empty_agent_after_colon(self): + result = parse_chain_spec(":task") + assert result == [] # empty agent + + +# --------------------------------------------------------------------------- +# build_chain_messages +# --------------------------------------------------------------------------- + +class TestBuildChainMessages: + def test_basic_chain(self): + chain = [("gemini", "research"), ("codex", "implement"), ("claude", "review")] + msgs = build_chain_messages(chain) + assert len(msgs) == 3 + + assert msgs[0].sender == "user" + assert msgs[0].receiver == "gemini" + assert msgs[0].content == "research" + + assert msgs[1].sender == "gemini" + assert msgs[1].receiver == "codex" + assert msgs[1].content == "implement" + + assert msgs[2].sender == "codex" + assert msgs[2].receiver == "claude" + assert msgs[2].content == "review" + + def test_single_step(self): + msgs = build_chain_messages([("codex", "do it")]) + assert len(msgs) == 1 + assert msgs[0].sender == "user" + assert msgs[0].receiver == "codex" + + def test_empty_chain(self): + assert build_chain_messages([]) == [] + + +# --------------------------------------------------------------------------- +# Integration: --to flow simulation +# --------------------------------------------------------------------------- + +class TestToFlowIntegration: + """Simulate the --to flow as implemented in bin/ask.""" + + def test_to_resolves_and_wraps(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="reviewer", provider="claude"), + ]) + aliases = {"a": "codex"} + + sender = "codex" + to_agent = "reviewer" + message = "please review this code" + + target = resolve_agent_to_provider(to_agent, team, aliases) + assert target == "claude" + + msg = AgentMessage(sender=sender, receiver=target, content=message) + wrapped = wrap_message(msg) + assert "[CCB_FROM agent=codex]" in wrapped + assert "please review this code" in wrapped + + def test_to_with_alias(self): + aliases = {"a": "codex", "b": "gemini"} + + target = resolve_agent_to_provider("b", None, aliases) + assert target == "gemini" + + def test_chain_then_wrap(self): + chain = [("gemini", "research topic"), ("codex", "implement")] + msgs = build_chain_messages(chain) + + # Verify each step wraps correctly + for msg in msgs: + wrapped = wrap_message(msg) + assert f"[CCB_FROM agent={msg.sender}]" in wrapped + assert msg.content in wrapped + + +# --------------------------------------------------------------------------- +# Chain CLI flow simulation +# --------------------------------------------------------------------------- + +class TestChainFlowIntegration: + """Simulate the --chain flow as implemented in bin/ask.""" + + def test_chain_resolves_agents(self): + """Verify each step in a chain resolves to valid providers.""" + team = TeamConfig(name="t", agents=[ + TeamAgent(name="researcher", provider="gemini"), + TeamAgent(name="coder", provider="codex"), + TeamAgent(name="reviewer", provider="claude"), + ]) + aliases = {"a": "codex"} + + spec = "researcher:analyze | coder:implement | reviewer:check" + steps = parse_chain_spec(spec) + msgs = build_chain_messages(steps) + + for msg in msgs: + target = resolve_agent_to_provider(msg.receiver, team, aliases) + assert target is not None, f"Failed to resolve: {msg.receiver}" + + def test_chain_context_passing(self): + """Verify context from previous step is included in wrapped message.""" + msg = AgentMessage( + sender="gemini", receiver="codex", + content="implement this", + context="Research result: use quicksort", + ) + wrapped = wrap_message(msg) + assert "[CCB_FROM agent=gemini]" in wrapped + assert "[CCB_CONTEXT]" in wrapped + assert "Research result: use quicksort" in wrapped + assert "implement this" in wrapped + + def test_chain_with_aliases(self): + """Verify aliases work in chain specs.""" + aliases = {"a": "codex", "b": "gemini"} + steps = parse_chain_spec("b:research | a:implement") + msgs = build_chain_messages(steps) + + target0 = resolve_agent_to_provider(msgs[0].receiver, None, aliases) + target1 = resolve_agent_to_provider(msgs[1].receiver, None, aliases) + assert target0 == "gemini" + assert target1 == "codex" diff --git a/test/test_aliases.py b/test/test_aliases.py new file mode 100644 index 00000000..6080159b --- /dev/null +++ b/test/test_aliases.py @@ -0,0 +1,199 @@ +"""Tests for lib/aliases.py — agent name alias resolution.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from aliases import DEFAULT_ALIASES, _load_json, load_aliases, resolve_alias + + +# --------------------------------------------------------------------------- +# resolve_alias +# --------------------------------------------------------------------------- + +class TestResolveAlias: + def test_known_alias(self): + aliases = {"a": "codex", "b": "gemini"} + assert resolve_alias("a", aliases) == "codex" + + def test_unknown_passthrough(self): + assert resolve_alias("kimi", {}) == "kimi" + + def test_case_insensitive(self): + aliases = {"a": "codex"} + assert resolve_alias("A", aliases) == "codex" + + def test_whitespace_stripped(self): + aliases = {"a": "codex"} + assert resolve_alias(" a ", aliases) == "codex" + + def test_empty_string(self): + assert resolve_alias("", {"": "x"}) == "x" + assert resolve_alias("", {}) == "" + + def test_none_safe(self): + assert resolve_alias(None, {}) == "" + + +# --------------------------------------------------------------------------- +# _load_json +# --------------------------------------------------------------------------- + +class TestLoadJson: + def test_missing_file(self, tmp_path: Path): + assert _load_json(tmp_path / "nope.json") == {} + + def test_valid_file(self, tmp_path: Path): + f = tmp_path / "a.json" + f.write_text(json.dumps({"x": "codex", "y": "gemini"})) + assert _load_json(f) == {"x": "codex", "y": "gemini"} + + def test_corrupt_json(self, tmp_path: Path): + f = tmp_path / "bad.json" + f.write_text("{not valid json") + assert _load_json(f) == {} + + def test_non_dict_json(self, tmp_path: Path): + f = tmp_path / "arr.json" + f.write_text(json.dumps([1, 2, 3])) + assert _load_json(f) == {} + + def test_coerces_values_to_str(self, tmp_path: Path): + f = tmp_path / "mixed.json" + f.write_text(json.dumps({"a": 123, "b": True})) + result = _load_json(f) + assert result == {"a": "123", "b": "True"} + + +# --------------------------------------------------------------------------- +# load_aliases +# --------------------------------------------------------------------------- + +class TestLoadAliases: + def test_defaults_only(self, tmp_path: Path): + """No config files → returns DEFAULT_ALIASES.""" + result = load_aliases(work_dir=tmp_path) + assert result == DEFAULT_ALIASES + + def test_global_overrides_default(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "aliases.json").write_text(json.dumps({"a": "gemini"})) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=tmp_path / "project") + assert result["a"] == "gemini" + # Other defaults still present + assert result["b"] == DEFAULT_ALIASES["b"] + + def test_project_overrides_global(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "aliases.json").write_text(json.dumps({"a": "gemini"})) + + proj = tmp_path / "project" + proj_ccb = proj / ".ccb" + proj_ccb.mkdir(parents=True) + (proj_ccb / "aliases.json").write_text(json.dumps({"a": "kimi"})) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=proj) + assert result["a"] == "kimi" + + def test_no_work_dir(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + home.mkdir() + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=None) + assert result == DEFAULT_ALIASES + + def test_custom_alias_added(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "aliases.json").write_text(json.dumps({"z": "deepseek"})) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=tmp_path) + assert result["z"] == "deepseek" + # Defaults preserved + assert result["a"] == DEFAULT_ALIASES["a"] + + +# --------------------------------------------------------------------------- +# Alias + instance (colon-separated) integration +# --------------------------------------------------------------------------- + +class TestAliasWithInstance: + """Test the pattern used in bin/ask: alias:instance resolution.""" + + def test_alias_with_instance(self): + aliases = DEFAULT_ALIASES + raw = "a:review" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "codex:review" + + def test_plain_alias(self): + aliases = DEFAULT_ALIASES + raw = "b" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "gemini" + + def test_non_alias_with_instance(self): + aliases = DEFAULT_ALIASES + raw = "codex:auth" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "codex:auth" + + def test_non_alias_plain(self): + aliases = DEFAULT_ALIASES + raw = "kimi" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "kimi" + + +# --------------------------------------------------------------------------- +# Integration with parse_qualified_provider +# --------------------------------------------------------------------------- + +class TestIntegrationWithProviders: + """Verify alias resolution works with parse_qualified_provider.""" + + def test_alias_then_parse(self): + from providers import parse_qualified_provider + + aliases = DEFAULT_ALIASES + raw = "a:review" + base, _, instance = raw.partition(":") + base = resolve_alias(base, aliases) + qualified = f"{base}:{instance}" if instance else base + + provider, inst = parse_qualified_provider(qualified) + assert provider == "codex" + assert inst == "review" + + def test_plain_alias_then_parse(self): + from providers import parse_qualified_provider + + aliases = DEFAULT_ALIASES + raw = "c" + base, _, instance = raw.partition(":") + base = resolve_alias(base, aliases) + qualified = f"{base}:{instance}" if instance else base + + provider, inst = parse_qualified_provider(qualified) + assert provider == "claude" + assert inst is None diff --git a/test/test_task_router.py b/test/test_task_router.py new file mode 100644 index 00000000..9e260e82 --- /dev/null +++ b/test/test_task_router.py @@ -0,0 +1,278 @@ +"""Tests for lib/task_router.py — smart task routing.""" + +from __future__ import annotations + +import pytest + +from task_router import ( + DEFAULT_FALLBACK, + DEFAULT_ROUTING_RULES, + RouteResult, + RoutingRule, + auto_route, + route_by_keywords, + route_by_team, + _score_message, + _score_agent_skills, +) +from team_config import TeamAgent, TeamConfig + + +# --------------------------------------------------------------------------- +# _score_message +# --------------------------------------------------------------------------- + +class TestScoreMessage: + def test_single_match(self): + assert _score_message("build a React component", ["react"]) == 1.0 + + def test_multiple_matches(self): + assert _score_message("React CSS HTML frontend", ["react", "css", "html"]) == 3.0 + + def test_no_match(self): + assert _score_message("hello world", ["react", "vue"]) == 0.0 + + def test_case_insensitive(self): + assert _score_message("Build React UI", ["react", "ui"]) == 2.0 + + def test_chinese_keywords(self): + assert _score_message("帮我写一个前端组件", ["前端"]) == 1.0 + + def test_empty_message(self): + assert _score_message("", ["react"]) == 0.0 + + +# --------------------------------------------------------------------------- +# route_by_keywords +# --------------------------------------------------------------------------- + +class TestRouteByKeywords: + def test_frontend_keywords(self): + result = route_by_keywords("帮我写一个 React 前端组件") + assert result.provider == "gemini" + + def test_algorithm_keywords(self): + result = route_by_keywords("分析这个算法的时间复杂度") + assert result.provider == "codex" + + def test_review_keywords(self): + result = route_by_keywords("请审查这段代码的安全性") + assert result.provider == "codex" + + def test_chinese_writing(self): + result = route_by_keywords("翻译这段话成英文") + assert result.provider == "kimi" + + def test_python_coding(self): + result = route_by_keywords("用 Python 实现一个快排") + assert result.provider == "qwen" + + def test_architecture(self): + result = route_by_keywords("帮我重构这个架构设计模式") + assert result.provider == "claude" + + def test_no_match_uses_fallback(self): + result = route_by_keywords("hello world 123") + assert result.provider == DEFAULT_FALLBACK.provider + + def test_empty_message_uses_fallback(self): + result = route_by_keywords("") + assert result.provider == DEFAULT_FALLBACK.provider + + def test_custom_rules(self): + rules = [RoutingRule(provider="custom", model="v1", keywords=["magic"], weight=1.0)] + result = route_by_keywords("do magic stuff", rules=rules) + assert result.provider == "custom" + assert result.model == "v1" + + def test_custom_fallback(self): + fb = RouteResult(provider="fallback_provider", reason="custom fb") + result = route_by_keywords("no match", rules=[], fallback=fb) + assert result.provider == "fallback_provider" + + def test_higher_weight_wins(self): + rules = [ + RoutingRule(provider="low", model="", keywords=["code"], weight=0.5), + RoutingRule(provider="high", model="", keywords=["code"], weight=2.0), + ] + result = route_by_keywords("write code", rules=rules) + assert result.provider == "high" + + def test_more_matches_wins(self): + rules = [ + RoutingRule(provider="one_match", model="", keywords=["react"], weight=1.0), + RoutingRule(provider="two_matches", model="", keywords=["react", "css"], weight=1.0), + ] + result = route_by_keywords("build React with CSS", rules=rules) + assert result.provider == "two_matches" + + def test_reason_includes_keywords(self): + result = route_by_keywords("帮我写 React 前端") + assert "keywords:" in result.reason + + def test_score_is_positive(self): + result = route_by_keywords("用 React 写前端") + assert result.score > 0 + + +# --------------------------------------------------------------------------- +# _score_agent_skills +# --------------------------------------------------------------------------- + +class TestScoreAgentSkills: + def test_skill_match(self): + agent = TeamAgent(name="dev", provider="codex", skills=["python", "rust"]) + assert _score_agent_skills(agent, "write python code") == 1.5 + + def test_role_match(self): + agent = TeamAgent(name="dev", provider="codex", role="review") + assert _score_agent_skills(agent, "please review this") == 1.0 + + def test_skill_and_role_match(self): + agent = TeamAgent(name="dev", provider="codex", role="review", skills=["security"]) + score = _score_agent_skills(agent, "security review needed") + assert score == 2.5 # 1.5 (skill) + 1.0 (role) + + def test_no_match(self): + agent = TeamAgent(name="dev", provider="codex", skills=["python"]) + assert _score_agent_skills(agent, "write rust code") == 0.0 + + def test_no_skills_no_role(self): + agent = TeamAgent(name="dev", provider="codex") + assert _score_agent_skills(agent, "anything") == 0.0 + + +# --------------------------------------------------------------------------- +# route_by_team +# --------------------------------------------------------------------------- + +class TestRouteByTeam: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev-team", + agents=[ + TeamAgent(name="frontend", provider="gemini", model="3f", role="research", skills=["react", "css", "frontend"]), + TeamAgent(name="backend", provider="codex", model="o3", role="implementation", skills=["python", "api", "database"]), + TeamAgent(name="reviewer", provider="claude", role="review", skills=["security", "architecture"]), + ], + ) + + def test_frontend_task(self, team): + result = route_by_team("build a React frontend component", team) + assert result is not None + assert result.provider == "gemini" + assert "team:frontend" in result.reason + + def test_backend_task(self, team): + result = route_by_team("implement Python API with database", team) + assert result is not None + assert result.provider == "codex" + + def test_review_task(self, team): + result = route_by_team("security review of the architecture", team) + assert result is not None + assert result.provider == "claude" + + def test_no_match(self, team): + result = route_by_team("hello world", team) + assert result is None + + def test_empty_message(self, team): + assert route_by_team("", team) is None + + def test_empty_team(self): + team = TeamConfig(name="empty", agents=[]) + assert route_by_team("anything", team) is None + + def test_best_agent_wins(self, team): + # "python api" matches backend (2 skills), not frontend + result = route_by_team("python api endpoint", team) + assert result is not None + assert result.provider == "codex" + + def test_score_is_positive(self, team): + result = route_by_team("react frontend", team) + assert result is not None + assert result.score > 0 + + +# --------------------------------------------------------------------------- +# auto_route +# --------------------------------------------------------------------------- + +class TestAutoRoute: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev-team", + agents=[ + TeamAgent(name="fe", provider="gemini", skills=["react", "frontend"]), + TeamAgent(name="be", provider="codex", skills=["python", "api"]), + ], + ) + + def test_team_match_preferred(self, team): + result = auto_route("build react frontend", team) + assert result.provider == "gemini" + assert "team:" in result.reason + + def test_falls_to_keywords_when_no_team_match(self, team): + result = auto_route("翻译这段话", team) + assert result.provider == "kimi" # keyword match, not team + + def test_falls_to_keywords_without_team(self): + result = auto_route("React 前端组件") + assert result.provider == "gemini" + + def test_falls_to_fallback(self): + result = auto_route("hello there") + assert result.provider == DEFAULT_FALLBACK.provider + + def test_no_team_uses_keywords(self): + result = auto_route("分析算法复杂度", None) + assert result.provider == "codex" + + def test_empty_message(self): + result = auto_route("") + assert result.provider == DEFAULT_FALLBACK.provider + + +# --------------------------------------------------------------------------- +# Integration: full ask --auto flow simulation +# --------------------------------------------------------------------------- + +class TestAutoRouteIntegration: + """Simulate the full --auto flow as implemented in bin/ask.""" + + def _simulate_auto(self, message: str, team: TeamConfig | None = None) -> str: + route = auto_route(message, team) + return route.provider + + def test_auto_frontend(self): + assert self._simulate_auto("帮我写一个 React 前端组件") == "gemini" + + def test_auto_algorithm(self): + assert self._simulate_auto("分析这个算法的时间复杂度") == "codex" + + def test_auto_translation(self): + assert self._simulate_auto("翻译这段话成英文") == "kimi" + + def test_auto_python(self): + assert self._simulate_auto("用 Python 实现快排") == "qwen" + + def test_auto_review(self): + assert self._simulate_auto("请审查这段代码") == "codex" + + def test_auto_with_team(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="fe", provider="qwen", skills=["react", "frontend"]), + ]) + # Team skill match overrides keyword match + assert self._simulate_auto("build react frontend", team) == "qwen" + + def test_auto_team_no_match_falls_to_keywords(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="fe", provider="qwen", skills=["react"]), + ]) + assert self._simulate_auto("分析算法复杂度", team) == "codex" diff --git a/test/test_team_config.py b/test/test_team_config.py new file mode 100644 index 00000000..02a484a1 --- /dev/null +++ b/test/test_team_config.py @@ -0,0 +1,387 @@ +"""Tests for lib/team_config.py — team configuration and agent resolution.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from team_config import ( + VALID_STRATEGIES, + TeamAgent, + TeamConfig, + _load_team_json, + _parse_agent, + load_team_config, + resolve_team_agent, +) + + +# --------------------------------------------------------------------------- +# TeamAgent / TeamConfig dataclasses +# --------------------------------------------------------------------------- + +class TestTeamConfig: + def test_agent_map_lookup(self): + team = TeamConfig( + name="test", + agents=[ + TeamAgent(name="coder", provider="codex", model="o3", role="implementation"), + TeamAgent(name="reviewer", provider="claude", role="review"), + ], + ) + m = team.agent_map() + assert m["coder"].provider == "codex" + assert m["reviewer"].provider == "claude" + + def test_agent_map_case_insensitive(self): + team = TeamConfig( + name="test", + agents=[TeamAgent(name="Coder", provider="codex")], + ) + # Names are lowered during parse, but test direct construction + m = team.agent_map() + assert "coder" in m + + def test_empty_agents(self): + team = TeamConfig(name="empty", agents=[]) + assert team.agent_map() == {} + + +# --------------------------------------------------------------------------- +# _parse_agent +# --------------------------------------------------------------------------- + +class TestParseAgent: + def test_valid_agent(self): + raw = {"name": "coder", "provider": "codex", "model": "o3", "role": "implementation", "skills": ["python", "rust"]} + agent = _parse_agent(raw) + assert agent is not None + assert agent.name == "coder" + assert agent.provider == "codex" + assert agent.model == "o3" + assert agent.role == "implementation" + assert agent.skills == ["python", "rust"] + + def test_minimal_agent(self): + raw = {"name": "bot", "provider": "gemini"} + agent = _parse_agent(raw) + assert agent is not None + assert agent.name == "bot" + assert agent.provider == "gemini" + assert agent.model == "" + assert agent.role == "" + assert agent.skills == [] + + def test_missing_name(self): + assert _parse_agent({"provider": "codex"}) is None + + def test_missing_provider(self): + assert _parse_agent({"name": "bot"}) is None + + def test_empty_name(self): + assert _parse_agent({"name": "", "provider": "codex"}) is None + + def test_not_a_dict(self): + assert _parse_agent("invalid") is None + assert _parse_agent(42) is None + assert _parse_agent(None) is None + + def test_skills_filters_empty(self): + raw = {"name": "bot", "provider": "gemini", "skills": ["python", "", " ", "rust"]} + agent = _parse_agent(raw) + assert agent.skills == ["python", "rust"] + + def test_provider_lowered(self): + raw = {"name": "bot", "provider": "Gemini"} + agent = _parse_agent(raw) + assert agent.provider == "gemini" + + +# --------------------------------------------------------------------------- +# _load_team_json +# --------------------------------------------------------------------------- + +class TestLoadTeamJson: + def test_missing_file(self, tmp_path: Path): + assert _load_team_json(tmp_path / "nope.json") is None + + def test_valid_config(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "dev-team", + "strategy": "skill_based", + "agents": [ + {"name": "coder", "provider": "codex", "model": "o3", "role": "implementation"}, + {"name": "reviewer", "provider": "claude", "role": "review"}, + ], + })) + team = _load_team_json(f) + assert team is not None + assert team.name == "dev-team" + assert team.strategy == "skill_based" + assert len(team.agents) == 2 + + def test_corrupt_json(self, tmp_path: Path): + f = tmp_path / "bad.json" + f.write_text("{invalid json") + assert _load_team_json(f) is None + + def test_non_dict_json(self, tmp_path: Path): + f = tmp_path / "arr.json" + f.write_text(json.dumps([1, 2])) + assert _load_team_json(f) is None + + def test_no_agents_returns_none(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({"name": "empty", "agents": []})) + assert _load_team_json(f) is None + + def test_invalid_agents_skipped(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "partial", + "agents": [ + {"name": "good", "provider": "codex"}, + {"name": "", "provider": "gemini"}, # invalid: empty name + "not_a_dict", # invalid: not dict + {"provider": "claude"}, # invalid: no name + ], + })) + team = _load_team_json(f) + assert team is not None + assert len(team.agents) == 1 + assert team.agents[0].name == "good" + + def test_default_name(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({"agents": [{"name": "a", "provider": "codex"}]})) + team = _load_team_json(f) + assert team.name == "default" + + def test_invalid_strategy_defaults(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "t", + "strategy": "invalid_strategy", + "agents": [{"name": "a", "provider": "codex"}], + })) + team = _load_team_json(f) + assert team.strategy == "skill_based" + + def test_all_valid_strategies(self, tmp_path: Path): + for strategy in VALID_STRATEGIES: + f = tmp_path / f"team_{strategy}.json" + f.write_text(json.dumps({ + "name": "t", + "strategy": strategy, + "agents": [{"name": "a", "provider": "codex"}], + })) + team = _load_team_json(f) + assert team.strategy == strategy + + def test_description_field(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "t", + "description": "My dev team", + "agents": [{"name": "a", "provider": "codex"}], + })) + team = _load_team_json(f) + assert team.description == "My dev team" + + +# --------------------------------------------------------------------------- +# load_team_config +# --------------------------------------------------------------------------- + +class TestLoadTeamConfig: + def test_no_config(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + home.mkdir() + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + assert load_team_config(work_dir=tmp_path) is None + + def test_global_config(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "team.json").write_text(json.dumps({ + "name": "global-team", + "agents": [{"name": "bot", "provider": "gemini"}], + })) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + team = load_team_config(work_dir=tmp_path / "project") + assert team is not None + assert team.name == "global-team" + + def test_project_overrides_global(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "team.json").write_text(json.dumps({ + "name": "global-team", + "agents": [{"name": "bot", "provider": "gemini"}], + })) + + proj = tmp_path / "project" + proj_ccb = proj / ".ccb" + proj_ccb.mkdir(parents=True) + (proj_ccb / "team.json").write_text(json.dumps({ + "name": "project-team", + "agents": [{"name": "coder", "provider": "codex"}], + })) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + team = load_team_config(work_dir=proj) + assert team.name == "project-team" + assert team.agents[0].name == "coder" + + def test_no_work_dir(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "team.json").write_text(json.dumps({ + "name": "global", + "agents": [{"name": "bot", "provider": "kimi"}], + })) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + team = load_team_config(work_dir=None) + assert team is not None + assert team.name == "global" + + +# --------------------------------------------------------------------------- +# resolve_team_agent +# --------------------------------------------------------------------------- + +class TestResolveTeamAgent: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev", + agents=[ + TeamAgent(name="researcher", provider="gemini", model="3f", role="research"), + TeamAgent(name="coder", provider="codex", model="o3", role="implementation"), + TeamAgent(name="reviewer", provider="claude", role="review"), + ], + ) + + def test_resolve_known_agent(self, team): + agent = resolve_team_agent("researcher", team) + assert agent is not None + assert agent.provider == "gemini" + assert agent.model == "3f" + + def test_resolve_case_insensitive(self, team): + agent = resolve_team_agent("Coder", team) + assert agent is not None + assert agent.provider == "codex" + + def test_resolve_unknown_returns_none(self, team): + assert resolve_team_agent("unknown", team) is None + + def test_resolve_no_team(self): + assert resolve_team_agent("coder", None) is None + + def test_resolve_empty_name(self, team): + assert resolve_team_agent("", team) is None + + def test_resolve_none_name(self, team): + assert resolve_team_agent(None, team) is None + + +# --------------------------------------------------------------------------- +# Integration: team agents override aliases +# --------------------------------------------------------------------------- + +class TestTeamOverridesAlias: + """Verify team agent names take priority over aliases.""" + + def test_team_agent_overrides_alias(self): + from aliases import DEFAULT_ALIASES, resolve_alias + + team = TeamConfig( + name="test", + agents=[TeamAgent(name="a", provider="kimi")], # override alias a→codex + ) + + name = "a" + # Team resolution first + team_agent = resolve_team_agent(name, team) + if team_agent: + provider = team_agent.provider + else: + provider = resolve_alias(name, DEFAULT_ALIASES) + + assert provider == "kimi" # team wins over alias + + def test_non_team_falls_to_alias(self): + from aliases import DEFAULT_ALIASES, resolve_alias + + team = TeamConfig( + name="test", + agents=[TeamAgent(name="coder", provider="codex")], + ) + + name = "a" + team_agent = resolve_team_agent(name, team) + if team_agent: + provider = team_agent.provider + else: + provider = resolve_alias(name, DEFAULT_ALIASES) + + assert provider == "codex" # alias a→codex + + +# --------------------------------------------------------------------------- +# Full resolution flow (as in bin/ask) +# --------------------------------------------------------------------------- + +class TestFullResolutionFlow: + """Simulate the full resolution flow in bin/ask.""" + + def _resolve(self, raw_provider: str, team: TeamConfig | None) -> tuple[str, str | None]: + from aliases import load_aliases, resolve_alias + from providers import parse_qualified_provider + + team_agent = resolve_team_agent(raw_provider, team) + if team_agent: + raw_provider = team_agent.provider + else: + aliases = {"a": "codex", "b": "gemini", "c": "claude"} + base_part, _, instance_part = raw_provider.partition(":") + base_part = resolve_alias(base_part, aliases) + raw_provider = f"{base_part}:{instance_part}" if instance_part else base_part + + return parse_qualified_provider(raw_provider) + + def test_team_agent_resolves(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="researcher", provider="gemini", model="3f"), + ]) + provider, instance = self._resolve("researcher", team) + assert provider == "gemini" + assert instance is None + + def test_alias_resolves_without_team(self): + provider, instance = self._resolve("a", None) + assert provider == "codex" + + def test_alias_with_instance(self): + provider, instance = self._resolve("a:review", None) + assert provider == "codex" + assert instance == "review" + + def test_direct_provider(self): + provider, instance = self._resolve("kimi", None) + assert provider == "kimi" + assert instance is None + + def test_team_agent_overrides_alias_letter(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="a", provider="qwen"), + ]) + provider, instance = self._resolve("a", team) + assert provider == "qwen" # team wins