diff --git a/bin/ask b/bin/ask index f3167bff..f90aea70 100755 --- a/bin/ask +++ b/bin/ask @@ -45,9 +45,12 @@ sys.path.insert(0, str(lib_dir)) from compat import read_stdin_text, setup_windows_encoding setup_windows_encoding() +from aliases import load_aliases, resolve_alias from cli_output import EXIT_ERROR, EXIT_OK from providers import parse_qualified_provider from session_utils import find_project_session_file +from task_router import auto_route +from team_config import load_team_config, resolve_team_agent # Provider to daemon command mapping @@ -481,11 +484,22 @@ def make_task_id() -> str: def _usage() -> None: + aliases = load_aliases() + alias_list = ", ".join(f"{k}→{v}" for k, v in sorted(aliases.items())) print("Usage: ask [options] ", file=sys.stderr) print("", file=sys.stderr) print("Providers:", file=sys.stderr) print(" gemini, codex, opencode, droid, claude, copilot, codebuddy, qwen", file=sys.stderr) print("", file=sys.stderr) + print("Aliases:", file=sys.stderr) + print(f" {alias_list}", file=sys.stderr) + print("", file=sys.stderr) + team = load_team_config(Path.cwd()) + if team: + agents = ", ".join(f"{a.name}→{a.provider}" for a in team.agents) + print(f"Team '{team.name}' ({team.strategy}):", file=sys.stderr) + print(f" {agents}", file=sys.stderr) + print("", file=sys.stderr) print("Options:", file=sys.stderr) print(" -h, --help Show this help message", file=sys.stderr) print(" -t, --timeout SECONDS Request timeout (default: 3600)", file=sys.stderr) @@ -493,6 +507,7 @@ def _usage() -> None: print(" --foreground Run in foreground (no nohup/background)", file=sys.stderr) print(" --background Force background mode", file=sys.stderr) print(" --no-wrap Don't wrap with CCB protocol markers", file=sys.stderr) + print(" --auto Auto-select provider based on message content", file=sys.stderr) def main(argv: list[str]) -> int: @@ -500,24 +515,17 @@ def main(argv: list[str]) -> int: _usage() return EXIT_ERROR - # First argument must be the provider + # First argument must be the provider (or --auto / --help) raw_provider = argv[1].lower() if raw_provider in ("-h", "--help"): _usage() return EXIT_OK - base_provider, instance = parse_qualified_provider(raw_provider) - - if base_provider not in PROVIDER_DAEMONS: - print(f"[ERROR] Unknown provider: {base_provider}", file=sys.stderr) - print(f"[ERROR] Available: {', '.join(PROVIDER_DAEMONS.keys())}", file=sys.stderr) - return EXIT_ERROR - - daemon_cmd = PROVIDER_DAEMONS[base_provider] - provider = raw_provider # keep full qualified key for daemon routing + auto_mode = raw_provider == "--auto" + cwd = Path.cwd() - # Parse remaining arguments + # Parse remaining arguments (shared by both auto and normal modes) timeout: float = 3600.0 notify_mode = False no_wrap = False @@ -560,6 +568,34 @@ def main(argv: list[str]) -> int: print("[ERROR] Message cannot be empty", file=sys.stderr) return EXIT_ERROR + # --auto mode: select provider based on message content + if auto_mode: + team = load_team_config(cwd) + route = auto_route(message, team) + raw_provider = route.provider + print(f"[AUTO] → {raw_provider}" + (f" ({route.reason})" if route.reason else ""), file=sys.stderr) + else: + # Resolution order: team agents > aliases > direct provider names + team = load_team_config(cwd) + team_agent = resolve_team_agent(raw_provider, team) + if team_agent: + raw_provider = team_agent.provider + else: + aliases = load_aliases(cwd) + base_part, _, instance_part = raw_provider.partition(":") + base_part = resolve_alias(base_part, aliases) + raw_provider = f"{base_part}:{instance_part}" if instance_part else base_part + + base_provider, instance = parse_qualified_provider(raw_provider) + + if base_provider not in PROVIDER_DAEMONS: + print(f"[ERROR] Unknown provider: {base_provider}", file=sys.stderr) + print(f"[ERROR] Available: {', '.join(PROVIDER_DAEMONS.keys())}", file=sys.stderr) + return EXIT_ERROR + + daemon_cmd = PROVIDER_DAEMONS[base_provider] + provider = raw_provider # keep full qualified key for daemon routing + # Notify mode: sync send, no wait for reply (used for hook notifications) if notify_mode: _require_caller() diff --git a/lib/aliases.py b/lib/aliases.py new file mode 100644 index 00000000..4dc9627f --- /dev/null +++ b/lib/aliases.py @@ -0,0 +1,62 @@ +"""Agent name aliases for CCB. + +Resolves short aliases (a, b, c, ...) to provider names. + +Configuration layers (higher overrides lower): +1. Hardcoded defaults (DEFAULT_ALIASES) +2. ~/.ccb/aliases.json (global) +3. .ccb/aliases.json (project-level, relative to work_dir) +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Dict, Optional + +DEFAULT_ALIASES: Dict[str, str] = { + "a": "codex", + "b": "gemini", + "c": "claude", + "d": "opencode", + "e": "droid", + "f": "copilot", + "g": "codebuddy", + "h": "qwen", +} + + +def _load_json(path: Path) -> Dict[str, str]: + """Load aliases from a JSON file, returning {} on any error.""" + try: + if not path.is_file(): + return {} + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return {} + # Only keep str->str entries + return {str(k): str(v) for k, v in data.items()} + except (json.JSONDecodeError, OSError, ValueError): + print(f"[WARN] Failed to parse alias config: {path}", file=sys.stderr) + return {} + + +def load_aliases(work_dir: Optional[Path] = None) -> Dict[str, str]: + """Merge alias configs: defaults < ~/.ccb/aliases.json < .ccb/aliases.json.""" + merged = dict(DEFAULT_ALIASES) + + global_path = Path.home() / ".ccb" / "aliases.json" + merged.update(_load_json(global_path)) + + if work_dir is not None: + project_path = work_dir / ".ccb" / "aliases.json" + merged.update(_load_json(project_path)) + + return merged + + +def resolve_alias(name: str, aliases: Dict[str, str]) -> str: + """Resolve an alias to a provider name. Non-aliases pass through unchanged.""" + key = (name or "").strip().lower() + return aliases.get(key, key) diff --git a/lib/task_router.py b/lib/task_router.py new file mode 100644 index 00000000..01c3ca27 --- /dev/null +++ b/lib/task_router.py @@ -0,0 +1,210 @@ +"""Smart task routing for CCB Agent Teams. + +Routes tasks to the best provider based on message content analysis. +Supports keyword matching (Chinese + English) and team skill-based matching. + +Used by `ask --auto "message"` to auto-select the best provider. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + +from team_config import TeamAgent, TeamConfig + + +@dataclass +class RouteResult: + """Result of routing a task to a provider.""" + provider: str + model: str = "" + reason: str = "" + score: float = 0.0 + + +# --------------------------------------------------------------------------- +# Keyword → provider routing rules +# --------------------------------------------------------------------------- + +@dataclass +class RoutingRule: + """A keyword-based routing rule.""" + provider: str + model: str + keywords: List[str] + weight: float = 1.0 + + +# Default routing rules (reference: HiveMind ProviderRouter + CLAUDE.md mapping) +DEFAULT_ROUTING_RULES: List[RoutingRule] = [ + RoutingRule( + provider="gemini", model="3f", + keywords=["frontend", "前端", "react", "vue", "css", "html", "ui", "design", "设计", "样式", "组件", "tailwind", "nextjs"], + weight=1.5, + ), + RoutingRule( + provider="codex", model="o3", + keywords=["algorithm", "算法", "math", "数学", "proof", "证明", "reasoning", "推理", "逻辑", "logic", "complexity", "复杂度"], + weight=1.5, + ), + RoutingRule( + provider="codex", model="o3", + keywords=["review", "审查", "审核", "audit", "security", "安全", "code review", "代码审查"], + weight=1.5, + ), + RoutingRule( + provider="qwen", model="", + keywords=["python", "编程", "代码生成", "code", "coding", "implement", "实现", "sql", "database", "数据库", "数据分析"], + weight=1.0, + ), + RoutingRule( + provider="kimi", model="thinking", + keywords=["中文", "chinese", "翻译", "translate", "translation", "文案", "写作", "writing", "长文", "文档", "document", "总结", "summary", "分析"], + weight=1.0, + ), + RoutingRule( + provider="kimi", model="", + keywords=["explain", "解释", "概念", "concept", "快速", "quick", "shell", "bash", "运维", "devops"], + weight=0.8, + ), + RoutingRule( + provider="claude", model="", + keywords=["architecture", "架构", "设计模式", "design pattern", "重构", "refactor", "planning", "规划"], + weight=1.0, + ), +] + +# Default fallback when no keywords match +DEFAULT_FALLBACK = RouteResult(provider="kimi", model="", reason="default fallback", score=0.0) + + +def _score_message(message: str, keywords: List[str]) -> float: + """Score a message against a list of keywords. Returns number of matches.""" + text = message.lower() + score = 0.0 + for kw in keywords: + if kw.lower() in text: + score += 1.0 + return score + + +def route_by_keywords( + message: str, + rules: Optional[List[RoutingRule]] = None, + fallback: Optional[RouteResult] = None, +) -> RouteResult: + """Route a message to a provider based on keyword matching. + + Returns the RouteResult with the highest weighted score. + """ + if rules is None: + rules = DEFAULT_ROUTING_RULES + if fallback is None: + fallback = DEFAULT_FALLBACK + + if not message or not message.strip(): + return fallback + + best: Optional[RouteResult] = None + best_score = 0.0 + + for rule in rules: + raw_score = _score_message(message, rule.keywords) + if raw_score <= 0: + continue + weighted = raw_score * rule.weight + if weighted > best_score: + best_score = weighted + matched = [kw for kw in rule.keywords if kw.lower() in message.lower()] + best = RouteResult( + provider=rule.provider, + model=rule.model, + reason=f"keywords: {', '.join(matched[:3])}", + score=weighted, + ) + + return best if best else fallback + + +# --------------------------------------------------------------------------- +# Team skill-based routing +# --------------------------------------------------------------------------- + +def _score_agent_skills(agent: TeamAgent, message: str) -> float: + """Score a team agent against a message based on skills + role keywords.""" + if not agent.skills and not agent.role: + return 0.0 + text = message.lower() + score = 0.0 + for skill in agent.skills: + if skill in text: + score += 1.5 # skills are more specific, higher weight + if agent.role and agent.role in text: + score += 1.0 + return score + + +def route_by_team( + message: str, + team: TeamConfig, +) -> Optional[RouteResult]: + """Route a message to the best team agent based on skills and role matching. + + Returns None if no agent has a positive match score. + """ + if not message or not message.strip() or not team.agents: + return None + + best_agent: Optional[TeamAgent] = None + best_score = 0.0 + + for agent in team.agents: + score = _score_agent_skills(agent, message) + if score > best_score: + best_score = score + best_agent = agent + + if best_agent is None: + return None + + matched = [] + text = message.lower() + for s in best_agent.skills: + if s in text: + matched.append(s) + if best_agent.role and best_agent.role in text: + matched.append(f"role:{best_agent.role}") + + return RouteResult( + provider=best_agent.provider, + model=best_agent.model, + reason=f"team:{best_agent.name} ({', '.join(matched[:3])})", + score=best_score, + ) + + +# --------------------------------------------------------------------------- +# Unified auto-route +# --------------------------------------------------------------------------- + +def auto_route( + message: str, + team: Optional[TeamConfig] = None, +) -> RouteResult: + """Auto-route a message to the best provider. + + Resolution order: + 1. Team skill-based matching (if team config exists) + 2. Keyword-based matching (default rules) + 3. Default fallback + """ + # Try team-based routing first + if team is not None: + result = route_by_team(message, team) + if result: + return result + + # Fall back to keyword-based routing + return route_by_keywords(message) diff --git a/lib/team_config.py b/lib/team_config.py new file mode 100644 index 00000000..a884f648 --- /dev/null +++ b/lib/team_config.py @@ -0,0 +1,129 @@ +"""Team configuration for CCB Agent Teams. + +Loads team config from JSON files and resolves team agent names to providers. + +Configuration layers (higher overrides lower): +1. ~/.ccb/team.json (global) +2. .ccb/team.json (project-level) + +A team config defines named agents with provider, model, role, and skills. +Team agent names take priority over aliases when resolving provider names. +""" + +from __future__ import annotations + +import json +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + + +@dataclass +class TeamAgent: + """A named agent within a team.""" + name: str + provider: str + model: str = "" + role: str = "" + skills: List[str] = field(default_factory=list) + + +@dataclass +class TeamConfig: + """Team configuration with named agents and allocation strategy.""" + name: str + agents: List[TeamAgent] = field(default_factory=list) + strategy: str = "skill_based" # round_robin | load_balance | skill_based + description: str = "" + + def agent_map(self) -> Dict[str, TeamAgent]: + """Build name → TeamAgent lookup (case-insensitive).""" + return {a.name.lower(): a for a in self.agents} + + +VALID_STRATEGIES = {"round_robin", "load_balance", "skill_based"} + + +def _parse_agent(raw: dict) -> Optional[TeamAgent]: + """Parse a single agent entry from JSON. Returns None on invalid data.""" + if not isinstance(raw, dict): + return None + name = str(raw.get("name", "")).strip() + provider = str(raw.get("provider", "")).strip().lower() + if not name or not provider: + return None + return TeamAgent( + name=name.lower(), + provider=provider, + model=str(raw.get("model", "")).strip(), + role=str(raw.get("role", "")).strip().lower(), + skills=[str(s).strip().lower() for s in raw.get("skills", []) if str(s).strip()], + ) + + +def _load_team_json(path: Path) -> Optional[TeamConfig]: + """Load a team config from a JSON file. Returns None on any error.""" + try: + if not path.is_file(): + return None + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + except (json.JSONDecodeError, OSError, ValueError): + print(f"[WARN] Failed to parse team config: {path}", file=sys.stderr) + return None + + name = str(data.get("name", "")).strip() + if not name: + name = "default" + + strategy = str(data.get("strategy", "skill_based")).strip().lower() + if strategy not in VALID_STRATEGIES: + strategy = "skill_based" + + agents: List[TeamAgent] = [] + for raw_agent in data.get("agents", []): + agent = _parse_agent(raw_agent) + if agent: + agents.append(agent) + + if not agents: + return None + + return TeamConfig( + name=name, + agents=agents, + strategy=strategy, + description=str(data.get("description", "")).strip(), + ) + + +def load_team_config(work_dir: Optional[Path] = None) -> Optional[TeamConfig]: + """Load team config: project .ccb/team.json overrides global ~/.ccb/team.json. + + Returns None if no valid team config is found. + """ + global_path = Path.home() / ".ccb" / "team.json" + global_config = _load_team_json(global_path) + + project_config: Optional[TeamConfig] = None + if work_dir is not None: + project_path = work_dir / ".ccb" / "team.json" + project_config = _load_team_json(project_path) + + # Project-level takes full priority (not merged) + return project_config or global_config + + +def resolve_team_agent( + name: str, + team: Optional[TeamConfig], +) -> Optional[TeamAgent]: + """Resolve a name to a TeamAgent. Returns None if not a team agent.""" + if team is None: + return None + key = (name or "").strip().lower() + if not key: + return None + return team.agent_map().get(key) diff --git a/test/test_aliases.py b/test/test_aliases.py new file mode 100644 index 00000000..6080159b --- /dev/null +++ b/test/test_aliases.py @@ -0,0 +1,199 @@ +"""Tests for lib/aliases.py — agent name alias resolution.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from aliases import DEFAULT_ALIASES, _load_json, load_aliases, resolve_alias + + +# --------------------------------------------------------------------------- +# resolve_alias +# --------------------------------------------------------------------------- + +class TestResolveAlias: + def test_known_alias(self): + aliases = {"a": "codex", "b": "gemini"} + assert resolve_alias("a", aliases) == "codex" + + def test_unknown_passthrough(self): + assert resolve_alias("kimi", {}) == "kimi" + + def test_case_insensitive(self): + aliases = {"a": "codex"} + assert resolve_alias("A", aliases) == "codex" + + def test_whitespace_stripped(self): + aliases = {"a": "codex"} + assert resolve_alias(" a ", aliases) == "codex" + + def test_empty_string(self): + assert resolve_alias("", {"": "x"}) == "x" + assert resolve_alias("", {}) == "" + + def test_none_safe(self): + assert resolve_alias(None, {}) == "" + + +# --------------------------------------------------------------------------- +# _load_json +# --------------------------------------------------------------------------- + +class TestLoadJson: + def test_missing_file(self, tmp_path: Path): + assert _load_json(tmp_path / "nope.json") == {} + + def test_valid_file(self, tmp_path: Path): + f = tmp_path / "a.json" + f.write_text(json.dumps({"x": "codex", "y": "gemini"})) + assert _load_json(f) == {"x": "codex", "y": "gemini"} + + def test_corrupt_json(self, tmp_path: Path): + f = tmp_path / "bad.json" + f.write_text("{not valid json") + assert _load_json(f) == {} + + def test_non_dict_json(self, tmp_path: Path): + f = tmp_path / "arr.json" + f.write_text(json.dumps([1, 2, 3])) + assert _load_json(f) == {} + + def test_coerces_values_to_str(self, tmp_path: Path): + f = tmp_path / "mixed.json" + f.write_text(json.dumps({"a": 123, "b": True})) + result = _load_json(f) + assert result == {"a": "123", "b": "True"} + + +# --------------------------------------------------------------------------- +# load_aliases +# --------------------------------------------------------------------------- + +class TestLoadAliases: + def test_defaults_only(self, tmp_path: Path): + """No config files → returns DEFAULT_ALIASES.""" + result = load_aliases(work_dir=tmp_path) + assert result == DEFAULT_ALIASES + + def test_global_overrides_default(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "aliases.json").write_text(json.dumps({"a": "gemini"})) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=tmp_path / "project") + assert result["a"] == "gemini" + # Other defaults still present + assert result["b"] == DEFAULT_ALIASES["b"] + + def test_project_overrides_global(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "aliases.json").write_text(json.dumps({"a": "gemini"})) + + proj = tmp_path / "project" + proj_ccb = proj / ".ccb" + proj_ccb.mkdir(parents=True) + (proj_ccb / "aliases.json").write_text(json.dumps({"a": "kimi"})) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=proj) + assert result["a"] == "kimi" + + def test_no_work_dir(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + home.mkdir() + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=None) + assert result == DEFAULT_ALIASES + + def test_custom_alias_added(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "aliases.json").write_text(json.dumps({"z": "deepseek"})) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + result = load_aliases(work_dir=tmp_path) + assert result["z"] == "deepseek" + # Defaults preserved + assert result["a"] == DEFAULT_ALIASES["a"] + + +# --------------------------------------------------------------------------- +# Alias + instance (colon-separated) integration +# --------------------------------------------------------------------------- + +class TestAliasWithInstance: + """Test the pattern used in bin/ask: alias:instance resolution.""" + + def test_alias_with_instance(self): + aliases = DEFAULT_ALIASES + raw = "a:review" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "codex:review" + + def test_plain_alias(self): + aliases = DEFAULT_ALIASES + raw = "b" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "gemini" + + def test_non_alias_with_instance(self): + aliases = DEFAULT_ALIASES + raw = "codex:auth" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "codex:auth" + + def test_non_alias_plain(self): + aliases = DEFAULT_ALIASES + raw = "kimi" + base, _, instance = raw.partition(":") + resolved = resolve_alias(base, aliases) + result = f"{resolved}:{instance}" if instance else resolved + assert result == "kimi" + + +# --------------------------------------------------------------------------- +# Integration with parse_qualified_provider +# --------------------------------------------------------------------------- + +class TestIntegrationWithProviders: + """Verify alias resolution works with parse_qualified_provider.""" + + def test_alias_then_parse(self): + from providers import parse_qualified_provider + + aliases = DEFAULT_ALIASES + raw = "a:review" + base, _, instance = raw.partition(":") + base = resolve_alias(base, aliases) + qualified = f"{base}:{instance}" if instance else base + + provider, inst = parse_qualified_provider(qualified) + assert provider == "codex" + assert inst == "review" + + def test_plain_alias_then_parse(self): + from providers import parse_qualified_provider + + aliases = DEFAULT_ALIASES + raw = "c" + base, _, instance = raw.partition(":") + base = resolve_alias(base, aliases) + qualified = f"{base}:{instance}" if instance else base + + provider, inst = parse_qualified_provider(qualified) + assert provider == "claude" + assert inst is None diff --git a/test/test_task_router.py b/test/test_task_router.py new file mode 100644 index 00000000..9e260e82 --- /dev/null +++ b/test/test_task_router.py @@ -0,0 +1,278 @@ +"""Tests for lib/task_router.py — smart task routing.""" + +from __future__ import annotations + +import pytest + +from task_router import ( + DEFAULT_FALLBACK, + DEFAULT_ROUTING_RULES, + RouteResult, + RoutingRule, + auto_route, + route_by_keywords, + route_by_team, + _score_message, + _score_agent_skills, +) +from team_config import TeamAgent, TeamConfig + + +# --------------------------------------------------------------------------- +# _score_message +# --------------------------------------------------------------------------- + +class TestScoreMessage: + def test_single_match(self): + assert _score_message("build a React component", ["react"]) == 1.0 + + def test_multiple_matches(self): + assert _score_message("React CSS HTML frontend", ["react", "css", "html"]) == 3.0 + + def test_no_match(self): + assert _score_message("hello world", ["react", "vue"]) == 0.0 + + def test_case_insensitive(self): + assert _score_message("Build React UI", ["react", "ui"]) == 2.0 + + def test_chinese_keywords(self): + assert _score_message("帮我写一个前端组件", ["前端"]) == 1.0 + + def test_empty_message(self): + assert _score_message("", ["react"]) == 0.0 + + +# --------------------------------------------------------------------------- +# route_by_keywords +# --------------------------------------------------------------------------- + +class TestRouteByKeywords: + def test_frontend_keywords(self): + result = route_by_keywords("帮我写一个 React 前端组件") + assert result.provider == "gemini" + + def test_algorithm_keywords(self): + result = route_by_keywords("分析这个算法的时间复杂度") + assert result.provider == "codex" + + def test_review_keywords(self): + result = route_by_keywords("请审查这段代码的安全性") + assert result.provider == "codex" + + def test_chinese_writing(self): + result = route_by_keywords("翻译这段话成英文") + assert result.provider == "kimi" + + def test_python_coding(self): + result = route_by_keywords("用 Python 实现一个快排") + assert result.provider == "qwen" + + def test_architecture(self): + result = route_by_keywords("帮我重构这个架构设计模式") + assert result.provider == "claude" + + def test_no_match_uses_fallback(self): + result = route_by_keywords("hello world 123") + assert result.provider == DEFAULT_FALLBACK.provider + + def test_empty_message_uses_fallback(self): + result = route_by_keywords("") + assert result.provider == DEFAULT_FALLBACK.provider + + def test_custom_rules(self): + rules = [RoutingRule(provider="custom", model="v1", keywords=["magic"], weight=1.0)] + result = route_by_keywords("do magic stuff", rules=rules) + assert result.provider == "custom" + assert result.model == "v1" + + def test_custom_fallback(self): + fb = RouteResult(provider="fallback_provider", reason="custom fb") + result = route_by_keywords("no match", rules=[], fallback=fb) + assert result.provider == "fallback_provider" + + def test_higher_weight_wins(self): + rules = [ + RoutingRule(provider="low", model="", keywords=["code"], weight=0.5), + RoutingRule(provider="high", model="", keywords=["code"], weight=2.0), + ] + result = route_by_keywords("write code", rules=rules) + assert result.provider == "high" + + def test_more_matches_wins(self): + rules = [ + RoutingRule(provider="one_match", model="", keywords=["react"], weight=1.0), + RoutingRule(provider="two_matches", model="", keywords=["react", "css"], weight=1.0), + ] + result = route_by_keywords("build React with CSS", rules=rules) + assert result.provider == "two_matches" + + def test_reason_includes_keywords(self): + result = route_by_keywords("帮我写 React 前端") + assert "keywords:" in result.reason + + def test_score_is_positive(self): + result = route_by_keywords("用 React 写前端") + assert result.score > 0 + + +# --------------------------------------------------------------------------- +# _score_agent_skills +# --------------------------------------------------------------------------- + +class TestScoreAgentSkills: + def test_skill_match(self): + agent = TeamAgent(name="dev", provider="codex", skills=["python", "rust"]) + assert _score_agent_skills(agent, "write python code") == 1.5 + + def test_role_match(self): + agent = TeamAgent(name="dev", provider="codex", role="review") + assert _score_agent_skills(agent, "please review this") == 1.0 + + def test_skill_and_role_match(self): + agent = TeamAgent(name="dev", provider="codex", role="review", skills=["security"]) + score = _score_agent_skills(agent, "security review needed") + assert score == 2.5 # 1.5 (skill) + 1.0 (role) + + def test_no_match(self): + agent = TeamAgent(name="dev", provider="codex", skills=["python"]) + assert _score_agent_skills(agent, "write rust code") == 0.0 + + def test_no_skills_no_role(self): + agent = TeamAgent(name="dev", provider="codex") + assert _score_agent_skills(agent, "anything") == 0.0 + + +# --------------------------------------------------------------------------- +# route_by_team +# --------------------------------------------------------------------------- + +class TestRouteByTeam: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev-team", + agents=[ + TeamAgent(name="frontend", provider="gemini", model="3f", role="research", skills=["react", "css", "frontend"]), + TeamAgent(name="backend", provider="codex", model="o3", role="implementation", skills=["python", "api", "database"]), + TeamAgent(name="reviewer", provider="claude", role="review", skills=["security", "architecture"]), + ], + ) + + def test_frontend_task(self, team): + result = route_by_team("build a React frontend component", team) + assert result is not None + assert result.provider == "gemini" + assert "team:frontend" in result.reason + + def test_backend_task(self, team): + result = route_by_team("implement Python API with database", team) + assert result is not None + assert result.provider == "codex" + + def test_review_task(self, team): + result = route_by_team("security review of the architecture", team) + assert result is not None + assert result.provider == "claude" + + def test_no_match(self, team): + result = route_by_team("hello world", team) + assert result is None + + def test_empty_message(self, team): + assert route_by_team("", team) is None + + def test_empty_team(self): + team = TeamConfig(name="empty", agents=[]) + assert route_by_team("anything", team) is None + + def test_best_agent_wins(self, team): + # "python api" matches backend (2 skills), not frontend + result = route_by_team("python api endpoint", team) + assert result is not None + assert result.provider == "codex" + + def test_score_is_positive(self, team): + result = route_by_team("react frontend", team) + assert result is not None + assert result.score > 0 + + +# --------------------------------------------------------------------------- +# auto_route +# --------------------------------------------------------------------------- + +class TestAutoRoute: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev-team", + agents=[ + TeamAgent(name="fe", provider="gemini", skills=["react", "frontend"]), + TeamAgent(name="be", provider="codex", skills=["python", "api"]), + ], + ) + + def test_team_match_preferred(self, team): + result = auto_route("build react frontend", team) + assert result.provider == "gemini" + assert "team:" in result.reason + + def test_falls_to_keywords_when_no_team_match(self, team): + result = auto_route("翻译这段话", team) + assert result.provider == "kimi" # keyword match, not team + + def test_falls_to_keywords_without_team(self): + result = auto_route("React 前端组件") + assert result.provider == "gemini" + + def test_falls_to_fallback(self): + result = auto_route("hello there") + assert result.provider == DEFAULT_FALLBACK.provider + + def test_no_team_uses_keywords(self): + result = auto_route("分析算法复杂度", None) + assert result.provider == "codex" + + def test_empty_message(self): + result = auto_route("") + assert result.provider == DEFAULT_FALLBACK.provider + + +# --------------------------------------------------------------------------- +# Integration: full ask --auto flow simulation +# --------------------------------------------------------------------------- + +class TestAutoRouteIntegration: + """Simulate the full --auto flow as implemented in bin/ask.""" + + def _simulate_auto(self, message: str, team: TeamConfig | None = None) -> str: + route = auto_route(message, team) + return route.provider + + def test_auto_frontend(self): + assert self._simulate_auto("帮我写一个 React 前端组件") == "gemini" + + def test_auto_algorithm(self): + assert self._simulate_auto("分析这个算法的时间复杂度") == "codex" + + def test_auto_translation(self): + assert self._simulate_auto("翻译这段话成英文") == "kimi" + + def test_auto_python(self): + assert self._simulate_auto("用 Python 实现快排") == "qwen" + + def test_auto_review(self): + assert self._simulate_auto("请审查这段代码") == "codex" + + def test_auto_with_team(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="fe", provider="qwen", skills=["react", "frontend"]), + ]) + # Team skill match overrides keyword match + assert self._simulate_auto("build react frontend", team) == "qwen" + + def test_auto_team_no_match_falls_to_keywords(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="fe", provider="qwen", skills=["react"]), + ]) + assert self._simulate_auto("分析算法复杂度", team) == "codex" diff --git a/test/test_team_config.py b/test/test_team_config.py new file mode 100644 index 00000000..02a484a1 --- /dev/null +++ b/test/test_team_config.py @@ -0,0 +1,387 @@ +"""Tests for lib/team_config.py — team configuration and agent resolution.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from team_config import ( + VALID_STRATEGIES, + TeamAgent, + TeamConfig, + _load_team_json, + _parse_agent, + load_team_config, + resolve_team_agent, +) + + +# --------------------------------------------------------------------------- +# TeamAgent / TeamConfig dataclasses +# --------------------------------------------------------------------------- + +class TestTeamConfig: + def test_agent_map_lookup(self): + team = TeamConfig( + name="test", + agents=[ + TeamAgent(name="coder", provider="codex", model="o3", role="implementation"), + TeamAgent(name="reviewer", provider="claude", role="review"), + ], + ) + m = team.agent_map() + assert m["coder"].provider == "codex" + assert m["reviewer"].provider == "claude" + + def test_agent_map_case_insensitive(self): + team = TeamConfig( + name="test", + agents=[TeamAgent(name="Coder", provider="codex")], + ) + # Names are lowered during parse, but test direct construction + m = team.agent_map() + assert "coder" in m + + def test_empty_agents(self): + team = TeamConfig(name="empty", agents=[]) + assert team.agent_map() == {} + + +# --------------------------------------------------------------------------- +# _parse_agent +# --------------------------------------------------------------------------- + +class TestParseAgent: + def test_valid_agent(self): + raw = {"name": "coder", "provider": "codex", "model": "o3", "role": "implementation", "skills": ["python", "rust"]} + agent = _parse_agent(raw) + assert agent is not None + assert agent.name == "coder" + assert agent.provider == "codex" + assert agent.model == "o3" + assert agent.role == "implementation" + assert agent.skills == ["python", "rust"] + + def test_minimal_agent(self): + raw = {"name": "bot", "provider": "gemini"} + agent = _parse_agent(raw) + assert agent is not None + assert agent.name == "bot" + assert agent.provider == "gemini" + assert agent.model == "" + assert agent.role == "" + assert agent.skills == [] + + def test_missing_name(self): + assert _parse_agent({"provider": "codex"}) is None + + def test_missing_provider(self): + assert _parse_agent({"name": "bot"}) is None + + def test_empty_name(self): + assert _parse_agent({"name": "", "provider": "codex"}) is None + + def test_not_a_dict(self): + assert _parse_agent("invalid") is None + assert _parse_agent(42) is None + assert _parse_agent(None) is None + + def test_skills_filters_empty(self): + raw = {"name": "bot", "provider": "gemini", "skills": ["python", "", " ", "rust"]} + agent = _parse_agent(raw) + assert agent.skills == ["python", "rust"] + + def test_provider_lowered(self): + raw = {"name": "bot", "provider": "Gemini"} + agent = _parse_agent(raw) + assert agent.provider == "gemini" + + +# --------------------------------------------------------------------------- +# _load_team_json +# --------------------------------------------------------------------------- + +class TestLoadTeamJson: + def test_missing_file(self, tmp_path: Path): + assert _load_team_json(tmp_path / "nope.json") is None + + def test_valid_config(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "dev-team", + "strategy": "skill_based", + "agents": [ + {"name": "coder", "provider": "codex", "model": "o3", "role": "implementation"}, + {"name": "reviewer", "provider": "claude", "role": "review"}, + ], + })) + team = _load_team_json(f) + assert team is not None + assert team.name == "dev-team" + assert team.strategy == "skill_based" + assert len(team.agents) == 2 + + def test_corrupt_json(self, tmp_path: Path): + f = tmp_path / "bad.json" + f.write_text("{invalid json") + assert _load_team_json(f) is None + + def test_non_dict_json(self, tmp_path: Path): + f = tmp_path / "arr.json" + f.write_text(json.dumps([1, 2])) + assert _load_team_json(f) is None + + def test_no_agents_returns_none(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({"name": "empty", "agents": []})) + assert _load_team_json(f) is None + + def test_invalid_agents_skipped(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "partial", + "agents": [ + {"name": "good", "provider": "codex"}, + {"name": "", "provider": "gemini"}, # invalid: empty name + "not_a_dict", # invalid: not dict + {"provider": "claude"}, # invalid: no name + ], + })) + team = _load_team_json(f) + assert team is not None + assert len(team.agents) == 1 + assert team.agents[0].name == "good" + + def test_default_name(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({"agents": [{"name": "a", "provider": "codex"}]})) + team = _load_team_json(f) + assert team.name == "default" + + def test_invalid_strategy_defaults(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "t", + "strategy": "invalid_strategy", + "agents": [{"name": "a", "provider": "codex"}], + })) + team = _load_team_json(f) + assert team.strategy == "skill_based" + + def test_all_valid_strategies(self, tmp_path: Path): + for strategy in VALID_STRATEGIES: + f = tmp_path / f"team_{strategy}.json" + f.write_text(json.dumps({ + "name": "t", + "strategy": strategy, + "agents": [{"name": "a", "provider": "codex"}], + })) + team = _load_team_json(f) + assert team.strategy == strategy + + def test_description_field(self, tmp_path: Path): + f = tmp_path / "team.json" + f.write_text(json.dumps({ + "name": "t", + "description": "My dev team", + "agents": [{"name": "a", "provider": "codex"}], + })) + team = _load_team_json(f) + assert team.description == "My dev team" + + +# --------------------------------------------------------------------------- +# load_team_config +# --------------------------------------------------------------------------- + +class TestLoadTeamConfig: + def test_no_config(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + home.mkdir() + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + assert load_team_config(work_dir=tmp_path) is None + + def test_global_config(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "team.json").write_text(json.dumps({ + "name": "global-team", + "agents": [{"name": "bot", "provider": "gemini"}], + })) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + team = load_team_config(work_dir=tmp_path / "project") + assert team is not None + assert team.name == "global-team" + + def test_project_overrides_global(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "team.json").write_text(json.dumps({ + "name": "global-team", + "agents": [{"name": "bot", "provider": "gemini"}], + })) + + proj = tmp_path / "project" + proj_ccb = proj / ".ccb" + proj_ccb.mkdir(parents=True) + (proj_ccb / "team.json").write_text(json.dumps({ + "name": "project-team", + "agents": [{"name": "coder", "provider": "codex"}], + })) + + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + team = load_team_config(work_dir=proj) + assert team.name == "project-team" + assert team.agents[0].name == "coder" + + def test_no_work_dir(self, tmp_path: Path, monkeypatch): + home = tmp_path / "home" + ccb_dir = home / ".ccb" + ccb_dir.mkdir(parents=True) + (ccb_dir / "team.json").write_text(json.dumps({ + "name": "global", + "agents": [{"name": "bot", "provider": "kimi"}], + })) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + team = load_team_config(work_dir=None) + assert team is not None + assert team.name == "global" + + +# --------------------------------------------------------------------------- +# resolve_team_agent +# --------------------------------------------------------------------------- + +class TestResolveTeamAgent: + @pytest.fixture() + def team(self) -> TeamConfig: + return TeamConfig( + name="dev", + agents=[ + TeamAgent(name="researcher", provider="gemini", model="3f", role="research"), + TeamAgent(name="coder", provider="codex", model="o3", role="implementation"), + TeamAgent(name="reviewer", provider="claude", role="review"), + ], + ) + + def test_resolve_known_agent(self, team): + agent = resolve_team_agent("researcher", team) + assert agent is not None + assert agent.provider == "gemini" + assert agent.model == "3f" + + def test_resolve_case_insensitive(self, team): + agent = resolve_team_agent("Coder", team) + assert agent is not None + assert agent.provider == "codex" + + def test_resolve_unknown_returns_none(self, team): + assert resolve_team_agent("unknown", team) is None + + def test_resolve_no_team(self): + assert resolve_team_agent("coder", None) is None + + def test_resolve_empty_name(self, team): + assert resolve_team_agent("", team) is None + + def test_resolve_none_name(self, team): + assert resolve_team_agent(None, team) is None + + +# --------------------------------------------------------------------------- +# Integration: team agents override aliases +# --------------------------------------------------------------------------- + +class TestTeamOverridesAlias: + """Verify team agent names take priority over aliases.""" + + def test_team_agent_overrides_alias(self): + from aliases import DEFAULT_ALIASES, resolve_alias + + team = TeamConfig( + name="test", + agents=[TeamAgent(name="a", provider="kimi")], # override alias a→codex + ) + + name = "a" + # Team resolution first + team_agent = resolve_team_agent(name, team) + if team_agent: + provider = team_agent.provider + else: + provider = resolve_alias(name, DEFAULT_ALIASES) + + assert provider == "kimi" # team wins over alias + + def test_non_team_falls_to_alias(self): + from aliases import DEFAULT_ALIASES, resolve_alias + + team = TeamConfig( + name="test", + agents=[TeamAgent(name="coder", provider="codex")], + ) + + name = "a" + team_agent = resolve_team_agent(name, team) + if team_agent: + provider = team_agent.provider + else: + provider = resolve_alias(name, DEFAULT_ALIASES) + + assert provider == "codex" # alias a→codex + + +# --------------------------------------------------------------------------- +# Full resolution flow (as in bin/ask) +# --------------------------------------------------------------------------- + +class TestFullResolutionFlow: + """Simulate the full resolution flow in bin/ask.""" + + def _resolve(self, raw_provider: str, team: TeamConfig | None) -> tuple[str, str | None]: + from aliases import load_aliases, resolve_alias + from providers import parse_qualified_provider + + team_agent = resolve_team_agent(raw_provider, team) + if team_agent: + raw_provider = team_agent.provider + else: + aliases = {"a": "codex", "b": "gemini", "c": "claude"} + base_part, _, instance_part = raw_provider.partition(":") + base_part = resolve_alias(base_part, aliases) + raw_provider = f"{base_part}:{instance_part}" if instance_part else base_part + + return parse_qualified_provider(raw_provider) + + def test_team_agent_resolves(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="researcher", provider="gemini", model="3f"), + ]) + provider, instance = self._resolve("researcher", team) + assert provider == "gemini" + assert instance is None + + def test_alias_resolves_without_team(self): + provider, instance = self._resolve("a", None) + assert provider == "codex" + + def test_alias_with_instance(self): + provider, instance = self._resolve("a:review", None) + assert provider == "codex" + assert instance == "review" + + def test_direct_provider(self): + provider, instance = self._resolve("kimi", None) + assert provider == "kimi" + assert instance is None + + def test_team_agent_overrides_alias_letter(self): + team = TeamConfig(name="t", agents=[ + TeamAgent(name="a", provider="qwen"), + ]) + provider, instance = self._resolve("a", team) + assert provider == "qwen" # team wins