diff --git a/.gitignore b/.gitignore index 789c18f2..9a8f8f83 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ docs/superpowers/ .superpowers/ skills-lock.json skills/ +.planning/ diff --git a/.planning/REQUIREMENTS.md b/.planning/REQUIREMENTS.md new file mode 100644 index 00000000..1d5acec8 --- /dev/null +++ b/.planning/REQUIREMENTS.md @@ -0,0 +1,75 @@ +# Requirements: Per-Model Provider Routing + +**Defined:** 2026-06-11 +**Core Value:** 一键切换 AI 编程工具的底层 provider,零配置摩擦 + +## v1 Requirements (Milestone 1) + +### 数据存储 (DB) + +- [x] **DB-01**: 数据库 Schema 从 v10 升级到 v11,新增 `model_routes` 表 +- [x] **DB-02**: `model_routes` 表包含字段:id, app_type, pattern (通配符), provider_id, priority (排序), enabled (开关), created_at, updated_at +- [x] **DB-03**: 支持 CRUD 操作:创建路由规则、列出所有规则、更新规则、删除规则 +- [x] **DB-04**: 规则按 priority 排序,同 priority 按创建时间排序 +- [x] **DB-05**: 创建规则时验证 provider_id 存在且属于同一 app_type +- [x] **DB-06**: Schema 升级向下兼容:空 model_routes 表 = 行为不变 + +### 路由引擎 (Router) + +- [x] **RT-01**: ModelRouter 在代理请求处理流程中先于 ProviderRouter 执行 +- [x] **RT-02**: 支持 `*` 通配符匹配 model 名称(如 `*sonnet*`、`claude-*`、`*-4-5`) +- [x] **RT-03**: 多个规则匹配时,选择 priority 最高(数字最小)的 enabled 规则 +- [x] **RT-04**: 无匹配规则时,回退到现有的 ProviderRouter 逻辑(行为不变) +- [x] **RT-05**: 规则指向的 provider 不存在时,记录 warning 日志并回退 +- [x] **RT-06**: 路由选中的 provider 为单 provider(不使用 failover 队列) + +### CLI 命令 (CLI) + +- [x] **CL-01**: `cc-switch proxy model-route list [--app ]` — 列出所有路由规则 +- [x] **CL-02**: `cc-switch proxy model-route add [--priority ] [--app ]` — 添加路由 +- [x] **CL-03**: `cc-switch proxy model-route remove ` — 删除路由 +- [x] **CL-04**: `cc-switch proxy model-route toggle ` — 切换启用/禁用 +- [x] **CL-05**: `cc-switch proxy model-route update [--pattern] [--provider] [--priority]` — 更新路由 +- [x] **CL-06**: 命令输出人类可读的表格格式(与现有 proxy 命令风格一致) + +### TUI 界面 (TUI) + +- [x] **UI-01**: 在代理设置页面中增加模型路由管理入口 +- [x] **UI-02**: 路由规则列表表格:显示 pattern、目标 provider、优先级、启用状态 +- [x] **UI-03**: 支持创建新规则:输入 pattern + 选择 provider + 设置优先级 +- [x] **UI-04**: 支持编辑/删除/切换启用状态 +- [ ] **UI-05**: 界面风格与现有 TUI 一致(配色、布局、快捷键) + +### 同步 (Sync) + +- [ ] **SY-01**: model_routes 变更时触发 WebDAV 自动同步(若已配置) +- [ ] **SY-02**: model_routes 变更时触发 S3 自动同步(若已配置) + +### 测试 (TEST) + +- [x] **TE-01**: model_routes DAO 的 CRUD 单元测试 +- [x] **TE-02**: ModelRouter 通配符匹配逻辑的单元测试 +- [x] **TE-03**: Schema v10→v11 迁移测试 +- [ ] **TE-04**: 代理路由集成测试:匹配规则→选中正确 provider +- [ ] **TE-05**: 代理回退集成测试:无匹配→回退到现有逻辑 +- [x] **TE-06**: CLI 命令集成测试 + +## Out of Scope + +| Feature | Reason | +|---------|--------| +| 正则表达式匹配(仅支持 `*` 通配符) | 与上游 cc-switch PR 保持一致,`*` 覆盖 95% 用例 | +| 多 provider failover for model routes | 设计决策:路由规则选中单 provider,匹配失败回退到现有 failover | +| 基于请求内容的动态路由(非 model 名称) | 复杂度高,无明确用例 | +| 路由规则导入/导出 | 可通过 WebDAV/S3 同步覆盖此需求 | + +## Traceability + +| Requirement | Phase | Status | +|-------------|-------|--------| +| DB-01 ~ DB-06 | Phase 1: Database | Pending | +| RT-01 ~ RT-06 | Phase 2: Router Engine | Pending | +| CL-01 ~ CL-06 | Phase 3: CLI Commands | Pending | +| UI-01 ~ UI-05 | Phase 4: TUI Interface | Pending | +| SY-01 ~ SY-02 | Phase 5: Sync Integration | Pending | +| TE-01 ~ TE-06 | Phase 6: Testing | Pending | diff --git a/.planning/ROADMAP.md b/.planning/ROADMAP.md new file mode 100644 index 00000000..504ed6a1 --- /dev/null +++ b/.planning/ROADMAP.md @@ -0,0 +1,192 @@ +# Roadmap: Per-Model Provider Routing + +**Created:** 2026-06-11 +**Milestone:** 1 +**Total phases:** 6 +**Estimated effort:** 17-27 hours (~2.5-4 days) + +--- + +## Phase Dependency Graph + +``` +Phase 1: Database Layer + ↓ +Phase 2: Router Engine + Proxy Integration + ↓ +┌───────────────┬───────────────┐ +↓ ↓ ↓ +Phase 3: Phase 4: Phase 5: +CLI Commands TUI Interface Sync Integration + ↓ ↓ ↓ +└───────────────┴───────────────┘ + ↓ + Phase 6: Final Testing & PR Prep +``` + +Phases 3, 4, 5 可并行执行(都只依赖 Phase 2)。 + +--- + +## Phase 1: Database Layer + +**Goal:** 创建 `model_routes` 表和相关 DAO,完成 Schema v10→v11 迁移 + +**Depends on:** 无 +**Estimated effort:** 2-3 小时 +**Files to touch:** ~4 files, ~230 lines + +### Tasks + +1. **Schema v11 migration** + - 在 `database/schema.rs` 中实现 `migrate_v10_to_v11()` + - 创建 `model_routes` 表:id INTEGER PK, app_type TEXT NOT NULL, pattern TEXT NOT NULL, provider_id TEXT NOT NULL, priority INTEGER DEFAULT 0, enabled INTEGER DEFAULT 1, created_at TEXT, updated_at TEXT + - 添加 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE + - 更新 `CURRENT_SCHEMA_VERSION` 常量 + +2. **ModelRoute 类型定义** + - 在 `provider.rs`(或新建 `model_route.rs`)中定义 `ModelRoute` struct + - 实现 Serialize/Deserialize/Clone/Debug + +3. **model_routes DAO** + - 新建 `database/dao/model_routes.rs` + - `list_routes(app_type) → Vec` — 按 priority ASC, created_at ASC + - `create_route(route) → ModelRoute` + - `update_route(id, updates) → ModelRoute` + - `delete_route(id)` + - `toggle_route(id)` + - `get_route(id) → Option` + - 创建时验证 provider_id 存在且属于同 app_type + - 在 `database/dao/mod.rs` 中注册模块 + +4. **Database 集成** + - 在 `database/mod.rs` 中暴露 DAO 方法 + - 确保 `try_new()` 自动执行迁移 + +### Verification +- [ ] `cargo test database` — 所有数据库测试通过 +- [ ] DAO CRUD 测试覆盖所有操作 +- [ ] Schema 迁移测试:v10 数据库升级到 v11 后数据完整 +- [ ] 向下兼容:无 model_routes 时所有现有功能正常 + +**Covers:** DB-01 ~ DB-06, TE-01, TE-03 + +--- + +## Phase 2: Router Engine + Proxy Integration + +**Goal:** 实现 ModelRouter 通配符匹配引擎,并集成到代理请求处理流程 + +**Depends on:** Phase 1 +**Estimated effort:** 4-6 小时 +**Files to touch:** ~8 files, ~500 lines +**Plans:** 1 plan + +### Plans + +- [x] 02-01-PLAN.md — ModelRouter engine creation, HandlerContext integration, ProxyServerState wiring, integration tests + +**Covers:** RT-01 ~ RT-06, TE-02 + + +--- + +## Phase 3: CLI Commands + +**Goal:** 实现 `cc-switch proxy model-route` 子命令组 + +**Depends on:** Phase 1(仅需 DAO,可与 Phase 2 并行) +**Estimated effort:** 1-2 小时 +**Files to touch:** ~2 files, ~70 lines +**Plans:** 1/1 plans complete + +### Plans + +- [x] 03-01-PLAN.md — ModelRouteCommand enum definition + ProxyCommand integration + command handler implementation + tests + +**Covers:** CL-01 ~ CL-06, TE-06 + +--- + +## Phase 4: TUI Interface + +**Goal:** 在 ratatui TUI 的代理设置区域增加模型路由管理界面 + +**Depends on:** Phase 1 + Phase 2(需要 DAO 和 ModelRouter 工作正常) +**Estimated effort:** 6-10 小时(最大工作量) +**Files to touch:** ~10 files, ~400 lines +**Plans:** 2/2 plans complete + +### Plans + +- [x] 04-01-PLAN.md — ModelRouteSnapshot data type, Route::SettingsModelRoutes, Settings menu entry, table rendering placeholder +- [x] 04-02-PLAN.md — Action variants, runtime action handlers, multi-step Add/Edit overlays, delete confirmation, toggle, keyboard wiring + +**Covers:** UI-01 ~ UI-05 + +--- + +## Phase 5: Sync Integration + +**Goal:** model_routes 变更时触发 WebDAV/S3 自动同步 + +**Depends on:** Phase 1(仅需 DAO) +**Estimated effort:** 0.5-1 小时 +**Files to touch:** ~2 files, ~10 lines + +### Tasks + +1. **WebDAV 同步触发** + - 在 `services/webdav_auto_sync.rs` 中添加 model_routes 表变更的触发 + - 在 DAO 的 create/update/delete 方法中调用 sync trigger + +2. **S3 同步触发** + - 在 `services/s3_auto_sync.rs` 中同样添加触发 + - 保持与现有同步机制一致的模式 + +### Verification +- [ ] 配置 WebDAV 同步后,添加/修改路由规则触发同步 +- [ ] 配置 S3 同步后,添加/修改路由规则触发同步 + +**Covers:** SY-01 ~ SY-02 + +--- + +## Phase 6: Final Testing & PR Preparation + +**Goal:** 全面测试,清理代码,准备可合并的纯净 PR 分支 + +**Depends on:** Phase 3, 4, 5(全部完成) +**Estimated effort:** 3-5 小时 +**Plans:** 1 plan + +### Plans + +- [ ] 06-01-PLAN.md — Verify integration tests (TE-04/TE-05 already exist), run full test suite + quality gates, prepare clean PR branch with .planning/ excluded + +**Covers:** TE-04, TE-05 + +--- + +## Risk Register + +| Risk | Severity | Mitigation | +|------|----------|------------| +| handler_context 结构与 cc-switch 差异过大,ModelRouter 集成点不匹配 | MEDIUM | Phase 2 开始前详细对比两个项目的 handler_context 结构 | +| TUI 表单组件不够灵活,无法实现 pattern + provider picker 组合输入 | MEDIUM | Phase 4 开始前评估现有 TUI 组件能力,必要时简化输入流程 | +| Schema 迁移与现有备份/恢复机制冲突 | LOW | Phase 1 先研究现有迁移模式和备份逻辑 | +| 上游 PR 的变更在 cc-switch-cli 中路径/API 不同 | LOW | 每个 Phase 对照当前代码库做适配,不盲目复制 | + +--- + +## Traceability + +| Phase | Requirements Covered | Est. Effort | +|-------|---------------------|-------------| +| Phase 1: Database | DB-01~06, TE-01, TE-03 | 2-3h | +| Phase 2: Router Engine | RT-01~06, TE-02 | 4-6h | +| Phase 3: CLI Commands | CL-01~06, TE-06 | 1-2h | +| Phase 4: TUI Interface | UI-01~05 | 6-10h | +| Phase 5: Sync | SY-01~02 | 0.5-1h | +| Phase 6: Testing & PR | TE-04~05 | 3-5h | +| **Total** | **31 requirements** | **17-27h** | diff --git a/.planning/STATE.md b/.planning/STATE.md new file mode 100644 index 00000000..f09171d7 --- /dev/null +++ b/.planning/STATE.md @@ -0,0 +1,102 @@ +--- +gsd_state_version: 1.0 +milestone: v1.0 +milestone_name: milestone +current_phase: Phase 4 (complete) +status: in_progress +last_updated: "2026-06-12T01:22:05.651Z" +progress: + total_phases: 6 + completed_phases: 4 + total_plans: 6 + completed_plans: 6 + percent: 67 +--- + +# State: CC-Switch CLI + +**Last updated:** 2026-06-12 +**Active milestone:** Milestone 1 — Per-Model Provider Routing +**Current phase:** Phase 3 (complete) + +## Project Reference + +See: `.planning/PROJECT.md` (updated 2026-06-11) + +**Core value:** 一键切换 AI 编程工具的底层 provider,零配置摩擦 +**Current focus:** 实现 per-model provider routing(根据模型名称将代理请求路由到不同 provider) + +## Milestone Progress + +| Phase | Status | Est. Effort | Started | Completed | +|-------|--------|-------------|---------|-----------| +| Phase 1: Database | ✅ Complete | 2-3h | 2026-06-11 | 2026-06-11 | +| Phase 2: Router Engine | ✅ Complete | 4-6h | 2026-06-11 | 2026-06-12 | +| Phase 3: CLI Commands | ✅ Complete | 1-2h | 2026-06-11 | 2026-06-12 | +| Phase 4: TUI Interface | ✅ Complete | 6-10h | 2026-06-12 | 2026-06-12 | +| Phase 5: Sync Integration | ⬜ Pending | 0.5-1h | — | — | +| Phase 6: Testing & PR Prep | ⬜ Pending | 3-5h | — | — | + +## Reference Artifacts + +- Codebase map: `.planning/codebase/` (7 documents, 2391 lines, generated 2026-06-11) +- Phase 1 Research: `.planning/phase-1/RESEARCH.md` +- Phase 1 Plan: `.planning/phases/01-database/01-01-PLAN.md` (1 plan, 3 tasks, 1 wave) +- Phase 1 Summary: `.planning/phases/01-database/01-01-SUMMARY.md` +- Phase 2 Research: `.planning/phase-2/RESEARCH.md` +- Phase 2 Plan: `.planning/phases/02-router/02-01-PLAN.md` (1 plan, 3 tasks, 1 wave) +- Phase 2 Summary: `.planning/phases/02-router/02-01-SUMMARY.md` +- Phase 3 Research: `.planning/phase-3/RESEARCH.md` +- Phase 3 Summary: `.planning/phases/03-cli/03-01-SUMMARY.md` +- Phase 4 Plan 01: `.planning/phases/04-tui-interface/04-01-PLAN.md` (1 plan, 2 tasks, 1 wave) +- Phase 4 Summary 01: `.planning/phases/04-tui-interface/04-01-SUMMARY.md` +- Phase 4 Plan 02: `.planning/phases/04-tui-interface/04-02-PLAN.md` (1 plan, 2 tasks, 1 wave) +- Phase 4 Summary 02: `.planning/phases/04-tui-interface/04-02-SUMMARY.md` + +## Working State + +- **Branch:** `main` (clean) +- **Last commit:** `e10ef89 style(04-tui-interface): apply cargo fmt formatting fixes` +- **Schema version:** v11 + +## Quick Start (Next Session) + +```bash + +# Phase 4 is complete. Phase 5 (Sync Integration) is next. + +/gsd-execute-phase 05-sync --wave 1 +``` + +## Notes + +- 上游 PR #4081 于 2026-06-11 提交,当前状态 OPEN,有一次 codex review 但无实质性修改要求 +- cc-switch-cli 与 cc-switch 的关键差异:无 React 前端、ratatui TUI、代理架构细节可能不同 +- Phase 4 (TUI) 是最大的工作量来源(35-40%),取决于现有 TUI 组件的复用程度 +- Phase 1 completed: model_routes table, ModelRoute type, CRUD DAO — all foundations in place +- Phase 2 completed: ModelRouter engine, proxy integration — route matching works end-to-end +- Phase 3 complete: CLI commands for model-route CRUD (1 plan, 2 tasks, 1 wave) +- Phase 3 Summary: `.planning/phases/03-cli/03-01-SUMMARY.md` +- Phase 4 Plan 01: `.planning/phases/04-tui-interface/04-01-PLAN.md` (1 plan, 2 tasks, 1 wave) +- Phase 4 Summary 01: `.planning/phases/04-tui-interface/04-01-SUMMARY.md` +- Phase 4 Wave 1 (04-01) complete: model routes TUI scaffolding (data types, navigation, table rendering) +- Phase 4 Wave 2 (04-02) complete: model routes full CRUD operations via TUI overlays + +## Performance Metrics + +| Phase | Plan | Duration | Notes | +|-------|------|----------|-------| +| Phase 01-database P01 | 18 min | 3 tasks | 7 files | +| Phase 02-router P01 | 67 min | 3 tasks | 6 files | +| Phase 03-cli P01 | 7 min | 2 tasks | 1 file | +| Phase 04-tui-interface P01 | 10 min | 2 tasks | 8 files + 1 new | +| Phase 04-tui-interface P02 | ~10 min | 2 tasks | 9 files + 1 new | + +## Decisions + +- [Phase 1]: ModelRoute type in separate model_route.rs module (matches upstream PR #4081 structure) +- [Phase 2]: ModelRouter holds Arc only — no caching, reads routes fresh on every request +- [Phase 2]: Single provider for matched routes (no failover queue) — matches upstream design decision +- [Phase 3]: cli/mod.rs unchanged — Clap derive auto-discovers ProxyCommand::ModelRoute via existing dispatch +- [Phase 4]: Model routes rendering uses dedicated ui/model_routes.rs module (matches existing config.rs sub-page pattern) +- [Phase 4 P2]: Multi-step overlay flow (pattern -> provider -> priority) for Add/Edit; Space toggles with no toast diff --git a/src-tauri/.gitignore b/src-tauri/.gitignore index 502406b4..7d30e7df 100644 --- a/src-tauri/.gitignore +++ b/src-tauri/.gitignore @@ -2,3 +2,4 @@ # will have compiled files and executables /target/ /gen/schemas +.planning/ diff --git a/src-tauri/src/cli/commands/proxy.rs b/src-tauri/src/cli/commands/proxy.rs index 9164dba5..a3fcdd12 100644 --- a/src-tauri/src/cli/commands/proxy.rs +++ b/src-tauri/src/cli/commands/proxy.rs @@ -4,6 +4,7 @@ use crate::app_config::AppType; use crate::cli::proxy_settings::{validate_proxy_listen_address, validate_proxy_listen_port}; use crate::cli::ui::{highlight, info, success}; use crate::error::AppError; +use crate::model_route::ModelRoute; use crate::{AppState, ProxyConfig}; #[cfg(unix)] @@ -13,11 +14,45 @@ use crate::daemon::ipc::protocol::{Request as DaemonRequest, Response as DaemonR #[cfg(unix)] use crate::daemon::supervisor::{DAEMON_SOCKET_ENV, SESSION_TOKEN_ENV}; +#[derive(Subcommand, Debug, Clone)] +pub enum ModelRouteCommand { + /// List model routing rules + List, + /// Add a model routing rule + Add { + /// Wildcard pattern (e.g., *sonnet*, claude-*) + pattern: String, + /// Provider ID to route matching models to + provider_id: String, + /// Priority (lower = higher priority) + #[arg(long, default_value = "0")] + priority: i32, + }, + /// Remove a model routing rule + Remove { id: String }, + /// Toggle a model routing rule on/off + Toggle { id: String }, + /// Update a model routing rule + Update { + id: String, + #[arg(long)] + pattern: Option, + #[arg(long)] + provider_id: Option, + #[arg(long)] + priority: Option, + }, +} + #[derive(Subcommand, Debug, Clone)] pub enum ProxyCommand { /// Show current proxy configuration and routes Show, + /// Manage model-based routing rules + #[command(subcommand)] + ModelRoute(ModelRouteCommand), + /// Enable the persisted proxy switch Enable, @@ -54,6 +89,10 @@ pub enum ProxyCommand { pub fn execute(cmd: ProxyCommand, app: Option) -> Result<(), AppError> { let app_type = app.unwrap_or(AppType::Claude); match cmd { + ProxyCommand::ModelRoute(subcmd) => { + let state = get_state()?; + handle_model_route(&state, &app_type, subcmd) + } ProxyCommand::Show => show_proxy(), ProxyCommand::Enable => set_proxy_enabled(app_type, true), ProxyCommand::Disable => set_proxy_enabled(app_type, false), @@ -69,6 +108,116 @@ pub fn execute(cmd: ProxyCommand, app: Option) -> Result<(), AppError> } } +fn print_model_routes(routes: &[ModelRoute]) { + if routes.is_empty() { + println!("{}", info("No model routing rules found.")); + return; + } + let mut table = comfy_table::Table::new(); + table.load_preset(comfy_table::presets::UTF8_FULL); + table.set_header(vec!["ID", "Pattern", "Provider", "Priority", "Enabled"]); + for r in routes { + table.add_row(vec![ + r.id.clone(), + r.pattern.clone(), + r.provider_id.clone(), + r.priority.to_string(), + if r.enabled { "yes" } else { "no" }.to_string(), + ]); + } + println!("{table}"); +} + +fn handle_model_route( + state: &AppState, + app: &AppType, + cmd: ModelRouteCommand, +) -> Result<(), AppError> { + match cmd { + ModelRouteCommand::List => { + let routes = state.db.list_model_routes(app.as_str())?; + print_model_routes(&routes); + } + ModelRouteCommand::Add { + pattern, + provider_id, + priority, + } => { + let route = ModelRoute { + id: String::new(), + app_type: app.as_str().to_string(), + pattern: pattern.clone(), + provider_id: provider_id.clone(), + priority, + enabled: true, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }; + let created = state.db.create_model_route(&route)?; + println!( + "{}", + success(&format!( + "Model route created: id={}, pattern=\"{}\" → provider={}, priority={}", + created.id, created.pattern, created.provider_id, created.priority + )) + ); + } + ModelRouteCommand::Remove { id } => { + state.db.delete_model_route(&id)?; + println!("{}", success(&format!("Model route {id} removed."))); + } + ModelRouteCommand::Toggle { id } => { + let toggled = state.db.toggle_model_route(&id)?; + let status = if toggled.enabled { + "enabled" + } else { + "disabled" + }; + println!( + "{}", + success(&format!( + "Model route {id} toggled: pattern=\"{}\" now {status}.", + toggled.pattern + )) + ); + } + ModelRouteCommand::Update { + id, + pattern, + provider_id, + priority, + } => { + let existing = state + .db + .get_model_route(&id)? + .ok_or_else(|| AppError::Database("model_route not found".to_string()))?; + let updated = ModelRoute { + id: existing.id.clone(), + app_type: app.as_str().to_string(), + pattern: pattern.unwrap_or(existing.pattern), + provider_id: provider_id.unwrap_or(existing.provider_id), + priority: priority.unwrap_or(existing.priority), + enabled: existing.enabled, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }; + let result = state.db.update_model_route(&id, &updated)?; + println!( + "{}", + success(&format!( + "Model route {id} updated: pattern=\"{}\" → provider={}, priority={}.", + result.pattern, result.provider_id, result.priority + )) + ); + } + } + Ok(()) +} + fn get_state() -> Result { AppState::try_new() } @@ -647,8 +796,15 @@ mod tests { Database, MultiAppConfig, ProxyService, }; - use super::{apply_overrides, build_proxy_overview_lines, load_proxy_app_ports}; + use super::{ + apply_overrides, build_proxy_overview_lines, handle_model_route, load_proxy_app_ports, + ModelRouteCommand, + }; + use crate::app_config::AppType; use crate::cli::proxy_settings::validate_proxy_listen_port; + use crate::database::lock_conn; + use crate::error::AppError; + use crate::model_route::ModelRoute; #[test] fn cli_proxy_listen_port_validation_rejects_reserved_ports() { @@ -803,4 +959,462 @@ mod tests { "proxy show output should not hard-code automatic failover as disabled" ); } + + // --------------------------------------------------------------------------- + // Model-route command tests + // --------------------------------------------------------------------------- + + fn seed_provider(db: &Database, app_type: &str, id: &str) -> Result<(), AppError> { + let conn = lock_conn!(db.conn); + conn.execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES (?1, ?2, ?3, '{}', '{}')", + rusqlite::params![id, app_type, id], + ) + .map_err(|e| AppError::Database(e.to_string()))?; + Ok(()) + } + + #[test] + fn model_route_list_empty_shows_no_routes_message() { + let db = Arc::new(Database::memory().expect("create database")); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let result = handle_model_route(&state, &app, ModelRouteCommand::List); + assert!(result.is_ok(), "list should succeed"); + } + + #[test] + fn model_route_add_and_list_roundtrip() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + // Add a route + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Add { + pattern: "*-4-5".to_string(), + provider_id: "test-prov".to_string(), + priority: 0, + }, + ); + assert!(result.is_ok(), "add should succeed"); + + // Verify via list + let routes = db.list_model_routes("claude").expect("list routes"); + assert_eq!(routes.len(), 1); + let route = &routes[0]; + assert_eq!(route.pattern, "*-4-5"); + assert_eq!(route.provider_id, "test-prov"); + assert!(route.enabled); + } + + #[test] + fn model_route_add_rejects_nonexistent_provider() { + let db = Arc::new(Database::memory().expect("create database")); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Add { + pattern: "*-4-5".to_string(), + provider_id: "nonexistent".to_string(), + priority: 0, + }, + ); + assert!(result.is_err(), "add with nonexistent provider should fail"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("provider") && err.contains("not found"), + "expected provider not found error, got: {err}" + ); + } + + #[test] + fn model_route_add_with_explicit_priority() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Add { + pattern: "*-sonnet".to_string(), + provider_id: "test-prov".to_string(), + priority: 7, + }, + ); + assert!(result.is_ok(), "add with priority should succeed"); + + let routes = db.list_model_routes("claude").expect("list routes"); + assert_eq!(routes.len(), 1); + assert_eq!(routes[0].priority, 7); + } + + #[test] + fn model_route_remove_deletes_by_id() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + // Add then remove + let route_id = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".to_string(), + pattern: "*-sonnet".to_string(), + provider_id: "test-prov".to_string(), + priority: 0, + enabled: true, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }) + .expect("create route") + .id; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Remove { + id: route_id.clone(), + }, + ); + assert!(result.is_ok(), "remove should succeed"); + + let routes = db.list_model_routes("claude").expect("list routes"); + assert!(routes.is_empty(), "route should be deleted"); + } + + #[test] + fn model_route_remove_nonexistent_id_errors() { + let db = Arc::new(Database::memory().expect("create database")); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Remove { + id: "missing-route".to_string(), + }, + ); + assert!(result.is_err(), "remove nonexistent should fail"); + } + + #[test] + fn model_route_toggle_flips_enabled() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + // Create an enabled route + let route_id = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".to_string(), + pattern: "*-sonnet".to_string(), + provider_id: "test-prov".to_string(), + priority: 0, + enabled: true, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }) + .expect("create route") + .id; + + // Toggle off + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Toggle { + id: route_id.clone(), + }, + ); + assert!(result.is_ok(), "toggle should succeed"); + + let route = db + .get_model_route(&route_id) + .expect("get route") + .expect("route exists"); + assert!(!route.enabled, "should be disabled after toggle"); + + // Toggle on + handle_model_route( + &state, + &app, + ModelRouteCommand::Toggle { + id: route_id.clone(), + }, + ) + .expect("toggle back"); + let route = db + .get_model_route(&route_id) + .expect("get route") + .expect("route exists"); + assert!(route.enabled, "should be enabled after second toggle"); + } + + #[test] + fn model_route_toggle_nonexistent_id_errors() { + let db = Arc::new(Database::memory().expect("create database")); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Toggle { + id: "missing-route".to_string(), + }, + ); + assert!(result.is_err(), "toggle nonexistent should fail"); + } + + #[test] + fn model_route_update_changes_pattern_only() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let route_id = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".to_string(), + pattern: "original-*".to_string(), + provider_id: "test-prov".to_string(), + priority: 5, + enabled: true, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }) + .expect("create route") + .id; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Update { + id: route_id.clone(), + pattern: Some("new-pattern-*".to_string()), + provider_id: None, + priority: None, + }, + ); + assert!(result.is_ok(), "update pattern should succeed"); + + let route = db + .get_model_route(&route_id) + .expect("get route") + .expect("route exists"); + assert_eq!(route.pattern, "new-pattern-*"); + assert_eq!(route.provider_id, "test-prov"); // unchanged + assert_eq!(route.priority, 5); // unchanged + } + + #[test] + fn model_route_update_changes_provider_only() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + seed_provider(&db, "claude", "other-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let route_id = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".to_string(), + pattern: "*-sonnet".to_string(), + provider_id: "test-prov".to_string(), + priority: 5, + enabled: true, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }) + .expect("create route") + .id; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Update { + id: route_id.clone(), + pattern: None, + provider_id: Some("other-prov".to_string()), + priority: None, + }, + ); + assert!(result.is_ok(), "update provider should succeed"); + + let route = db + .get_model_route(&route_id) + .expect("get route") + .expect("route exists"); + assert_eq!(route.provider_id, "other-prov"); + assert_eq!(route.pattern, "*-sonnet"); // unchanged + } + + #[test] + fn model_route_update_changes_priority_only() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "claude", "test-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let route_id = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".to_string(), + pattern: "*-sonnet".to_string(), + provider_id: "test-prov".to_string(), + priority: 5, + enabled: true, + created_at: None, + hit_count: 0, + last_hit_at: None, + updated_at: None, + }) + .expect("create route") + .id; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Update { + id: route_id.clone(), + pattern: None, + provider_id: None, + priority: Some(99), + }, + ); + assert!(result.is_ok(), "update priority should succeed"); + + let route = db + .get_model_route(&route_id) + .expect("get route") + .expect("route exists"); + assert_eq!(route.priority, 99); + } + + #[test] + fn model_route_update_nonexistent_id_errors() { + let db = Arc::new(Database::memory().expect("create database")); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Claude; + + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Update { + id: "missing-route".to_string(), + pattern: Some("new-*".to_string()), + provider_id: None, + priority: None, + }, + ); + assert!(result.is_err(), "update nonexistent should fail"); + } + + #[test] + fn model_route_with_codex_app_type() { + let db = Arc::new(Database::memory().expect("create database")); + seed_provider(&db, "codex", "codex-prov").expect("seed provider"); + let state = crate::AppState { + db: db.clone(), + config: RwLock::new(MultiAppConfig::default()), + proxy_service: ProxyService::new(db.clone()), + }; + let app = AppType::Codex; + + // Add a codex route + let result = handle_model_route( + &state, + &app, + ModelRouteCommand::Add { + pattern: "gpt-*".to_string(), + provider_id: "codex-prov".to_string(), + priority: 0, + }, + ); + assert!(result.is_ok(), "add codex route should succeed"); + + // Verify stored under codex + let routes = db.list_model_routes("codex").expect("list codex routes"); + assert_eq!(routes.len(), 1); + assert_eq!(routes[0].app_type, "codex"); + assert_eq!(routes[0].pattern, "gpt-*"); + + // Codex routes should NOT appear in claude listing + let claude_routes = db.list_model_routes("claude").expect("list claude routes"); + assert!( + claude_routes.is_empty(), + "codex routes should not leak to claude" + ); + } } diff --git a/src-tauri/src/cli/i18n.rs b/src-tauri/src/cli/i18n.rs index 5a6e8fef..09c7badb 100644 --- a/src-tauri/src/cli/i18n.rs +++ b/src-tauri/src/cli/i18n.rs @@ -3936,6 +3936,150 @@ pub mod texts { } } + pub fn tui_settings_model_routes_title() -> &'static str { + if is_chinese() { + "模型路由" + } else { + "Model Routes" + } + } + + pub fn tui_toast_model_route_added() -> &'static str { + if is_chinese() { + "已添加模型路由" + } else { + "Model route added" + } + } + + pub fn tui_toast_model_route_updated() -> &'static str { + if is_chinese() { + "已更新模型路由" + } else { + "Model route updated" + } + } + + pub fn tui_toast_model_route_deleted() -> &'static str { + if is_chinese() { + "已删除模型路由" + } else { + "Model route deleted" + } + } + + pub fn tui_model_route_add_pattern_title() -> &'static str { + if is_chinese() { + "添加模型路由 — 模型模式" + } else { + "Add Model Route — Pattern" + } + } + + pub fn tui_model_route_add_pattern_prompt() -> &'static str { + if is_chinese() { + "输入模型名称模式(如 *-sonnet, gpt-4*)" + } else { + "Enter model name pattern (e.g. *-sonnet, gpt-4*)" + } + } + + pub fn tui_model_route_add_provider_title() -> &'static str { + if is_chinese() { + "添加模型路由 — 供应商" + } else { + "Add Model Route — Provider" + } + } + + pub fn tui_model_route_add_provider_prompt() -> &'static str { + if is_chinese() { + "输入供应商 ID" + } else { + "Enter provider ID" + } + } + + pub fn tui_model_route_add_priority_title() -> &'static str { + if is_chinese() { + "添加模型路由 — 优先级" + } else { + "Add Model Route — Priority" + } + } + + pub fn tui_model_route_add_priority_prompt() -> &'static str { + if is_chinese() { + "输入优先级(数值越小越优先,默认 0)" + } else { + "Enter priority (lower = higher priority, default 0)" + } + } + + pub fn tui_model_route_edit_pattern_title() -> &'static str { + if is_chinese() { + "编辑模型路由 — 模型模式" + } else { + "Edit Model Route — Pattern" + } + } + + pub fn tui_model_route_edit_pattern_prompt() -> &'static str { + if is_chinese() { + "输入模型名称模式" + } else { + "Enter model name pattern" + } + } + + pub fn tui_model_route_edit_provider_title() -> &'static str { + if is_chinese() { + "编辑模型路由 — 供应商" + } else { + "Edit Model Route — Provider" + } + } + + pub fn tui_model_route_edit_provider_prompt() -> &'static str { + if is_chinese() { + "输入供应商 ID" + } else { + "Enter provider ID" + } + } + + pub fn tui_model_route_edit_priority_title() -> &'static str { + if is_chinese() { + "编辑模型路由 — 优先级" + } else { + "Edit Model Route — Priority" + } + } + + pub fn tui_model_route_edit_priority_prompt() -> &'static str { + if is_chinese() { + "输入优先级" + } else { + "Enter priority" + } + } + + pub fn tui_model_route_confirm_delete_message(pattern: &str) -> String { + if is_chinese() { + format!("确认删除模型路由 \"{pattern}\"?此操作不可撤销。") + } else { + format!("Delete model route \"{pattern}\"? This cannot be undone.") + } + } + + pub fn tui_model_route_confirm_delete_title() -> &'static str { + if is_chinese() { + "删除模型路由" + } else { + "Delete Model Route" + } + } + pub fn tui_managed_accounts_not_loaded() -> &'static str { if is_chinese() { "未加载" diff --git a/src-tauri/src/cli/tui/app.rs b/src-tauri/src/cli/tui/app.rs index 6630c15b..72c01206 100644 --- a/src-tauri/src/cli/tui/app.rs +++ b/src-tauri/src/cli/tui/app.rs @@ -1,5 +1,6 @@ use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; use ratatui::prelude::Size; +use std::collections::HashMap; use std::collections::HashSet; use unicode_width::UnicodeWidthChar; diff --git a/src-tauri/src/cli/tui/app/app_state.rs b/src-tauri/src/cli/tui/app/app_state.rs index c8a88485..b13119d5 100644 --- a/src-tauri/src/cli/tui/app/app_state.rs +++ b/src-tauri/src/cli/tui/app/app_state.rs @@ -1,4 +1,5 @@ use super::*; +use std::collections::HashMap; #[derive(Debug, Clone)] pub enum Action { @@ -111,6 +112,23 @@ pub enum Action { field: ProviderAddField, claude_idx: Option, }, + ModelRouteAdd { + pattern: String, + provider_id: String, + priority: i32, + }, + ModelRouteEdit { + id: String, + pattern: String, + provider_id: String, + priority: i32, + }, + ModelRouteDelete { + id: String, + }, + ModelRouteToggle { + id: String, + }, UsageCustomRange { range: data::UsageCustomRange, }, @@ -441,11 +459,12 @@ pub enum SettingsItem { SkipClaudeOnboarding, ClaudePluginIntegration, Proxy, + ModelRoutes, CheckForUpdates, } impl SettingsItem { - pub const ALL: [SettingsItem; 9] = [ + pub const ALL: [SettingsItem; 10] = [ SettingsItem::ManagedAccounts, SettingsItem::Language, SettingsItem::VisibleAppsMode, @@ -454,19 +473,22 @@ impl SettingsItem { SettingsItem::SkipClaudeOnboarding, SettingsItem::ClaudePluginIntegration, SettingsItem::Proxy, + SettingsItem::ModelRoutes, SettingsItem::CheckForUpdates, ]; } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum LocalProxySettingsItem { + ProxySwitch, ListenAddress, ListenPort, AutoFailover, } impl LocalProxySettingsItem { - pub const ALL: [LocalProxySettingsItem; 3] = [ + pub const ALL: [LocalProxySettingsItem; 4] = [ + LocalProxySettingsItem::ProxySwitch, LocalProxySettingsItem::ListenAddress, LocalProxySettingsItem::ListenPort, LocalProxySettingsItem::AutoFailover, @@ -543,6 +565,9 @@ pub struct App { pub proxy_output_activity_samples: Vec, pub proxy_activity_last_input_tokens: Option, pub proxy_activity_last_output_tokens: Option, + /// 按 provider 聚合的 activity 样本(provider_id → samples),用于仪表盘点阵图多色展示 + pub proxy_provider_activity_samples: HashMap>, + pub proxy_activity_last_provider_tokens: Option>, pub proxy_visual_state: Option, pub proxy_visual_transition: Option, pub quota_auto_target_key: Option, @@ -584,6 +609,8 @@ pub struct App { pub settings_idx: usize, pub settings_proxy_idx: usize, pub settings_managed_accounts_idx: usize, + /// Selected index in the model routes table. + pub model_routes_idx: usize, pub managed_auth_status: Option, pub managed_auth_loading: bool, pub managed_auth_login: Option, diff --git a/src-tauri/src/cli/tui/app/content_config.rs b/src-tauri/src/cli/tui/app/content_config.rs index b44c87bc..48f77a20 100644 --- a/src-tauri/src/cli/tui/app/content_config.rs +++ b/src-tauri/src/cli/tui/app/content_config.rs @@ -800,6 +800,9 @@ impl App { Action::None } Some(SettingsItem::Proxy) => self.push_route_and_switch(Route::SettingsProxy), + Some(SettingsItem::ModelRoutes) => { + self.push_route_and_switch(Route::SettingsModelRoutes) + } Some(SettingsItem::CheckForUpdates) => Action::CheckUpdate, None => Action::None, }, @@ -869,46 +872,53 @@ impl App { self.settings_proxy_idx = (self.settings_proxy_idx + 1).min(items_len - 1); Action::None } - KeyCode::Enter => match LocalProxySettingsItem::ALL.get(self.settings_proxy_idx) { - Some(LocalProxySettingsItem::AutoFailover) => { - self.request_auto_failover_toggle(data) - } - Some(LocalProxySettingsItem::ListenAddress) => { - if data.proxy.running { - self.push_toast( - texts::tui_toast_proxy_settings_stop_before_edit(), - ToastKind::Info, - ); - return Action::None; + KeyCode::Enter | KeyCode::Char(' ') => { + match LocalProxySettingsItem::ALL.get(self.settings_proxy_idx) { + Some(LocalProxySettingsItem::ProxySwitch) => { + return Action::SetProxyEnabled { + enabled: !data.proxy.enabled, + } } - self.overlay = Overlay::TextInput(TextInputState { - title: texts::tui_settings_proxy_title().to_string(), - prompt: texts::tui_settings_proxy_listen_address_prompt().to_string(), - input: TextInput::new(data.proxy.configured_listen_address.clone()), - submit: TextSubmit::SettingsProxyListenAddress, - secret: false, - }); - Action::None - } - Some(LocalProxySettingsItem::ListenPort) => { - if data.proxy.running { - self.push_toast( - texts::tui_toast_proxy_settings_stop_before_edit(), - ToastKind::Info, - ); - return Action::None; + Some(LocalProxySettingsItem::AutoFailover) => { + self.request_auto_failover_toggle(data) } - self.overlay = Overlay::TextInput(TextInputState { - title: texts::tui_settings_proxy_title().to_string(), - prompt: texts::tui_settings_proxy_listen_port_prompt().to_string(), - input: TextInput::new(data.proxy.configured_listen_port.to_string()), - submit: TextSubmit::SettingsProxyListenPort, - secret: false, - }); - Action::None + Some(LocalProxySettingsItem::ListenAddress) => { + if data.proxy.running { + self.push_toast( + texts::tui_toast_proxy_settings_stop_before_edit(), + ToastKind::Info, + ); + return Action::None; + } + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_settings_proxy_title().to_string(), + prompt: texts::tui_settings_proxy_listen_address_prompt().to_string(), + input: TextInput::new(data.proxy.configured_listen_address.clone()), + submit: TextSubmit::SettingsProxyListenAddress, + secret: false, + }); + Action::None + } + Some(LocalProxySettingsItem::ListenPort) => { + if data.proxy.running { + self.push_toast( + texts::tui_toast_proxy_settings_stop_before_edit(), + ToastKind::Info, + ); + return Action::None; + } + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_settings_proxy_title().to_string(), + prompt: texts::tui_settings_proxy_listen_port_prompt().to_string(), + input: TextInput::new(data.proxy.configured_listen_port.to_string()), + submit: TextSubmit::SettingsProxyListenPort, + secret: false, + }); + Action::None + } + None => Action::None, } - None => Action::None, - }, + } _ => Action::None, } } @@ -940,6 +950,61 @@ impl App { } } + pub(crate) fn on_settings_model_routes_key(&mut self, key: KeyEvent, data: &UiData) -> Action { + let routes_len = data.model_routes.rows.len(); + match key.code { + KeyCode::Up => { + self.model_routes_idx = self.model_routes_idx.saturating_sub(1); + Action::None + } + KeyCode::Down => { + if routes_len > 0 { + self.model_routes_idx = (self.model_routes_idx + 1).min(routes_len - 1); + } + Action::None + } + KeyCode::Char('a') => { + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_model_route_add_pattern_title().to_string(), + prompt: texts::tui_model_route_add_pattern_prompt().to_string(), + input: TextInput::new(String::new()), + submit: TextSubmit::ModelRouteAddPattern, + secret: false, + }); + Action::None + } + KeyCode::Char('e') => { + if let Some(row) = data.model_routes.rows.get(self.model_routes_idx) { + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_model_route_edit_pattern_title().to_string(), + prompt: texts::tui_model_route_edit_pattern_prompt().to_string(), + input: TextInput::new(row.pattern.clone()), + submit: TextSubmit::ModelRouteEditPattern { id: row.id.clone() }, + secret: false, + }); + } + Action::None + } + KeyCode::Char('d') => { + if let Some(row) = data.model_routes.rows.get(self.model_routes_idx) { + self.overlay = Overlay::Confirm(ConfirmOverlay { + title: texts::tui_model_route_confirm_delete_title().to_string(), + message: texts::tui_model_route_confirm_delete_message(&row.pattern), + action: ConfirmAction::ModelRouteDelete { id: row.id.clone() }, + }); + } + Action::None + } + KeyCode::Char(' ') => { + if let Some(row) = data.model_routes.rows.get(self.model_routes_idx) { + return Action::ModelRouteToggle { id: row.id.clone() }; + } + Action::None + } + _ => Action::None, + } + } + fn managed_auth_account_count(&self) -> usize { self.managed_auth_status .as_ref() diff --git a/src-tauri/src/cli/tui/app/menu.rs b/src-tauri/src/cli/tui/app/menu.rs index e81df2a8..88c51902 100644 --- a/src-tauri/src/cli/tui/app/menu.rs +++ b/src-tauri/src/cli/tui/app/menu.rs @@ -65,6 +65,8 @@ impl App { proxy_output_activity_samples: Vec::new(), proxy_activity_last_input_tokens: None, proxy_activity_last_output_tokens: None, + proxy_provider_activity_samples: HashMap::new(), + proxy_activity_last_provider_tokens: None, proxy_visual_state: None, proxy_visual_transition: None, quota_auto_target_key: None, @@ -102,6 +104,7 @@ impl App { settings_idx: 0, settings_proxy_idx: 0, settings_managed_accounts_idx: 0, + model_routes_idx: 0, managed_auth_status: None, managed_auth_loading: false, managed_auth_login: None, @@ -164,9 +167,10 @@ impl App { | Route::SkillsDiscover | Route::SkillsRepos | Route::SkillDetail { .. } => NavItem::Skills, - Route::Settings | Route::SettingsProxy | Route::SettingsManagedAccounts => { - NavItem::Settings - } + Route::Settings + | Route::SettingsProxy + | Route::SettingsManagedAccounts + | Route::SettingsModelRoutes => NavItem::Settings, } } @@ -324,8 +328,10 @@ impl App { pub(crate) fn reset_proxy_activity(&mut self, input_tokens: u64, output_tokens: u64) { self.proxy_input_activity_samples.clear(); self.proxy_output_activity_samples.clear(); + self.proxy_provider_activity_samples.clear(); self.proxy_activity_last_input_tokens = Some(input_tokens); self.proxy_activity_last_output_tokens = Some(output_tokens); + self.proxy_activity_last_provider_tokens = None; } pub(crate) fn observe_proxy_token_activity(&mut self, input_tokens: u64, output_tokens: u64) { @@ -365,6 +371,61 @@ impl App { } } + /// 按 provider 记录 token activity 样本,用于仪表盘点阵图多色展示 + pub(crate) fn observe_proxy_provider_activity( + &mut self, + provider_token_map: &HashMap, + ) { + // proxy 重启会令主 token 计数回退,触发 observe_proxy_token_activity 清空主样本。 + // 这里同步清空 provider 样本,保持列对齐,避免颜色栈错位退化为单色。 + let main_len = self.proxy_input_activity_samples.len(); + let prev_len = self + .proxy_provider_activity_samples + .values() + .map(|s| s.len()) + .max() + .unwrap_or(0); + if prev_len > main_len { + for samples in self.proxy_provider_activity_samples.values_mut() { + samples.clear(); + } + } + + let first_tick = self.proxy_activity_last_provider_tokens.is_none(); + let prev_map = self + .proxy_activity_last_provider_tokens + .clone() + .unwrap_or_default(); + + // Compute per-provider deltas(首 tick 全为 0,与主样本首列对齐) + for (provider_id, current_tokens) in provider_token_map { + let prev = prev_map.get(provider_id).copied().unwrap_or(0); + let delta = if first_tick || *current_tokens < prev { + 0 + } else { + current_tokens.saturating_sub(prev) + }; + let samples = self + .proxy_provider_activity_samples + .entry(provider_id.clone()) + .or_default(); + samples.push(delta); + while samples.len() > PROXY_ACTIVITY_WINDOW { + samples.remove(0); + } + } + + // Pad all provider samples to match input/output sample length + let target_len = main_len; + for samples in self.proxy_provider_activity_samples.values_mut() { + while samples.len() < target_len { + samples.insert(0, 0); + } + } + + self.proxy_activity_last_provider_tokens = Some(provider_token_map.clone()); + } + pub fn push_toast(&mut self, message: impl Into, kind: ToastKind) { self.toast = Some(Toast::new(message, kind)); } @@ -764,6 +825,7 @@ impl App { Route::Settings => self.on_settings_key(key, data), Route::SettingsProxy => self.on_settings_proxy_key(key, data), Route::SettingsManagedAccounts => self.on_settings_managed_accounts_key(key, data), + Route::SettingsModelRoutes => self.on_settings_model_routes_key(key, data), Route::Main => match key.code { KeyCode::Char('r') => Action::LocalEnvRefresh, KeyCode::Char('p') | KeyCode::Char('P') => self.main_proxy_action(data), @@ -933,5 +995,12 @@ impl App { } else { self.config_webdav_idx = self.config_webdav_idx.min(config_webdav_len - 1); } + + let routes_len = data.model_routes.rows.len(); + if routes_len == 0 { + self.model_routes_idx = 0; + } else { + self.model_routes_idx = self.model_routes_idx.min(routes_len - 1); + } } } diff --git a/src-tauri/src/cli/tui/app/overlay_handlers/dialogs.rs b/src-tauri/src/cli/tui/app/overlay_handlers/dialogs.rs index c09003f1..538245a1 100644 --- a/src-tauri/src/cli/tui/app/overlay_handlers/dialogs.rs +++ b/src-tauri/src/cli/tui/app/overlay_handlers/dialogs.rs @@ -143,6 +143,9 @@ impl App { }; return Some(Action::None); } + ConfirmAction::ModelRouteDelete { id } => { + Action::ModelRouteDelete { id: id.clone() } + } }; self.close_overlay(); action @@ -363,6 +366,101 @@ impl App { } TextSubmit::WebDavJianguoyunUsername => self.handle_webdav_username_submit(raw), TextSubmit::WebDavJianguoyunPassword => self.handle_webdav_password_submit(raw), + TextSubmit::ModelRouteAddPattern => { + if raw.is_empty() { + self.push_toast( + texts::tui_toast_provider_add_missing_fields(), + ToastKind::Warning, + ); + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_model_route_add_pattern_title().to_string(), + prompt: texts::tui_model_route_add_pattern_prompt().to_string(), + input: TextInput::new(raw), + submit: TextSubmit::ModelRouteAddPattern, + secret: false, + }); + return Action::None; + } + // 打开 provider 选择器而非文本输入 + self.overlay = Overlay::ModelRouteProviderPicker { + pattern: raw, + selected: 0, + editing: false, + existing_id: None, + }; + Action::None + } + TextSubmit::ModelRouteAddProvider { .. } => { + // 不再使用 — provider 选择器直接跳到优先级步骤 + Action::None + } + TextSubmit::ModelRouteAddPriority { + pattern, + provider_id, + } => { + let priority: i32 = raw.trim().parse().unwrap_or(0); + Action::ModelRouteAdd { + pattern, + provider_id, + priority, + } + } + TextSubmit::ModelRouteEditPattern { id } => { + if raw.is_empty() { + self.push_toast( + texts::tui_toast_provider_add_missing_fields(), + ToastKind::Warning, + ); + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_model_route_edit_pattern_title().to_string(), + prompt: texts::tui_model_route_edit_pattern_prompt().to_string(), + input: TextInput::new(raw), + submit: TextSubmit::ModelRouteEditPattern { id }, + secret: false, + }); + return Action::None; + } + // 编辑时预选当前 provider,避免回车静默改成首个 provider (Codex P2) + let selected = data + .model_routes + .rows + .iter() + .find(|row| row.id == id) + .and_then(|route| { + data.providers + .rows + .iter() + .position(|p| p.id == route.provider_id) + }) + .unwrap_or(0); + self.overlay = Overlay::ModelRouteProviderPicker { + pattern: raw, + + selected, + + editing: true, + + existing_id: Some(id), + }; + + Action::None + } + + TextSubmit::ModelRouteEditProvider { .. } => Action::None, + + TextSubmit::ModelRouteEditPriority { + id, + pattern, + provider_id, + } => { + let priority: i32 = raw.trim().parse().unwrap_or(0); + Action::ModelRouteEdit { + id, + pattern, + provider_id, + priority, + } + } } } diff --git a/src-tauri/src/cli/tui/app/overlay_handlers/views.rs b/src-tauri/src/cli/tui/app/overlay_handlers/views.rs index 7b4c8eae..5342479c 100644 --- a/src-tauri/src/cli/tui/app/overlay_handlers/views.rs +++ b/src-tauri/src/cli/tui/app/overlay_handlers/views.rs @@ -36,6 +36,9 @@ impl App { if let Some(action) = self.handle_backup_picker_key(key, data) { return Some(action); } + if let Some(action) = self.handle_model_route_provider_picker_key(key, data) { + return Some(action); + } if let Some(action) = self.handle_text_view_overlay_key(key, data) { return Some(action); } @@ -332,4 +335,96 @@ impl App { _ => Action::None, }) } + + fn handle_model_route_provider_picker_key( + &mut self, + key: KeyEvent, + data: &UiData, + ) -> Option { + let Overlay::ModelRouteProviderPicker { + pattern, + selected, + editing, + existing_id, + } = &mut self.overlay + else { + return None; + }; + + let providers = &data.providers.rows; + + Some(match key.code { + KeyCode::Esc => { + self.overlay = Overlay::TextInput(TextInputState { + title: if *editing { + texts::tui_model_route_edit_pattern_title().to_string() + } else { + texts::tui_model_route_add_pattern_title().to_string() + }, + prompt: if *editing { + texts::tui_model_route_edit_pattern_prompt().to_string() + } else { + texts::tui_model_route_add_pattern_prompt().to_string() + }, + input: TextInput::new(pattern.clone()), + submit: if *editing { + TextSubmit::ModelRouteEditPattern { + id: existing_id.clone().unwrap_or_default(), + } + } else { + TextSubmit::ModelRouteAddPattern + }, + secret: false, + }); + Action::None + } + KeyCode::Up => { + *selected = selected.saturating_sub(1); + Action::None + } + KeyCode::Down => { + if !providers.is_empty() { + *selected = (*selected + 1).min(providers.len() - 1); + } + Action::None + } + KeyCode::Enter => { + if let Some(provider_row) = providers.get(*selected) { + let provider_id = provider_row.id.clone(); + let pattern = std::mem::take(pattern); + let is_editing = *editing; + let eid = existing_id.clone(); + // 编辑时预填原有 priority,避免误改顺序;新增时默认 0 + let priority_input = if is_editing { + eid.as_ref() + .and_then(|id| data.model_routes.rows.iter().find(|row| &row.id == id)) + .map(|row| row.priority.to_string()) + .unwrap_or_else(|| "0".to_string()) + } else { + "0".to_string() + }; + self.overlay = Overlay::TextInput(TextInputState { + title: texts::tui_model_route_add_priority_title().to_string(), + prompt: texts::tui_model_route_add_priority_prompt().to_string(), + input: TextInput::new(priority_input), + submit: if is_editing { + TextSubmit::ModelRouteEditPriority { + id: eid.unwrap_or_default(), + pattern, + provider_id, + } + } else { + TextSubmit::ModelRouteAddPriority { + pattern, + provider_id, + } + }, + secret: false, + }); + } + Action::None + } + _ => Action::None, + }) + } } diff --git a/src-tauri/src/cli/tui/app/tests.rs b/src-tauri/src/cli/tui/app/tests.rs index 3e343263..5ae57412 100644 --- a/src-tauri/src/cli/tui/app/tests.rs +++ b/src-tauri/src/cli/tui/app/tests.rs @@ -1133,6 +1133,56 @@ mod tests { assert_eq!(app.proxy_activity_last_output_tokens, Some(8)); } + #[test] + fn proxy_provider_activity_aligns_with_main_samples_on_first_tick() { + let mut app = App::new(Some(AppType::Claude)); + + // 首 tick:主样本 push 一个 0,provider 样本必须同长(修复前会因 + // 静默 return 而落后一列,导致点阵图颜色栈错位退化为单色)。 + app.reset_proxy_activity(10, 20); + app.observe_proxy_token_activity(10, 20); + let mut map = HashMap::new(); + map.insert("p1".to_string(), 5); + app.observe_proxy_provider_activity(&map); + + assert_eq!(app.proxy_input_activity_samples.len(), 1); + assert_eq!( + app.proxy_provider_activity_samples + .get("p1") + .map(|s| s.len()), + Some(1), + "provider samples must align with main samples from the first tick" + ); + } + + #[test] + fn proxy_provider_activity_resyncs_after_proxy_restart() { + let mut app = App::new(Some(AppType::Claude)); + + // 正常积累几个 tick + app.reset_proxy_activity(0, 0); + for i in 1..=3 { + app.observe_proxy_token_activity(i * 10, i * 20); + let mut map = HashMap::new(); + map.insert("p1".to_string(), i * 5); + app.observe_proxy_provider_activity(&map); + } + assert_eq!(app.proxy_input_activity_samples.len(), 3); + assert_eq!(app.proxy_provider_activity_samples["p1"].len(), 3); + + // proxy 重启:主计数回退触发主样本清空,provider 样本必须同步清空 + app.observe_proxy_token_activity(1, 2); + assert_eq!(app.proxy_input_activity_samples, vec![0]); + let mut map = HashMap::new(); + map.insert("p1".to_string(), 1); + app.observe_proxy_provider_activity(&map); + assert_eq!( + app.proxy_provider_activity_samples["p1"].len(), + 1, + "provider samples must resync after proxy restart realigns main samples" + ); + } + #[test] fn proxy_transition_starts_when_proxy_route_state_changes() { let mut app = App::new(Some(AppType::Claude)); diff --git a/src-tauri/src/cli/tui/app/types.rs b/src-tauri/src/cli/tui/app/types.rs index e185c4b3..12f4b1db 100644 --- a/src-tauri/src/cli/tui/app/types.rs +++ b/src-tauri/src/cli/tui/app/types.rs @@ -473,6 +473,9 @@ pub enum ConfirmAction { ClaudeModelFillAll { source_idx: usize, }, + ModelRouteDelete { + id: String, + }, } #[derive(Debug, Clone)] @@ -509,6 +512,26 @@ pub enum TextSubmit { }, WebDavJianguoyunUsername, WebDavJianguoyunPassword, + ModelRouteAddPattern, + ModelRouteAddProvider { + pattern: String, + }, + ModelRouteAddPriority { + pattern: String, + provider_id: String, + }, + ModelRouteEditPattern { + id: String, + }, + ModelRouteEditProvider { + id: String, + pattern: String, + }, + ModelRouteEditPriority { + id: String, + pattern: String, + provider_id: String, + }, } #[derive(Debug, Clone)] @@ -622,6 +645,12 @@ pub enum Overlay { UsageQueryTemplatePicker { selected: usize, }, + ModelRouteProviderPicker { + pattern: String, + selected: usize, + editing: bool, // true=edit mode (has existing id), false=add mode + existing_id: Option, // for edit mode + }, ManagedAccountPicker { auth_provider: String, selected: usize, @@ -733,6 +762,7 @@ impl Overlay { matches!( self, Overlay::BackupPicker { .. } + | Overlay::ModelRouteProviderPicker { .. } | Overlay::TextView(_) | Overlay::CommonSnippetPicker { .. } | Overlay::ProviderTestMenu { .. } @@ -772,6 +802,7 @@ impl Overlay { | Overlay::Help(_) | Overlay::Confirm(_) | Overlay::BackupPicker { .. } + | Overlay::ModelRouteProviderPicker { .. } | Overlay::TextView(_) | Overlay::CommonSnippetPicker { .. } | Overlay::ProviderTestMenu { .. } diff --git a/src-tauri/src/cli/tui/data.rs b/src-tauri/src/cli/tui/data.rs index 772b4552..cb7b7b63 100644 --- a/src-tauri/src/cli/tui/data.rs +++ b/src-tauri/src/cli/tui/data.rs @@ -73,6 +73,23 @@ pub(crate) struct ProviderQuotaState { pub(crate) updated_at: Option, } +#[derive(Debug, Clone)] +pub struct ModelRouteRow { + pub id: String, + pub pattern: String, + pub provider_id: String, + pub provider_name: String, + pub priority: i32, + pub enabled: bool, + pub hit_count: i64, + pub last_hit_at: Option, +} + +#[derive(Debug, Clone, Default)] +pub struct ModelRouteSnapshot { + pub rows: Vec, +} + #[derive(Debug, Clone, Default)] pub(crate) struct QuotaSnapshot { by_provider: HashMap, @@ -296,6 +313,8 @@ pub struct ProxySnapshot { pub last_error: Option, #[allow(dead_code)] pub current_app_target: Option, + /// 按 provider 聚合的预估 token 数(provider_id → token_count) + pub provider_token_map: HashMap, } impl ProxySnapshot { @@ -838,6 +857,7 @@ pub struct UiData { pub proxy: ProxySnapshot, pub usage: UsageSnapshot, pub pricing: ModelPricingSnapshot, + pub model_routes: ModelRouteSnapshot, pub(crate) quota: QuotaSnapshot, pub(crate) reload_token: UiDataReloadToken, } @@ -853,6 +873,7 @@ impl Default for UiData { proxy: ProxySnapshot::default(), usage: UsageSnapshot::default(), pricing: ModelPricingSnapshot::default(), + model_routes: ModelRouteSnapshot::default(), quota: QuotaSnapshot::default(), reload_token: UiDataReloadToken::default(), } @@ -930,6 +951,8 @@ impl UiData { }; let proxy = load_proxy_snapshot_from_state(state, app_type)?; + let model_routes = load_model_routes_snapshot(state, app_type, &providers)?; + Ok(Self { providers, mcp, @@ -937,6 +960,7 @@ impl UiData { config, skills, proxy, + model_routes, usage: UsageSnapshot::default(), pricing: ModelPricingSnapshot::default(), quota: QuotaSnapshot::default(), @@ -945,7 +969,11 @@ impl UiData { } pub(crate) fn refresh_proxy_snapshot(&mut self, app_type: &AppType) -> Result<(), AppError> { - self.proxy = load_proxy_snapshot(app_type)?; + let state = load_state()?; + self.proxy = load_proxy_snapshot_from_state(&state, app_type)?; + // 同时刷新 model_routes 命中统计,使仪表盘的路由命中图例 + // 能反映代理运行期间累积的 hit_count(否则停留在 UiData::load 快照)。 + self.model_routes = load_model_routes_snapshot(&state, app_type, &self.providers)?; Ok(()) } @@ -962,6 +990,7 @@ impl UiData { config: self.config.loading_projection(app_type), skills: self.skills.clone(), proxy, + model_routes: ModelRouteSnapshot::default(), usage: UsageSnapshot::default(), pricing: ModelPricingSnapshot::default(), quota: QuotaSnapshot::default(), @@ -2596,10 +2625,46 @@ fn load_proxy_snapshot_from_state( .filter(|value| !value.is_empty()) .map(str::to_string), current_app_target, + provider_token_map: runtime_status.provider_token_map, }) }) } +fn load_model_routes_snapshot( + state: &AppState, + app_type: &AppType, + providers: &ProvidersSnapshot, +) -> Result { + let model_routes = state.db.list_model_routes(app_type.as_str())?; + + let mut rows = model_routes + .into_iter() + .map(|route| { + let provider_name = providers + .rows + .iter() + .find(|p| p.id == route.provider_id) + .map(|p| crate::cli::tui::data::provider_display_name(app_type, p)) + .unwrap_or_else(|| route.provider_id.clone()); + + ModelRouteRow { + id: route.id, + pattern: route.pattern, + provider_id: route.provider_id, + provider_name, + priority: route.priority, + enabled: route.enabled, + hit_count: route.hit_count, + last_hit_at: route.last_hit_at, + } + }) + .collect::>(); + + rows.sort_by(|a, b| a.priority.cmp(&b.priority).then_with(|| a.id.cmp(&b.id))); + + Ok(ModelRouteSnapshot { rows }) +} + fn load_skills_snapshot() -> Result { Ok(SkillsSnapshot { installed: SkillService::list_installed()?, diff --git a/src-tauri/src/cli/tui/mod.rs b/src-tauri/src/cli/tui/mod.rs index 6de6c512..79732fa2 100644 --- a/src-tauri/src/cli/tui/mod.rs +++ b/src-tauri/src/cli/tui/mod.rs @@ -979,6 +979,10 @@ fn cache_invalidation_for_action(action: &Action) -> CacheInvalidation { | Action::ProviderDelete { .. } | Action::ProviderSetFailoverQueue { .. } | Action::ProviderMoveFailoverQueue { .. } + | Action::ModelRouteAdd { .. } + | Action::ModelRouteEdit { .. } + | Action::ModelRouteDelete { .. } + | Action::ModelRouteToggle { .. } | Action::EditorSubmit { submit: EditorSubmit::ProviderAdd | EditorSubmit::ProviderEdit { .. }, .. @@ -1980,6 +1984,7 @@ pub fn run(app_override: Option) -> Result<(), AppError> { data.proxy.estimated_input_tokens_total, data.proxy.estimated_output_tokens_total, ); + app.observe_proxy_provider_activity(&data.proxy.provider_token_map); } } queue_current_quota_refresh_if_due( diff --git a/src-tauri/src/cli/tui/route.rs b/src-tauri/src/cli/tui/route.rs index ea1b6d83..2ccce8ad 100644 --- a/src-tauri/src/cli/tui/route.rs +++ b/src-tauri/src/cli/tui/route.rs @@ -27,6 +27,7 @@ pub enum Route { Settings, SettingsProxy, SettingsManagedAccounts, + SettingsModelRoutes, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src-tauri/src/cli/tui/runtime_actions/mod.rs b/src-tauri/src/cli/tui/runtime_actions/mod.rs index f8fc914c..59fb228b 100644 --- a/src-tauri/src/cli/tui/runtime_actions/mod.rs +++ b/src-tauri/src/cli/tui/runtime_actions/mod.rs @@ -18,6 +18,7 @@ mod config; mod editor; mod helpers; mod mcp; +mod model_routes; mod pricing; mod prompts; mod providers; @@ -429,6 +430,19 @@ pub(crate) fn handle_action( Action::PromptOpenImportCandidate { filename, content } => { prompts::open_import_candidate(&mut ctx, filename, content) } + Action::ModelRouteAdd { + pattern, + provider_id, + priority, + } => model_routes::handle_add(&mut ctx, pattern, provider_id, priority), + Action::ModelRouteEdit { + id, + pattern, + provider_id, + priority, + } => model_routes::handle_edit(&mut ctx, id, pattern, provider_id, priority), + Action::ModelRouteDelete { id } => model_routes::handle_delete(&mut ctx, id), + Action::ModelRouteToggle { id } => model_routes::handle_toggle(&mut ctx, id), Action::ConfigExport { path } => config::export(&mut ctx, path), Action::ConfigShowFull => config::show_full(&mut ctx), Action::ConfigImport { path } => config::import(&mut ctx, path), diff --git a/src-tauri/src/cli/tui/runtime_actions/model_routes.rs b/src-tauri/src/cli/tui/runtime_actions/model_routes.rs new file mode 100644 index 00000000..60ec8b4d --- /dev/null +++ b/src-tauri/src/cli/tui/runtime_actions/model_routes.rs @@ -0,0 +1,133 @@ +use crate::cli::i18n::texts; +use crate::error::AppError; +use crate::model_route::ModelRoute; + +use super::super::app::ToastKind; +use super::super::data::{load_state, ModelRouteRow, ModelRouteSnapshot}; +use super::RuntimeActionContext; + +fn refresh_model_routes_data(ctx: &mut RuntimeActionContext<'_>) -> Result<(), AppError> { + let state = load_state()?; + let routes = state.db.list_model_routes(ctx.app.app_type.as_str())?; + + let rows: Vec = routes + .into_iter() + .map(|route| { + let provider_name = ctx + .data + .providers + .rows + .iter() + .find(|p| p.id == route.provider_id) + .map(|p| super::super::data::provider_display_name(&ctx.app.app_type, p)) + .unwrap_or_else(|| route.provider_id.clone()); + + ModelRouteRow { + id: route.id, + pattern: route.pattern, + provider_id: route.provider_id, + provider_name, + priority: route.priority, + enabled: route.enabled, + hit_count: route.hit_count, + last_hit_at: route.last_hit_at, + } + }) + .collect(); + + ctx.data.model_routes = ModelRouteSnapshot { rows }; + ctx.app.clamp_selections(ctx.data); + ctx.data.mark_current_app_data_changed(); + Ok(()) +} + +pub(super) fn handle_add( + ctx: &mut RuntimeActionContext<'_>, + pattern: String, + provider_id: String, + priority: i32, +) -> Result<(), AppError> { + let state = load_state()?; + let route = ModelRoute { + id: String::new(), + app_type: ctx.app.app_type.as_str().to_string(), + pattern, + provider_id, + priority, + enabled: true, + created_at: None, + + hit_count: 0, + + last_hit_at: None, + updated_at: None, + }; + + state.db.create_model_route(&route)?; + refresh_model_routes_data(ctx)?; + ctx.app + .push_toast(texts::tui_toast_model_route_added(), ToastKind::Success); + ctx.app.overlay = super::super::app::Overlay::None; + Ok(()) +} + +pub(super) fn handle_edit( + ctx: &mut RuntimeActionContext<'_>, + id: String, + pattern: String, + provider_id: String, + priority: i32, +) -> Result<(), AppError> { + let state = load_state()?; + // 保留已有的 enabled 状态,不因编辑而静默恢复已禁用的路由 + let enabled = state + .db + .get_model_route(&id) + .ok() + .flatten() + .map(|existing| existing.enabled) + .unwrap_or(true); + let route = ModelRoute { + id: String::new(), + app_type: ctx.app.app_type.as_str().to_string(), + pattern, + provider_id, + priority, + enabled, + created_at: None, + + hit_count: 0, + + last_hit_at: None, + updated_at: None, + }; + + state.db.update_model_route(&id, &route)?; + refresh_model_routes_data(ctx)?; + ctx.app + .push_toast(texts::tui_toast_model_route_updated(), ToastKind::Success); + ctx.app.overlay = super::super::app::Overlay::None; + Ok(()) +} + +pub(super) fn handle_delete( + ctx: &mut RuntimeActionContext<'_>, + id: String, +) -> Result<(), AppError> { + let state = load_state()?; + state.db.delete_model_route(&id)?; + refresh_model_routes_data(ctx)?; + ctx.app + .push_toast(texts::tui_toast_model_route_deleted(), ToastKind::Success); + Ok(()) +} + +pub(super) fn handle_toggle( + ctx: &mut RuntimeActionContext<'_>, + id: String, +) -> Result<(), AppError> { + let state = load_state()?; + state.db.toggle_model_route(&id)?; + refresh_model_routes_data(ctx)?; + Ok(()) +} diff --git a/src-tauri/src/cli/tui/ui.rs b/src-tauri/src/cli/tui/ui.rs index ee618f99..2ba25a64 100644 --- a/src-tauri/src/cli/tui/ui.rs +++ b/src-tauri/src/cli/tui/ui.rs @@ -39,6 +39,7 @@ mod editor; mod forms; mod main_page; mod mcp; +mod model_routes; mod overlay; mod pricing; mod prompts; @@ -61,6 +62,7 @@ use editor::*; use forms::*; use main_page::*; use mcp::*; +use model_routes::*; use overlay::*; use pricing::*; use prompts::*; @@ -197,6 +199,9 @@ fn render_content( Route::SettingsManagedAccounts => { render_settings_managed_accounts(frame, app, data, content_area, theme) } + Route::SettingsModelRoutes => { + render_settings_model_routes(frame, app, data, content_area, theme) + } } } diff --git a/src-tauri/src/cli/tui/ui/config.rs b/src-tauri/src/cli/tui/ui/config.rs index 56efcf5b..ac89a83b 100644 --- a/src-tauri/src/cli/tui/ui/config.rs +++ b/src-tauri/src/cli/tui/ui/config.rs @@ -21,6 +21,7 @@ pub(super) fn webdav_config_item_label(item: &WebDavConfigItem) -> &'static str pub(super) fn local_proxy_settings_item_label(item: &LocalProxySettingsItem) -> &'static str { match item { + LocalProxySettingsItem::ProxySwitch => crate::t!("Proxy enabled", "代理开关"), LocalProxySettingsItem::ListenAddress => texts::tui_settings_proxy_listen_address_label(), LocalProxySettingsItem::ListenPort => texts::tui_settings_proxy_listen_port_label(), LocalProxySettingsItem::AutoFailover => crate::t!("Automatic failover", "自动故障转移"), @@ -2565,6 +2566,10 @@ pub(super) fn render_settings( data.proxy.configured_listen_address, data.proxy.configured_listen_port, ), ), + super::app::SettingsItem::ModelRoutes => ( + texts::tui_settings_model_routes_title().to_string(), + format!("{} rules", data.model_routes.rows.len()), + ), super::app::SettingsItem::CheckForUpdates => ( texts::tui_settings_check_for_updates().to_string(), format!("v{}", env!("CARGO_PKG_VERSION")), @@ -3106,6 +3111,14 @@ pub(super) fn render_settings_proxy( let rows_data = LocalProxySettingsItem::ALL .iter() .map(|item| match item { + LocalProxySettingsItem::ProxySwitch => ( + local_proxy_settings_item_label(item).to_string(), + if data.proxy.enabled { + texts::enabled().to_string() + } else { + texts::disabled().to_string() + }, + ), LocalProxySettingsItem::ListenAddress => ( local_proxy_settings_item_label(item).to_string(), data.proxy.configured_listen_address.clone(), diff --git a/src-tauri/src/cli/tui/ui/main_page.rs b/src-tauri/src/cli/tui/ui/main_page.rs index 58f8dfd0..7b7df2fc 100644 --- a/src-tauri/src/cli/tui/ui/main_page.rs +++ b/src-tauri/src/cli/tui/ui/main_page.rs @@ -1,7 +1,14 @@ use crate::cli::tui::data; +use std::collections::{HashMap, HashSet}; use super::*; +/// Dracula purple — used for input (downstream) graph to contrast with accent-colored output. +const DRACULA_PURPLE: (u8, u8, u8) = (189, 147, 249); + +/// 图例中最低显示 token 数(近期窗口增量);低于此值的 provider 会从图例中隐藏,避免 0% 干扰主图例 +const LEGEND_MIN_RECENT_TOKENS: u64 = 1_000; + fn opencode_configured_provider_count(data: &UiData) -> usize { data.providers .rows @@ -291,12 +298,16 @@ pub(super) fn render_main( .split(chunks[1]); if current_app_routed { + // 收集近期 token 活动按 provider 聚合(用于多色图例,与点阵图同口径) + let route_hits = + collect_route_hits_for_dashboard(data, &app.proxy_provider_activity_samples); render_proxy_activity_dashboard( frame, hero_chunks[0], theme, &app.proxy_input_activity_samples, &app.proxy_output_activity_samples, + &app.proxy_provider_activity_samples, &uptime_text, &proxy_last_error_text, data.proxy.last_error.is_some(), @@ -305,6 +316,7 @@ pub(super) fn render_main( auto_failover_queue_len, data.proxy.estimated_input_tokens_total, data.proxy.estimated_output_tokens_total, + &route_hits, ); } else { render_logo_hero(frame, hero_chunks[0], theme); @@ -328,6 +340,7 @@ fn render_proxy_activity_dashboard( theme: &super::theme::Theme, input_activity_samples: &[u64], output_activity_samples: &[u64], + provider_activity_samples: &HashMap>, uptime_text: &str, proxy_last_error_text: &str, has_proxy_error: bool, @@ -336,6 +349,7 @@ fn render_proxy_activity_dashboard( auto_failover_queue_len: usize, input_tokens_total: u64, output_tokens_total: u64, + route_hits: &[ProviderHitInfo], ) -> Rect { let has_token_traffic = input_tokens_total > 0 || output_tokens_total > 0; let title_output_style = if has_token_traffic { @@ -346,7 +360,9 @@ fn render_proxy_activity_dashboard( Style::default().fg(theme.surface) }; let title_input_style = if has_token_traffic { - Style::default().fg(theme.cyan).add_modifier(Modifier::BOLD) + Style::default() + .fg(theme::terminal_palette_color(DRACULA_PURPLE)) + .add_modifier(Modifier::BOLD) } else { Style::default().fg(theme.surface) }; @@ -422,6 +438,40 @@ fn render_proxy_activity_dashboard( ); } + // 多色 Provider 近期流量图例(与点阵图共用近期 token 口径) + // 过滤掉过小流量(< LEGEND_MIN_RECENT_TOKENS tok)的 provider + let display_hits: Vec<&ProviderHitInfo> = route_hits + .iter() + .filter(|h| h.recent_tokens >= LEGEND_MIN_RECENT_TOKENS) + .take(5) + .collect(); + if !display_hits.is_empty() { + // 总量基于所有 route_hits(含 < LEGEND_MIN_RECENT_TOKENS 的),让百分比统计更准 + let total_tokens: u64 = route_hits.iter().map(|h| h.recent_tokens).sum(); + if total_tokens > 0 { + let legend_label = crate::t!("Recent tokens", "近期流量"); + meta_spans.push(Span::raw(" ")); + meta_spans.push(Span::styled(format!("{legend_label}: "), label_style)); + meta_plain.push_str(" "); + meta_plain.push_str(&legend_label); + meta_plain.push_str(": "); + for (i, hit) in display_hits.iter().enumerate() { + if i > 0 { + meta_spans.push(Span::raw(", ")); + meta_plain.push_str(", "); + } + let pct = (hit.recent_tokens as f64 / total_tokens as f64) * 100.0; + let tok_text = format_estimated_token_compact(hit.recent_tokens); + let text = format!("{} {}% ({})", hit.display_name, pct as i32, tok_text); + meta_spans.push(Span::styled( + text.clone(), + Style::default().fg(hit.color).add_modifier(Modifier::BOLD), + )); + meta_plain.push_str(&text); + } + } + } + let max_text_height = inner.height.saturating_sub(2).clamp(1, 4); let text_height = wrapped_display_line_count(&meta_plain, inner.width).min(max_text_height); let graph_height = inner.height.saturating_sub(text_height).max(2); @@ -443,37 +493,121 @@ fn render_proxy_activity_dashboard( let lower_height = graph_height.saturating_sub(upper_height).max(1); let wave_width = sections[1].width.saturating_sub(1); let mut graph_lines = Vec::new(); - let upper_style = Style::default().fg(theme.accent); - let lower_style = if theme.no_color { + + // 从图例数据构建 provider_id → 颜色映射(与 legend 颜色一致) + let mut provider_color_map: HashMap = route_hits + .iter() + .map(|h| (h.provider_id.clone(), h.color)) + .collect(); + + // 补全颜色:直接切换的 provider(不在 route_hits 中)但仍在活动 sample 里。 + // 复用图例同款调色板,按 i % 8 取色,确保点阵有颜色。 + let palette: [Color; 8] = + PER_PROVIDER_PALETTE_RGBS.map(|rgb| theme::terminal_palette_color(rgb)); + let palette_len = palette.len(); + if palette_len > 0 { + // 先收齐所有缺失颜色的 provider_id,避免借用冲突 + let missing: Vec = provider_activity_samples + .keys() + .filter(|id| !provider_color_map.contains_key(*id)) + .cloned() + .collect(); + for (i, provider_id) in missing.iter().enumerate() { + provider_color_map.insert(provider_id.clone(), palette[i % palette_len]); + } + } + + let visible_provider_ids: HashSet = + route_hits.iter().map(|h| h.provider_id.clone()).collect(); + let visible_samples: Vec<(&String, &Vec)> = provider_activity_samples + .iter() + .filter(|(id, _)| visible_provider_ids.contains(*id)) + .collect(); + + // 点阵每列实际占据的行数(从底部算)。颜色只填点阵字符所在的区间,避免 minor + // provider 颜色被分配到点阵空白行而不可见(图例颜色与点阵颜色对不上的根因)。 + let upper_filled = + column_filled_rows(wave_width as usize, upper_height, output_activity_samples); + let lower_filled = + column_filled_rows(wave_width as usize, lower_height, input_activity_samples); + + let upper_color_stacks = compute_column_color_stacks( + visible_samples.iter().copied(), + wave_width as usize, + &provider_color_map, + upper_height as usize, + &upper_filled, + ); + let lower_color_stacks = compute_column_color_stacks( + visible_samples.iter().copied(), + wave_width as usize, + &provider_color_map, + lower_height as usize, + &lower_filled, + ); + + let upper_rows = proxy_wave_lines( + wave_width, + upper_height, + true, + output_activity_samples, + &DOTS, + false, + ); + let lower_rows = proxy_wave_lines( + wave_width, + lower_height, + true, + input_activity_samples, + &REV_DOTS, + true, + ); + + let default_upper = Style::default().fg(theme.accent); + let default_lower = if theme.no_color { Style::default() } else { - Style::default().fg(theme.cyan) + Style::default().fg(theme::terminal_palette_color(DRACULA_PURPLE)) }; - graph_lines.extend( - proxy_wave_lines( - wave_width, - upper_height, - true, - output_activity_samples, - &DOTS, - false, - ) - .into_iter() - .map(|row| Line::from(vec![Span::raw(" "), Span::styled(row, upper_style)])), - ); - graph_lines.extend( - proxy_wave_lines( - wave_width, - lower_height, - true, - input_activity_samples, - &REV_DOTS, - true, - ) - .into_iter() - .map(|row| Line::from(vec![Span::raw(" "), Span::styled(row, lower_style)])), - ); + // 上半部分(output),每列按 provider 颜色 + for (row_idx, row) in upper_rows.iter().enumerate() { + let mut spans = vec![Span::raw(" ")]; + for (col_idx, ch) in row.chars().enumerate() { + let style = match stack_color_at(&upper_color_stacks, col_idx, row_idx) { + Some(provider_color) => { + if theme.no_color { + Style::default().add_modifier(Modifier::BOLD) + } else { + // 上半部使用 provider 颜色,稍微调亮 + Style::default().fg(provider_color) + } + } + None => default_upper, + }; + spans.push(Span::styled(ch.to_string(), style)); + } + graph_lines.push(Line::from(spans)); + } + + // 下半部分(input),使用与上半部相同的 per-provider 颜色 + for (row_idx, row) in lower_rows.iter().enumerate() { + let mut spans = vec![Span::raw(" ")]; + for (col_idx, ch) in row.chars().enumerate() { + let style = match stack_color_at(&lower_color_stacks, col_idx, row_idx) { + Some(provider_color) => { + if theme.no_color { + Style::default().add_modifier(Modifier::BOLD) + } else { + Style::default().fg(provider_color) + } + } + None => default_lower, + }; + spans.push(Span::styled(ch.to_string(), style)); + } + graph_lines.push(Line::from(spans)); + } frame.render_widget( Paragraph::new(graph_lines).wrap(Wrap { trim: false }), @@ -491,6 +625,243 @@ fn wrapped_display_line_count(text: &str, width: u16) -> u16 { UnicodeWidthStr::width(text).max(1).div_ceil(width as usize) as u16 } +/// 点阵图多色 palette(与 legend 共用同一组颜色) +const PER_PROVIDER_PALETTE_RGBS: [(u8, u8, u8); 8] = [ + (189, 147, 249), // 紫 + (135, 206, 250), // 天蓝 + (255, 160, 122), // 浅三文鱼 + (144, 238, 144), // 浅绿 + (221, 160, 221), // 李子紫 + (255, 215, 0), // 金 + (127, 255, 212), // 碧绿 + (176, 196, 222), // 淡钢蓝 +]; + +/// 根据 per-provider 活动样本,计算每列的垂直颜色栈。 +/// 颜色只填充该列点阵实际占据的行(`column_filled_rows`,从底部算), +/// 并在区间内按 token 占比分配行高:dominant 在底部,minor 紧贴其上。 +/// 这样每个 provider 的颜色都落在有点阵字符的行上,minor provider 也可见。 +fn compute_column_color_stacks<'a>( + provider_activity_samples: impl IntoIterator)>, + num_columns: usize, + provider_color_map: &HashMap, + stack_height: usize, + column_filled_rows: &[usize], +) -> Vec>> { + if num_columns == 0 || stack_height == 0 { + return vec![vec![None; stack_height]; num_columns]; + } + + let provider_activity_samples = provider_activity_samples.into_iter().collect::>(); + if provider_activity_samples.is_empty() { + return vec![vec![None; stack_height]; num_columns]; + } + + let mut color_stacks = vec![vec![None; stack_height]; num_columns]; + for col in 0..num_columns { + // 该列点阵实际占据的行数(从底部算)。颜色只填这个区间,避免 minor + // provider 的颜色被分配到点阵空白行(图例与点阵颜色对不上的根因)。 + let filled = column_filled_rows + .get(col) + .copied() + .unwrap_or(0) + .min(stack_height); + if filled == 0 { + continue; + } + + let mut entries = Vec::new(); + for (provider_id, samples) in &provider_activity_samples { + let tokens = samples.get(col).copied().unwrap_or(0); + if tokens > 0 { + if let Some(color) = provider_color_map.get(*provider_id).copied() { + entries.push((provider_id.as_str(), tokens, color)); + } + } + } + if entries.is_empty() { + continue; + } + + entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(b.0))); + let total_tokens = entries.iter().map(|(_, tokens, _)| *tokens).sum::(); + // 在 [0, filled) 内分配行数,dominant 占高 idx(点阵底部),minor 占低 idx(顶部字符行)。 + let mut rows = allocate_provider_rows(&entries, total_tokens, filled); + rows.reverse(); + + let base = stack_height - filled; + let mut idx = 0; + for (entry_idx, row_count) in rows { + let color = entries[entry_idx].2; + for _ in 0..row_count { + if idx >= filled { + break; + } + color_stacks[col][base + idx] = Some(color); + idx += 1; + } + } + } + color_stacks +} + +fn stack_color_at( + color_stacks: &[Vec>], + col_idx: usize, + row_idx: usize, +) -> Option { + color_stacks + .get(col_idx) + .and_then(|stack| stack.get(row_idx)) + .copied() + .flatten() +} + +/// 计算点阵每列实际占据的行数(从底部算),与 `proxy_wave_lines` 的渲染口径一致。 +/// 颜色栈据此只填充点阵有字符的区间,确保 provider 颜色落在可见的字符行上。 +fn column_filled_rows(width: usize, height: u16, samples: &[u64]) -> Vec { + if width == 0 || height == 0 { + return Vec::new(); + } + let recent = super::proxy_wave::recent_samples(width, true, samples); + let scaled = super::proxy_wave::scale_samples(height, &recent, true); + scaled + .iter() + .map(|v| ((*v as usize) + 7) / 8) + .map(|rows| rows.min(height as usize)) + .collect() +} + +fn allocate_provider_rows( + entries: &[(&str, u64, Color)], + total_tokens: u64, + stack_height: usize, +) -> Vec<(usize, usize)> { + if entries.is_empty() || total_tokens == 0 || stack_height == 0 { + return Vec::new(); + } + + let mut allocations = entries + .iter() + .enumerate() + .map(|(idx, (_, tokens, _))| { + let exact = (*tokens as f64 / total_tokens as f64) * stack_height as f64; + let mut rows = exact.floor() as usize; + if rows == 0 { + rows = 1; + } + (idx, rows, exact - exact.floor()) + }) + .collect::>(); + + let mut total_rows = allocations.iter().map(|(_, rows, _)| *rows).sum::(); + while total_rows > stack_height { + if let Some((_, rows, _)) = allocations + .iter_mut() + .filter(|(_, rows, _)| *rows > 1) + .min_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal)) + { + *rows -= 1; + total_rows -= 1; + } else { + break; + } + } + + while total_rows < stack_height { + if let Some((_, rows, _)) = allocations + .iter_mut() + .max_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal)) + { + *rows += 1; + total_rows += 1; + } else { + break; + } + } + + allocations + .into_iter() + .filter_map(|(idx, rows, _)| (rows > 0).then_some((idx, rows))) + .collect() +} + +/// Provider 命中信息(用于仪表盘多色图例和点阵图着色) +#[derive(Clone)] +struct ProviderHitInfo { + provider_id: String, + display_name: String, + /// 最近 PROXY_ACTIVITY_WINDOW 窗口的 token 增量总和(近期实际流量) + recent_tokens: u64, + color: Color, +} + +/// 从近期 token 活动样本按 provider 聚合(与点阵图同口径),分配不同颜色。 +/// 聚合源为 `samples`(按 provider 的窗口 token 增量),并补齐 model_routes 中 +/// enabled 但近期无流量的 provider(其 recent_tokens 为 0,会被图例阈值过滤)。 +fn collect_route_hits_for_dashboard( + data: &UiData, + samples: &HashMap>, +) -> Vec { + let mut agg: HashMap = HashMap::new(); + + // 1) 近期 token 增量是主信号:每个窗口 delta 之和 + for (provider_id, sample_vec) in samples { + let sum: u64 = sample_vec.iter().sum(); + agg.insert(provider_id.clone(), sum); + } + + // 2) 并集 model_routes enabled 的 provider(近期无流量的 recent_tokens 记 0, + // 下游由 LEGEND_MIN_RECENT_TOKENS 阈值过滤) + for row in &data.model_routes.rows { + if !row.enabled { + continue; + } + agg.entry(row.provider_id.clone()).or_insert(0); + } + + if agg.is_empty() { + return Vec::new(); + } + + let mut v: Vec<(String, u64)> = agg.into_iter().collect(); + // recent_tokens 降序;相同值按 provider_id 字典序,保证测试与显示稳定 + v.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + // 使用与点阵图相同的 palette,确保颜色一致 + let palette: [Color; 8] = + PER_PROVIDER_PALETTE_RGBS.map(|rgb| theme::terminal_palette_color(rgb)); + v.into_iter() + .enumerate() + .map(|(i, (provider_id, recent_tokens))| { + let display_name = data + .providers + .rows + .iter() + .find(|p| p.id == provider_id) + .map(|p| { + // 截断过长的 provider 名 + let s = p.provider.name.clone(); + if s.chars().count() > 8 { + let truncated: String = s.chars().take(6).collect(); + format!("{truncated}…") + } else { + s + } + }) + .unwrap_or_else(|| { + // provider 已被删除时使用 id 前 8 字符 + provider_id.chars().take(8).collect() + }); + ProviderHitInfo { + provider_id: provider_id.clone(), + display_name, + recent_tokens, + color: palette[i % palette.len()], + } + }) + .collect() +} + fn render_logo_hero(frame: &mut Frame<'_>, area: Rect, theme: &super::theme::Theme) { let logo_lines = logo_hero_lines(theme); let logo_height = (logo_lines.len() as u16).min(area.height); @@ -717,3 +1088,178 @@ pub(super) fn proxy_activity_wave(width: u16, current_app_routed: bool, samples: .next() .unwrap_or_default() } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cli::tui::data::{ModelRouteRow, ProviderRow}; + use crate::provider::Provider; + use serde_json::Value; + + /// 构造一个最小可用的 Provider(仅 id/name 有意义,其余留空) + fn make_provider(id: &str, name: &str) -> Provider { + Provider { + id: id.to_string(), + name: name.to_string(), + settings_config: Value::Null, + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: None, + icon: None, + icon_color: None, + in_failover_queue: false, + } + } + + /// 构造一个最小 ProviderRow + fn make_provider_row(id: &str, name: &str) -> ProviderRow { + ProviderRow { + id: id.to_string(), + provider: make_provider(id, name), + api_url: None, + is_current: false, + is_in_config: false, + is_saved: false, + is_default_model: false, + primary_model_id: None, + default_model_id: None, + } + } + + /// 构造仅含给定 providers 的 UiData + fn make_ui_data_with_providers(providers: &[(&str, &str)]) -> UiData { + let mut data = UiData::default(); + data.providers.rows = providers + .iter() + .map(|(id, name)| make_provider_row(id, name)) + .collect(); + data + } + + #[test] + fn collect_aggregates_recent_tokens_from_samples() { + let data = make_ui_data_with_providers(&[("p1", "DeepSeek"), ("p2", "Minimax")]); + let mut samples = HashMap::new(); + samples.insert("p1".to_string(), vec![100, 200, 300]); // sum = 600 + samples.insert("p2".to_string(), vec![50, 50, 50]); // sum = 150 + + let result = collect_route_hits_for_dashboard(&data, &samples); + assert_eq!(result.len(), 2); + // recent_tokens 降序:p1 在前 + assert_eq!(result[0].provider_id, "p1"); + assert_eq!(result[0].recent_tokens, 600); + assert_eq!(result[1].provider_id, "p2"); + assert_eq!(result[1].recent_tokens, 150); + assert_eq!(result[0].display_name, "DeepSeek"); + assert_eq!(result[1].display_name, "Minimax"); + } + + #[test] + fn collect_returns_empty_when_no_samples_and_no_enabled_routes() { + let data = UiData::default(); + let samples = HashMap::new(); + let result = collect_route_hits_for_dashboard(&data, &samples); + assert!( + result.is_empty(), + "expected empty Vec, got {} entries", + result.len() + ); + } + + #[test] + fn collect_unions_samples_with_model_routes_enabled_providers() { + let mut data = + make_ui_data_with_providers(&[("p_routed", "Routed"), ("p_direct", "Direct")]); + // model_routes 含一个 enabled route 指向 p_routed(无 samples,近期无流量) + data.model_routes.rows.push(ModelRouteRow { + id: "r1".to_string(), + pattern: "*".to_string(), + provider_id: "p_routed".to_string(), + provider_name: "Routed".to_string(), + priority: 0, + enabled: true, + hit_count: 999, // 历史命中不应影响 recent_tokens 口径 + last_hit_at: None, + }); + // samples 含 p_direct(直接切换,无 route) + let mut samples = HashMap::new(); + samples.insert("p_direct".to_string(), vec![400, 400]); // sum = 800 + + let result = collect_route_hits_for_dashboard(&data, &samples); + let ids: Vec<&str> = result.iter().map(|h| h.provider_id.as_str()).collect(); + assert!(ids.contains(&"p_direct"), "p_direct should be in union"); + assert!( + ids.contains(&"p_routed"), + "p_routed should be in union via model_routes" + ); + // recent_tokens 降序:p_direct(800) 在前,p_routed(0) 在后 + assert_eq!(result[0].provider_id, "p_direct"); + assert_eq!(result[1].provider_id, "p_routed"); + assert_eq!(result[1].recent_tokens, 0); + } + + #[test] + fn color_stacks_keep_multiple_providers_in_same_column() { + let mut samples = HashMap::new(); + samples.insert("p1".to_string(), vec![90]); + samples.insert("p2".to_string(), vec![10]); + + let p1 = Color::Rgb(255, 0, 0); + let p2 = Color::Rgb(0, 255, 0); + let colors = HashMap::from([("p1".to_string(), p1), ("p2".to_string(), p2)]); + + // 点阵画满 4 行:dominant(p1) 占底部,minor(p2) 占顶部字符行。 + let stacks = compute_column_color_stacks(samples.iter(), 1, &colors, 4, &[4]); + + assert_eq!(stacks.len(), 1); + assert_eq!(stacks[0].len(), 4); + assert!( + stacks[0].contains(&Some(p1)), + "dominant provider should be present" + ); + assert!( + stacks[0].contains(&Some(p2)), + "smaller provider should still be visible in the same column" + ); + } + + #[test] + fn color_stacks_allow_single_provider_to_fill_column() { + let mut samples = HashMap::new(); + samples.insert("p1".to_string(), vec![100]); + samples.insert("p2".to_string(), vec![0]); + + let p1 = Color::Rgb(255, 0, 0); + let p2 = Color::Rgb(0, 255, 0); + let colors = HashMap::from([("p1".to_string(), p1), ("p2".to_string(), p2)]); + + let stacks = compute_column_color_stacks(samples.iter(), 1, &colors, 3, &[3]); + + assert_eq!(stacks[0], vec![Some(p1), Some(p1), Some(p1)]); + } + + #[test] + fn color_stacks_only_fill_rendered_rows() { + // Regression: 点阵只画 2 行(filled=2),stack_height=4。颜色必须只填 + // 点阵字符所在的 [2, 4) 区间,minor(p2) 在顶部字符行(base=2),dominant(p1) + // 在底部,[0, 2) 的空白行保持 None,避免图例颜色与点阵颜色对不上。 + let mut samples = HashMap::new(); + samples.insert("p1".to_string(), vec![90]); + samples.insert("p2".to_string(), vec![10]); + + let p1 = Color::Rgb(255, 0, 0); + let p2 = Color::Rgb(0, 255, 0); + let colors = HashMap::from([("p1".to_string(), p1), ("p2".to_string(), p2)]); + + let stacks = compute_column_color_stacks(samples.iter(), 1, &colors, 4, &[2]); + + assert_eq!( + stacks[0], + vec![None, None, Some(p2), Some(p1)], + "colors must occupy only the rendered [base, stack_height) rows" + ); + } +} diff --git a/src-tauri/src/cli/tui/ui/model_routes.rs b/src-tauri/src/cli/tui/ui/model_routes.rs new file mode 100644 index 00000000..b17b3202 --- /dev/null +++ b/src-tauri/src/cli/tui/ui/model_routes.rs @@ -0,0 +1,91 @@ +use ratatui::{ + layout::{Constraint, Direction, Layout, Rect}, + style::{Modifier, Style}, + widgets::{Block, BorderType, Borders, Cell, Row, Table, TableState}, + Frame, +}; + +use crate::cli::i18n::texts; + +use super::{ + app::{App, Focus}, + shared::{ + highlight_symbol, inset_left, pane_border_style, render_key_bar_center, selection_style, + CONTENT_INSET_LEFT, + }, + theme::Theme, +}; + +use crate::cli::tui::data::UiData; + +pub(super) fn render_settings_model_routes( + frame: &mut Frame<'_>, + app: &App, + data: &UiData, + area: Rect, + theme: &Theme, +) { + let title = texts::tui_settings_model_routes_title(); + + let header_cells = vec![ + Cell::from("Pattern"), + Cell::from("Provider"), + Cell::from("Priority"), + Cell::from("Enabled"), + ]; + let header = + Row::new(header_cells).style(Style::default().fg(theme.dim).add_modifier(Modifier::BOLD)); + + let rows = data.model_routes.rows.iter().map(|r| { + Row::new(vec![ + Cell::from(r.pattern.clone()), + Cell::from(r.provider_name.clone()), + Cell::from(r.priority.to_string()), + Cell::from(if r.enabled { "Yes" } else { "No" }), + ]) + }); + + let constraints = vec![ + Constraint::Percentage(30), + Constraint::Percentage(35), + Constraint::Length(10), + Constraint::Length(8), + ]; + + let outer = Block::default() + .borders(Borders::ALL) + .border_type(BorderType::Plain) + .border_style(pane_border_style(app, Focus::Content, theme)) + .title(title); + frame.render_widget(outer.clone(), area); + let inner = outer.inner(area); + + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(1), Constraint::Min(0)]) + .split(inner); + + if app.focus == Focus::Content { + let selected = data.model_routes.rows.get(app.model_routes_idx); + let mut key_items: Vec<(&str, &str)> = vec![ + ("a", texts::tui_key_add()), + ("Space", texts::tui_key_toggle()), + ]; + if selected.is_some() { + key_items.push(("e", texts::tui_key_edit())); + key_items.push(("d", texts::tui_key_delete())); + }; + key_items.push(("\u{2191}\u{2193}", texts::tui_key_move())); + render_key_bar_center(frame, chunks[0], theme, &key_items); + } + + let table = Table::new(rows, constraints) + .header(header) + .block(Block::default().borders(Borders::NONE)) + .row_highlight_style(selection_style(theme)) + .highlight_symbol(highlight_symbol(theme)); + + let mut state = TableState::default(); + state.select(Some(app.model_routes_idx)); + frame.render_stateful_widget(table, inset_left(chunks[1], CONTENT_INSET_LEFT), &mut state); +} diff --git a/src-tauri/src/cli/tui/ui/overlay/basic.rs b/src-tauri/src/cli/tui/ui/overlay/basic.rs index d1832404..b0834bca 100644 --- a/src-tauri/src/cli/tui/ui/overlay/basic.rs +++ b/src-tauri/src/cli/tui/ui/overlay/basic.rs @@ -390,3 +390,70 @@ fn render_scrolling_lines(frame: &mut Frame<'_>, area: Rect, lines: &[String], s frame.render_widget(Paragraph::new(shown).wrap(Wrap { trim: false }), area); } + +pub(super) fn render_model_route_provider_picker( + frame: &mut Frame<'_>, + data: &UiData, + content_area: Rect, + theme: &theme::Theme, + selected: usize, +) { + use crate::app_config::AppType; + use crate::cli::tui::data::provider_display_name; + use ratatui::widgets::{Clear, List, ListItem, ListState}; + use unicode_width::UnicodeWidthStr; + + let area = centered_rect(OVERLAY_LG.0, OVERLAY_LG.1, content_area); + frame.render_widget(Clear, area); + + let block = Block::default() + .borders(Borders::ALL) + .border_type(BorderType::Plain) + .border_style(overlay_border_style(theme, false)) + .title(crate::t!("Select provider", "选择供应商")); + let inner = block.inner(area); + frame.render_widget(block, area); + + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(1), Constraint::Min(0)]) + .split(inner); + + render_key_bar_center( + frame, + chunks[0], + theme, + &[ + ("Enter", crate::t!("Confirm", "确认")), + ("Esc", crate::t!("Back", "返回")), + ], + ); + + let body_area = inset_top(chunks[1], 1); + let items: Vec = data + .providers + .rows + .iter() + .map(|row| { + let display = provider_display_name(&AppType::Claude, row); + let extra = if row.is_current { " (*)" } else { "" }; + let mut text = format!("{display}{extra}"); + if UnicodeWidthStr::width(text.as_str()) as u16 > body_area.width.saturating_sub(2) { + text = text + .chars() + .take(body_area.width.saturating_sub(5) as usize) + .collect::() + + "…"; + } + ListItem::new(Line::from(Span::raw(text))) + }) + .collect(); + + let list = List::new(items) + .highlight_style(selection_style(theme)) + .highlight_symbol(highlight_symbol(theme)); + + let mut state = ListState::default(); + state.select(Some(selected)); + frame.render_stateful_widget(list, body_area, &mut state); +} diff --git a/src-tauri/src/cli/tui/ui/overlay/render.rs b/src-tauri/src/cli/tui/ui/overlay/render.rs index 69dd6365..abadeb2a 100644 --- a/src-tauri/src/cli/tui/ui/overlay/render.rs +++ b/src-tauri/src/cli/tui/ui/overlay/render.rs @@ -21,6 +21,15 @@ pub(crate) fn render_overlay( Overlay::BackupPicker { selected } => { super::basic::render_backup_picker_overlay(frame, data, content_area, theme, *selected) } + Overlay::ModelRouteProviderPicker { selected, .. } => { + super::basic::render_model_route_provider_picker( + frame, + data, + content_area, + theme, + *selected, + ) + } Overlay::TextView(view) => super::basic::render_text_view_overlay( frame, content_area, diff --git a/src-tauri/src/cli/tui/ui/proxy_wave.rs b/src-tauri/src/cli/tui/ui/proxy_wave.rs index 8664c429..5fd4f4ac 100644 --- a/src-tauri/src/cli/tui/ui/proxy_wave.rs +++ b/src-tauri/src/cli/tui/ui/proxy_wave.rs @@ -57,7 +57,7 @@ pub(super) fn proxy_wave_lines( rows } -fn recent_samples(width: usize, current_app_routed: bool, samples: &[u64]) -> Vec { +pub(super) fn recent_samples(width: usize, current_app_routed: bool, samples: &[u64]) -> Vec { if !current_app_routed { return vec![0; width]; } @@ -73,7 +73,7 @@ fn recent_samples(width: usize, current_app_routed: bool, samples: &[u64]) -> Ve out } -fn scale_samples(height: u16, samples: &[u64], show_idle_baseline: bool) -> Vec { +pub(super) fn scale_samples(height: u16, samples: &[u64], show_idle_baseline: bool) -> Vec { let baseline = if show_idle_baseline { 1 } else { 0 }; let max = samples.iter().copied().max().unwrap_or(0); if max == 0 { diff --git a/src-tauri/src/cli/tui/ui/tests.rs b/src-tauri/src/cli/tui/ui/tests.rs index 1fb1ca6c..b8a43c4d 100644 --- a/src-tauri/src/cli/tui/ui/tests.rs +++ b/src-tauri/src/cli/tui/ui/tests.rs @@ -23,7 +23,7 @@ use crate::{ Focus, Overlay, TextInputState, TextSubmit, UsagePane, }, data::{ - ConfigSnapshot, McpSnapshot, ModelPricingRow, ModelPricingSnapshot, + ConfigSnapshot, McpSnapshot, ModelPricingRow, ModelPricingSnapshot, ModelRouteSnapshot, OpenClawWorkspaceSnapshot, PromptsSnapshot, ProviderRow, ProvidersSnapshot, ProxySnapshot, SkillsSnapshot, UiData, UsageLogRow, UsageProviderStatsRow, UsageRangePreset, UsageSnapshot, UsageSummarySnapshot, UsageTrendBucket, @@ -1562,6 +1562,7 @@ pub(super) fn minimal_data(_app_type: &AppType) -> UiData { usage: UsageSnapshot::default(), pricing: Default::default(), quota: Default::default(), + model_routes: ModelRouteSnapshot::default(), reload_token: Default::default(), } } diff --git a/src-tauri/src/database/dao/mod.rs b/src-tauri/src/database/dao/mod.rs index ab7e2742..a076ed88 100644 --- a/src-tauri/src/database/dao/mod.rs +++ b/src-tauri/src/database/dao/mod.rs @@ -5,6 +5,7 @@ pub mod failover; pub mod mcp; pub mod model_pricing; +pub mod model_routes; pub mod prompts; pub mod providers; pub mod providers_seed; diff --git a/src-tauri/src/database/dao/model_routes.rs b/src-tauri/src/database/dao/model_routes.rs new file mode 100644 index 00000000..e6f04fc7 --- /dev/null +++ b/src-tauri/src/database/dao/model_routes.rs @@ -0,0 +1,473 @@ +//! 模型路由 DAO (Model Route Data Access Object) +//! +//! 管理 model_routes 表的 CRUD 操作,为 per-model provider routing 提供持久化层。 +//! 支持按 app_type 列出路由、创建/更新/删除路由、切换启用状态、记录命中统计。 +//! id 使用 UUID v4 (TEXT PRIMARY KEY),与上游 cc-switch 一致。 + +use crate::database::{lock_conn, Database}; +use crate::error::AppError; +use crate::model_route::ModelRoute; + +const SELECT_COLS: &str = "id, app_type, pattern, provider_id, priority, enabled, hit_count, last_hit_at, created_at, updated_at"; + +impl Database { + /// 列出指定 app_type 的所有模型路由,按 priority ASC, created_at ASC 排序 + pub fn list_model_routes(&self, app_type: &str) -> Result, AppError> { + let conn = lock_conn!(self.conn); + + let mut stmt = conn + .prepare(&format!( + "SELECT {SELECT_COLS} FROM model_routes WHERE app_type = ?1 ORDER BY priority ASC, created_at ASC" + )) + .map_err(|e| AppError::Database(e.to_string()))?; + + let items = stmt + .query_map([app_type], |row| Ok(row_to_route(row))) + .map_err(|e| AppError::Database(e.to_string()))? + .collect::, _>>() + .map_err(|e| AppError::Database(e.to_string()))?; + + Ok(items) + } + + /// 根据 ID 获取单个模型路由 + pub fn get_model_route(&self, id: &str) -> Result, AppError> { + let conn = lock_conn!(self.conn); + + let mut stmt = conn + .prepare(&format!( + "SELECT {SELECT_COLS} FROM model_routes WHERE id = ?1" + )) + .map_err(|e| AppError::Database(e.to_string()))?; + + let mut rows = stmt + .query_map([id], |row| Ok(row_to_route(row))) + .map_err(|e| AppError::Database(e.to_string()))?; + + rows.next() + .transpose() + .map_err(|e| AppError::Database(e.to_string())) + } + + /// 创建模型路由(生成 UUID id,验证 provider_id 存在) + pub fn create_model_route(&self, route: &ModelRoute) -> Result { + let conn = lock_conn!(self.conn); + + let provider_exists: bool = conn + .query_row( + "SELECT COUNT(*) > 0 FROM providers WHERE id = ?1 AND app_type = ?2", + rusqlite::params![&route.provider_id, &route.app_type], + |row| row.get(0), + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + if !provider_exists { + return Err(AppError::Database(format!( + "provider '{}' not found for app '{}'", + route.provider_id, route.app_type + ))); + } + + let id = if route.id.is_empty() { + uuid::Uuid::new_v4().to_string() + } else { + route.id.clone() + }; + + let mut stmt = conn + .prepare(&format!( + "INSERT INTO model_routes (id, app_type, pattern, provider_id, priority, enabled) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + RETURNING {SELECT_COLS}" + )) + .map_err(|e| AppError::Database(e.to_string()))?; + + stmt.query_row( + rusqlite::params![ + &id, + &route.app_type, + &route.pattern, + &route.provider_id, + route.priority, + route.enabled as i32, + ], + |row| Ok(row_to_route(row)), + ) + .map_err(|e| AppError::Database(e.to_string())) + } + + /// 更新模型路由 + pub fn update_model_route(&self, id: &str, route: &ModelRoute) -> Result { + let conn = lock_conn!(self.conn); + + let current_provider: String = conn + .query_row( + "SELECT provider_id FROM model_routes WHERE id = ?1", + [id], + |row| row.get(0), + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + if route.provider_id != current_provider { + let provider_exists: bool = conn + .query_row( + "SELECT COUNT(*) > 0 FROM providers WHERE id = ?1 AND app_type = ?2", + rusqlite::params![&route.provider_id, &route.app_type], + |row| row.get(0), + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + if !provider_exists { + return Err(AppError::Database(format!( + "provider '{}' not found for app '{}'", + route.provider_id, route.app_type + ))); + } + } + + let mut stmt = conn + .prepare(&format!( + "UPDATE model_routes SET + pattern = ?1, provider_id = ?2, priority = ?3, enabled = ?4, + updated_at = datetime('now') + WHERE id = ?5 + RETURNING {SELECT_COLS}" + )) + .map_err(|e| AppError::Database(e.to_string()))?; + + stmt.query_row( + rusqlite::params![ + &route.pattern, + &route.provider_id, + route.priority, + route.enabled as i32, + id, + ], + |row| Ok(row_to_route(row)), + ) + .map_err(|e| AppError::Database(e.to_string())) + } + + /// 删除模型路由 + pub fn delete_model_route(&self, id: &str) -> Result<(), AppError> { + let conn = lock_conn!(self.conn); + + let changes = conn + .execute("DELETE FROM model_routes WHERE id = ?1", [id]) + .map_err(|e| AppError::Database(e.to_string()))?; + + if changes == 0 { + return Err(AppError::Database("model_route not found".to_string())); + } + + Ok(()) + } + + /// 切换模型路由的启用状态 + pub fn toggle_model_route(&self, id: &str) -> Result { + let conn = lock_conn!(self.conn); + + let mut stmt = conn + .prepare(&format!( + "UPDATE model_routes SET + enabled = NOT enabled, + updated_at = datetime('now') + WHERE id = ?1 + RETURNING {SELECT_COLS}" + )) + .map_err(|e| AppError::Database(e.to_string()))?; + + stmt.query_row([id], |row| Ok(row_to_route(row))) + .map_err(|e| AppError::Database(e.to_string())) + } + + /// 记录一次命中(增加 hit_count 并更新 last_hit_at) + /// 使用 UPDATE 而非事务,性能更好;last_hit_at 只在每次调用时更新(不频繁) + pub fn record_model_route_hit(&self, id: &str) -> Result<(), AppError> { + let conn = lock_conn!(self.conn); + + let changes = conn + .execute( + "UPDATE model_routes SET + hit_count = hit_count + 1, + last_hit_at = datetime('now') + WHERE id = ?1", + [id], + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + if changes == 0 { + return Err(AppError::Database("model_route not found".to_string())); + } + + Ok(()) + } + + /// 获取所有启用的 model_routes(按 app_type + provider_id 聚合用于仪表盘) + /// 返回 (app_type, provider_id, total_hits) 列表 + pub fn aggregate_route_hits_by_provider(&self) -> Result, AppError> { + let conn = lock_conn!(self.conn); + + let mut stmt = conn + .prepare( + "SELECT app_type, provider_id, SUM(hit_count) as total + FROM model_routes + WHERE enabled = 1 + GROUP BY app_type, provider_id + ORDER BY total DESC", + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + let rows = stmt + .query_map([], |row| { + Ok((row.get(0)?, row.get(1)?, row.get::<_, i64>(2)?)) + }) + .map_err(|e| AppError::Database(e.to_string()))? + .collect::, _>>() + .map_err(|e| AppError::Database(e.to_string()))?; + + Ok(rows) + } +} + +fn row_to_route(row: &rusqlite::Row) -> ModelRoute { + ModelRoute { + id: row.get(0).expect("id"), + app_type: row.get(1).expect("app_type"), + pattern: row.get(2).expect("pattern"), + provider_id: row.get(3).expect("provider_id"), + priority: row.get(4).expect("priority"), + enabled: row.get::<_, i32>(5).expect("enabled") != 0, + hit_count: row.get(6).expect("hit_count"), + last_hit_at: row.get(7).expect("last_hit_at"), + created_at: row.get(8).expect("created_at"), + updated_at: row.get(9).expect("updated_at"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn seed_provider(db: &Database, app_type: &str, id: &str) -> Result<(), AppError> { + let conn = lock_conn!(db.conn); + conn.execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES (?1, ?2, ?3, '{}', '{}')", + rusqlite::params![id, app_type, id], + ) + .map_err(|e| AppError::Database(e.to_string()))?; + Ok(()) + } + + fn test_route(pattern: &str, provider_id: &str, priority: i32) -> ModelRoute { + ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: pattern.into(), + provider_id: provider_id.into(), + priority, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + } + } + + #[test] + fn create_and_get_model_route_roundtrip() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "test-prov")?; + + let created = db.create_model_route(&test_route("*-sonnet", "test-prov", 10))?; + + assert_eq!(created.id.len(), 36); + assert_eq!(created.pattern, "*-sonnet"); + assert_eq!(created.provider_id, "test-prov"); + assert_eq!(created.priority, 10); + assert!(created.enabled); + assert_eq!(created.hit_count, 0); + assert!(created.created_at.is_some()); + + let got = db.get_model_route(&created.id)?; + assert!(got.is_some()); + assert_eq!(got.unwrap().pattern, "*-sonnet"); + + Ok(()) + } + + #[test] + fn create_model_route_rejects_invalid_provider() -> Result<(), AppError> { + let db = Database::memory()?; + + let result = db.create_model_route(&test_route("*-sonnet", "nonexistent", 10)); + assert!(result.is_err()); + + let err = result.unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("provider") && msg.contains("not found"), + "expected provider not found error, got: {msg}" + ); + + Ok(()) + } + + #[test] + fn list_model_routes_ordered_by_priority() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "p1")?; + + let r1 = db.create_model_route(&test_route("mid", "p1", 5))?; + let r2 = db.create_model_route(&test_route("low", "p1", 1))?; + let r3 = db.create_model_route(&test_route("high", "p1", 3))?; + + let routes = db.list_model_routes("claude")?; + assert_eq!(routes.len(), 3); + assert_eq!(routes[0].id, r2.id); + assert_eq!(routes[0].priority, 1); + assert_eq!(routes[1].id, r3.id); + assert_eq!(routes[1].priority, 3); + assert_eq!(routes[2].id, r1.id); + assert_eq!(routes[2].priority, 5); + + Ok(()) + } + + #[test] + fn update_model_route_modifies_fields() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "p1")?; + seed_provider(&db, "claude", "p2")?; + + let created = db.create_model_route(&test_route("*-sonnet", "p1", 10))?; + + let updated = db.update_model_route( + &created.id, + &ModelRoute { + id: created.id.clone(), + app_type: "claude".into(), + pattern: "claude-*".into(), + provider_id: "p2".into(), + priority: 5, + enabled: false, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }, + )?; + + assert_eq!(updated.pattern, "claude-*"); + assert_eq!(updated.provider_id, "p2"); + assert_eq!(updated.priority, 5); + assert!(!updated.enabled); + + let got = db.get_model_route(&created.id)?; + assert!(got.is_some()); + let got = got.unwrap(); + assert_eq!(got.pattern, "claude-*"); + assert!(!got.enabled); + + Ok(()) + } + + #[test] + fn toggle_model_route_flips_enabled() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "p1")?; + + let created = db.create_model_route(&test_route("*-sonnet", "p1", 10))?; + assert!(created.enabled); + + let toggled_off = db.toggle_model_route(&created.id)?; + assert!(!toggled_off.enabled); + + let toggled_on = db.toggle_model_route(&created.id)?; + assert!(toggled_on.enabled); + + Ok(()) + } + + #[test] + fn delete_model_route_removes_row() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "p1")?; + + let created = db.create_model_route(&test_route("*-sonnet", "p1", 10))?; + + db.delete_model_route(&created.id)?; + + let got = db.get_model_route(&created.id)?; + assert!(got.is_none()); + + let result = db.delete_model_route("nonexistent-id"); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn record_model_route_hit_increments_count() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "p1")?; + + let created = db.create_model_route(&test_route("*-sonnet", "p1", 10))?; + assert_eq!(created.hit_count, 0); + + db.record_model_route_hit(&created.id)?; + db.record_model_route_hit(&created.id)?; + db.record_model_route_hit(&created.id)?; + + let got = db.get_model_route(&created.id)?.unwrap(); + assert_eq!(got.hit_count, 3); + assert!(got.last_hit_at.is_some()); + + Ok(()) + } + + #[test] + fn aggregate_route_hits_by_provider_groups_correctly() -> Result<(), AppError> { + let db = Database::memory()?; + seed_provider(&db, "claude", "p1")?; + seed_provider(&db, "claude", "p2")?; + seed_provider(&db, "codex", "cx1")?; + + let r1 = db.create_model_route(&test_route("*sonnet*", "p1", 1))?; + let r2 = db.create_model_route(&test_route("*opus*", "p2", 2))?; + let mut codex_route = test_route("*codex*", "cx1", 1); + codex_route.app_type = "codex".to_string(); + let r3 = db.create_model_route(&codex_route)?; + let _r4 = db.create_model_route(&test_route("disabled", "p1", 5))?; + + // r4 is disabled + db.toggle_model_route( + &db.list_model_routes("claude")? + .iter() + .find(|r| r.pattern == "disabled") + .unwrap() + .id, + )?; + + // 5 hits to claude/p1, 3 to claude/p2, 2 to codex/cx1 + for _ in 0..5 { + db.record_model_route_hit(&r1.id)?; + } + for _ in 0..3 { + db.record_model_route_hit(&r2.id)?; + } + for _ in 0..2 { + db.record_model_route_hit(&r3.id)?; + } + + let agg = db.aggregate_route_hits_by_provider()?; + // r4 was disabled but got 0 hits, so it should be filtered out + assert_eq!(agg.len(), 3); + assert_eq!(agg[0], ("claude".to_string(), "p1".to_string(), 5)); + assert_eq!(agg[1], ("claude".to_string(), "p2".to_string(), 3)); + assert_eq!(agg[2], ("codex".to_string(), "cx1".to_string(), 2)); + + Ok(()) + } +} diff --git a/src-tauri/src/database/mod.rs b/src-tauri/src/database/mod.rs index dfa4cde5..fc7420da 100644 --- a/src-tauri/src/database/mod.rs +++ b/src-tauri/src/database/mod.rs @@ -59,7 +59,7 @@ static DATABASE_PERMISSION_CHECK: Once = Once::new(); /// 当前 Schema 版本号 /// 每次修改表结构时递增,并在 schema.rs 中添加相应的迁移逻辑 -pub(crate) const SCHEMA_VERSION: i32 = 11; +pub(crate) const SCHEMA_VERSION: i32 = 12; fn database_open_flags() -> OpenFlags { OpenFlags::SQLITE_OPEN_READ_WRITE diff --git a/src-tauri/src/database/schema.rs b/src-tauri/src/database/schema.rs index e4126297..7df8e61f 100644 --- a/src-tauri/src/database/schema.rs +++ b/src-tauri/src/database/schema.rs @@ -264,6 +264,8 @@ impl Database { ) .map_err(|e| AppError::Database(e.to_string()))?; + Self::create_model_routes_table(conn)?; + // 尝试添加 live_takeover_active 列到 proxy_config 表 let _ = conn.execute( "ALTER TABLE proxy_config ADD COLUMN live_takeover_active INTEGER NOT NULL DEFAULT 0", @@ -413,6 +415,11 @@ impl Database { Self::migrate_v10_to_v11(conn)?; Self::set_user_version(conn, 11)?; } + 11 => { + log::info!("迁移数据库从 v11 到 v12(添加模型路由表和命中统计字段)"); + Self::migrate_v11_to_v12(conn)?; + Self::set_user_version(conn, 12)?; + } _ => { return Err(AppError::Database(format!( "未知的数据库版本 {version},无法迁移到 {SCHEMA_VERSION}" @@ -520,6 +527,17 @@ impl Database { "BOOLEAN NOT NULL DEFAULT 0", )?; + // model_routes 统计字段(cc-switch v12 未含,留作向后兼容 + 命中追踪) + if Self::table_exists(conn, "model_routes")? { + Self::add_column_if_missing( + conn, + "model_routes", + "hit_count", + "INTEGER NOT NULL DEFAULT 0", + )?; + Self::add_column_if_missing(conn, "model_routes", "last_hit_at", "TEXT")?; + } + // 添加代理超时配置字段 if Self::table_exists(conn, "proxy_config")? { // 兼容旧版本缺失的基础字段 @@ -1289,6 +1307,56 @@ impl Database { Ok(()) } + /// v11 -> v12 迁移:添加模型路由表和命中统计字段。 + fn migrate_v11_to_v12(conn: &Connection) -> Result<(), AppError> { + Self::create_model_routes_table(conn)?; + log::info!("v11 -> v12 迁移完成:已添加模型路由表和命中统计字段"); + Ok(()) + } + + fn create_model_routes_table(conn: &Connection) -> Result<(), AppError> { + conn.execute( + "CREATE TABLE IF NOT EXISTS model_routes ( + id TEXT PRIMARY KEY, + app_type TEXT NOT NULL, + pattern TEXT NOT NULL, + provider_id TEXT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + enabled INTEGER NOT NULL DEFAULT 1, + hit_count INTEGER NOT NULL DEFAULT 0, + last_hit_at TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + FOREIGN KEY (provider_id, app_type) REFERENCES providers(id, app_type) ON DELETE CASCADE + )", + [], + ) + .map_err(|e| AppError::Database(format!("创建 model_routes 表失败: {e}")))?; + + Self::add_column_if_missing( + conn, + "model_routes", + "hit_count", + "INTEGER NOT NULL DEFAULT 0", + )?; + Self::add_column_if_missing(conn, "model_routes", "last_hit_at", "TEXT")?; + + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_model_routes_lookup + ON model_routes(app_type, enabled, priority DESC, created_at ASC, id ASC)", + [], + ) + .map_err(|e| AppError::Database(format!("创建 model_routes lookup 索引失败: {e}")))?; + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_model_routes_provider + ON model_routes(provider_id, app_type)", + [], + ) + .map_err(|e| AppError::Database(format!("创建 model_routes provider 索引失败: {e}")))?; + + Ok(()) + } + /// 插入默认模型定价数据 /// 格式: (model_id, display_name, input, output, cache_read, cache_creation) /// 注意: model_id 使用短横线格式(如 claude-haiku-4-5),与 API 返回的模型名称标准化后一致 diff --git a/src-tauri/src/database/tests.rs b/src-tauri/src/database/tests.rs index eff1e382..6b0f54cf 100644 --- a/src-tauri/src/database/tests.rs +++ b/src-tauri/src/database/tests.rs @@ -4,6 +4,7 @@ use super::*; use crate::app_config::MultiAppConfig; +use crate::model_route::ModelRoute; use crate::prompt::Prompt; use crate::provider::{Provider, ProviderManager}; use indexmap::IndexMap; @@ -208,13 +209,13 @@ fn schema_migration_sets_user_version_when_missing() { fn schema_migration_rejects_future_version() { let conn = Connection::open_in_memory().expect("open memory db"); Database::create_tables_on_conn(&conn).expect("create tables"); - Database::set_user_version(&conn, SCHEMA_VERSION + 1).expect("set future version"); + Database::set_user_version(&conn, SCHEMA_VERSION + 2).expect("set future version"); let err = Database::apply_schema_migrations_on_conn(&conn).expect_err("should reject higher version"); let message = err.to_string(); assert!(message.contains("由较新版本的 CC Switch 创建")); - assert!(message.contains(&format!("数据库版本: {}", SCHEMA_VERSION + 1))); + assert!(message.contains(&format!("数据库版本: {}", SCHEMA_VERSION + 2))); assert!(message.contains(&format!("最高支持数据库版本: {SCHEMA_VERSION}"))); assert!(message.contains("cc-switch update")); } @@ -227,7 +228,7 @@ fn init_rejects_future_schema_before_creating_tables() { let _guard = ConfigDirEnvGuard::set(temp.path()); let db_path = temp.path().join("cc-switch.db"); let conn = Connection::open(&db_path).expect("open db"); - Database::set_user_version(&conn, SCHEMA_VERSION + 1).expect("set future version"); + Database::set_user_version(&conn, SCHEMA_VERSION + 2).expect("set future version"); drop(conn); let err = match Database::init() { @@ -2288,3 +2289,452 @@ fn model_pricing_upsert_rejects_invalid_values() { assert!(ModelPricingUpdate::new("bad-negative", "Bad Negative", "-1", "1", "0", "0").is_err()); assert!(ModelPricingUpdate::new("", "Blank Model", "1", "1", "0", "0").is_err()); } + +#[test] +fn schema_migration_v10_to_v12_adds_model_routes_table() { + let conn = Connection::open_in_memory().expect("open memory db"); + conn.execute_batch( + r#" + CREATE TABLE providers ( + id TEXT NOT NULL, + app_type TEXT NOT NULL, + name TEXT NOT NULL, + settings_config TEXT NOT NULL, + website_url TEXT, + category TEXT, + created_at INTEGER, + sort_index INTEGER, + notes TEXT, + icon TEXT, + icon_color TEXT, + meta TEXT NOT NULL DEFAULT '{}', + is_current BOOLEAN NOT NULL DEFAULT 0, + in_failover_queue BOOLEAN NOT NULL DEFAULT 0, + PRIMARY KEY (id, app_type) + ); + CREATE TABLE mcp_servers ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + server_config TEXT NOT NULL, + description TEXT, + homepage TEXT, + docs TEXT, + tags TEXT NOT NULL DEFAULT '[]', + enabled_claude BOOLEAN NOT NULL DEFAULT 0, + enabled_codex BOOLEAN NOT NULL DEFAULT 0, + enabled_gemini BOOLEAN NOT NULL DEFAULT 0, + enabled_opencode BOOLEAN NOT NULL DEFAULT 0, + enabled_hermes BOOLEAN NOT NULL DEFAULT 0 + ); + CREATE TABLE skills ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + directory TEXT NOT NULL, + repo_owner TEXT, + repo_name TEXT, + repo_branch TEXT DEFAULT 'main', + readme_url TEXT, + enabled_claude BOOLEAN NOT NULL DEFAULT 0, + enabled_codex BOOLEAN NOT NULL DEFAULT 0, + enabled_gemini BOOLEAN NOT NULL DEFAULT 0, + enabled_opencode BOOLEAN NOT NULL DEFAULT 0, + enabled_hermes BOOLEAN NOT NULL DEFAULT 0, + installed_at INTEGER NOT NULL DEFAULT 0, + content_hash TEXT, + updated_at INTEGER NOT NULL DEFAULT 0 + ); + CREATE TABLE prompts ( + id TEXT NOT NULL, + app_type TEXT NOT NULL, + name TEXT NOT NULL, + content TEXT NOT NULL, + description TEXT, + enabled BOOLEAN NOT NULL DEFAULT 1, + created_at INTEGER, + updated_at INTEGER, + PRIMARY KEY (id, app_type) + ); + CREATE TABLE skill_repos ( + owner TEXT NOT NULL, + name TEXT NOT NULL, + branch TEXT NOT NULL DEFAULT 'main', + enabled BOOLEAN NOT NULL DEFAULT 1, + PRIMARY KEY (owner, name) + ); + CREATE TABLE settings (key TEXT PRIMARY KEY, value TEXT); + CREATE TABLE proxy_config ( + app_type TEXT PRIMARY KEY CHECK (app_type IN ('claude','codex','gemini')), + proxy_enabled INTEGER NOT NULL DEFAULT 0, + listen_address TEXT NOT NULL DEFAULT '127.0.0.1', + listen_port INTEGER NOT NULL DEFAULT 15721, + enable_logging INTEGER NOT NULL DEFAULT 1, + enabled INTEGER NOT NULL DEFAULT 0, + auto_failover_enabled INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 3, + streaming_first_byte_timeout INTEGER NOT NULL DEFAULT 60, + streaming_idle_timeout INTEGER NOT NULL DEFAULT 120, + non_streaming_timeout INTEGER NOT NULL DEFAULT 600, + circuit_failure_threshold INTEGER NOT NULL DEFAULT 4, + circuit_success_threshold INTEGER NOT NULL DEFAULT 2, + circuit_timeout_seconds INTEGER NOT NULL DEFAULT 60, + circuit_error_rate_threshold REAL NOT NULL DEFAULT 0.6, + circuit_min_requests INTEGER NOT NULL DEFAULT 10, + default_cost_multiplier TEXT NOT NULL DEFAULT '1', + pricing_model_source TEXT NOT NULL DEFAULT 'response', + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) + ); + CREATE TABLE proxy_request_logs ( + request_id TEXT PRIMARY KEY, + provider_id TEXT NOT NULL, + app_type TEXT NOT NULL, + model TEXT NOT NULL, + request_model TEXT, + input_tokens INTEGER NOT NULL DEFAULT 0, + output_tokens INTEGER NOT NULL DEFAULT 0, + cache_read_tokens INTEGER NOT NULL DEFAULT 0, + cache_creation_tokens INTEGER NOT NULL DEFAULT 0, + input_cost_usd TEXT NOT NULL DEFAULT '0', + output_cost_usd TEXT NOT NULL DEFAULT '0', + cache_read_cost_usd TEXT NOT NULL DEFAULT '0', + cache_creation_cost_usd TEXT NOT NULL DEFAULT '0', + total_cost_usd TEXT NOT NULL DEFAULT '0', + latency_ms INTEGER NOT NULL, + first_token_ms INTEGER, + duration_ms INTEGER, + status_code INTEGER NOT NULL, + error_message TEXT, + session_id TEXT, + provider_type TEXT, + is_streaming INTEGER NOT NULL DEFAULT 0, + cost_multiplier TEXT NOT NULL DEFAULT '1.0', + created_at INTEGER NOT NULL, + data_source TEXT NOT NULL DEFAULT 'proxy' + ); + CREATE TABLE stream_check_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_id TEXT NOT NULL, + provider_name TEXT NOT NULL, + app_type TEXT NOT NULL, + status TEXT NOT NULL, + success INTEGER NOT NULL, + message TEXT NOT NULL, + response_time_ms INTEGER, + http_status INTEGER, + model_used TEXT, + retry_count INTEGER DEFAULT 0, + tested_at INTEGER NOT NULL + ); + CREATE TABLE model_pricing ( + model_id TEXT PRIMARY KEY, + display_name TEXT NOT NULL, + input_cost_per_million TEXT NOT NULL, + output_cost_per_million TEXT NOT NULL, + cache_read_cost_per_million TEXT NOT NULL DEFAULT '0', + cache_creation_cost_per_million TEXT NOT NULL DEFAULT '0' + ); + INSERT INTO model_pricing ( + model_id, display_name, input_cost_per_million, output_cost_per_million, + cache_read_cost_per_million, cache_creation_cost_per_million + ) VALUES ('test-model', 'Test Model', '1.0', '2.0', '0.1', '0'); + CREATE TABLE proxy_live_backup ( + app_type TEXT PRIMARY KEY, + original_config TEXT NOT NULL, + backed_up_at TEXT NOT NULL + ); + CREATE TABLE usage_daily_rollups ( + date TEXT NOT NULL, + app_type TEXT NOT NULL, + provider_id TEXT NOT NULL, + model TEXT NOT NULL, + request_count INTEGER NOT NULL DEFAULT 0, + success_count INTEGER NOT NULL DEFAULT 0, + input_tokens INTEGER NOT NULL DEFAULT 0, + output_tokens INTEGER NOT NULL DEFAULT 0, + cache_read_tokens INTEGER NOT NULL DEFAULT 0, + cache_creation_tokens INTEGER NOT NULL DEFAULT 0, + total_cost_usd TEXT NOT NULL DEFAULT '0', + avg_latency_ms INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (date, app_type, provider_id, model) + ); + CREATE TABLE session_log_sync ( + file_path TEXT PRIMARY KEY, + last_modified INTEGER NOT NULL, + last_line_offset INTEGER NOT NULL DEFAULT 0, + last_synced_at INTEGER NOT NULL + ); + "#, + ) + .expect("seed v10 schema"); + + Database::set_user_version(&conn, 10).expect("set user_version=10"); + Database::apply_schema_migrations_on_conn(&conn).expect("apply migrations"); + + assert_eq!( + Database::get_user_version(&conn).expect("version after migration"), + SCHEMA_VERSION + ); + + assert!( + Database::table_exists(&conn, "model_routes").expect("check model_routes exists"), + "model_routes table should exist after v10 -> v12 migration" + ); + assert!( + Database::has_column(&conn, "model_routes", "pattern").expect("check pattern column"), + "model_routes.pattern column should exist" + ); + assert!( + Database::has_column(&conn, "model_routes", "priority").expect("check priority column"), + "model_routes.priority column should exist" + ); +} + +#[test] +fn model_route_dao_crud_roundtrip() { + let db = Database::memory().expect("create memory db"); + + // Seed a provider for FK validation + let conn = db.conn.lock().expect("lock conn"); + conn.execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES ('test-prov', 'claude', 'Test Provider', '{}', '{}')", + [], + ) + .expect("seed provider"); + drop(conn); + + // Create + let created = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "*-sonnet".into(), + provider_id: "test-prov".into(), + priority: 10, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create model route"); + + assert_eq!(created.id.len(), 36); // UUID v4 + assert_eq!(created.pattern, "*-sonnet"); + assert_eq!(created.provider_id, "test-prov"); + assert_eq!(created.priority, 10); + assert!(created.enabled); + assert!(created.created_at.is_some()); + + // Get by id + let got = db.get_model_route(&created.id).expect("get model route"); + assert!(got.is_some()); + assert_eq!(got.unwrap().pattern, "*-sonnet"); + + // Create second route + let second = db + .create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "gpt-*".into(), + provider_id: "test-prov".into(), + priority: 20, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create second route"); + + // FK constraint: reject non-existent provider + let result = db.create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "bad-*".into(), + provider_id: "nonexistent".into(), + priority: 1, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "expected 'not found' error, got: {err_msg}" + ); + + // Update + let updated = db + .update_model_route( + &created.id, + &ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "claude-*".into(), + provider_id: "test-prov".into(), + priority: 5, + enabled: false, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }, + ) + .expect("update model route"); + + assert_eq!(updated.pattern, "claude-*"); + assert_eq!(updated.priority, 5); + assert!(!updated.enabled); + + // Toggle + let toggled_off = db.toggle_model_route(&created.id).expect("toggle off"); + assert!(toggled_off.enabled, "toggle off should re-enable"); + + let toggled_on = db.toggle_model_route(&created.id).expect("toggle on"); + assert!(!toggled_on.enabled, "toggle on should disable"); + + // Delete + db.delete_model_route(&created.id) + .expect("delete model route"); + let gone = db.get_model_route(&created.id).expect("get deleted route"); + assert!(gone.is_none()); + + // Clean up the second route + db.delete_model_route(&second.id) + .expect("delete second route"); + + // List ordering: create 3 routes with priorities 5, 1, 3 + db.create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "mid".into(), + provider_id: "test-prov".into(), + priority: 5, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create priority 5"); + db.create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "low".into(), + provider_id: "test-prov".into(), + priority: 1, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create priority 1"); + db.create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "high".into(), + provider_id: "test-prov".into(), + priority: 3, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create priority 3"); + + let routes = db.list_model_routes("claude").expect("list routes"); + assert_eq!(routes.len(), 3); + assert_eq!(routes[0].priority, 1); + assert_eq!(routes[1].priority, 3); + assert_eq!(routes[2].priority, 5); + + // List filtering: create a codex route + let conn2 = db.conn.lock().expect("lock conn"); + conn2 + .execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES ('codex-prov', 'codex', 'Codex Provider', '{}', '{}')", + [], + ) + .expect("seed codex provider"); + drop(conn2); + + db.create_model_route(&ModelRoute { + id: String::new(), + app_type: "codex".into(), + pattern: "*-codex".into(), + provider_id: "codex-prov".into(), + priority: 1, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create codex route"); + + let claude_routes = db.list_model_routes("claude").expect("list claude routes"); + assert_eq!(claude_routes.len(), 3, "only claude routes listed"); + + let codex_routes = db.list_model_routes("codex").expect("list codex routes"); + assert_eq!(codex_routes.len(), 1); + assert_eq!(codex_routes[0].pattern, "*-codex"); +} + +#[test] +fn model_route_cascade_delete_on_provider_removal() { + let db = Database::memory().expect("create memory db"); + + // Seed provider + let conn = db.conn.lock().expect("lock conn"); + conn.execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES ('cascade-prov', 'claude', 'Cascade Provider', '{}', '{}')", + [], + ) + .expect("seed provider"); + drop(conn); + + // Create a model_route pointing to this provider + db.create_model_route(&ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "*-test".into(), + provider_id: "cascade-prov".into(), + priority: 1, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }) + .expect("create model route"); + + assert_eq!( + db.list_model_routes("claude").expect("list routes").len(), + 1 + ); + + // Delete the provider — should cascade delete the model_route + let conn2 = db.conn.lock().expect("lock conn"); + conn2 + .execute( + "DELETE FROM providers WHERE id = 'cascade-prov' AND app_type = 'claude'", + [], + ) + .expect("delete provider"); + drop(conn2); + + let routes = db.list_model_routes("claude").expect("list after cascade"); + assert!( + routes.is_empty(), + "model_routes should be empty after provider cascade delete" + ); +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 0af995d7..4aa3edb0 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -18,6 +18,7 @@ mod import_export; #[allow(dead_code)] mod init_status; mod mcp; +mod model_route; mod openclaw_config; mod opencode_config; mod prompt; @@ -59,6 +60,7 @@ pub use mcp::{ sync_enabled_to_codex, sync_enabled_to_gemini, sync_single_server_to_claude, sync_single_server_to_codex, sync_single_server_to_gemini, }; +pub use model_route::ModelRoute; pub use provider::{Provider, ProviderMeta, UsageScript}; pub use proxy::{ProxyConfig, ProxyServerInfo, ProxyStatus}; pub use services::{ diff --git a/src-tauri/src/model_route.rs b/src-tauri/src/model_route.rs new file mode 100644 index 00000000..35ed6982 --- /dev/null +++ b/src-tauri/src/model_route.rs @@ -0,0 +1,53 @@ +//! 模型路由类型定义 (Model Route type definition) +//! +//! 定义 per-model provider routing 的数据结构,用于根据模型名称模式 +//! 将代理请求路由到不同的 provider。 + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelRoute { + pub id: String, + pub app_type: String, + pub pattern: String, + pub provider_id: String, + pub priority: i32, + pub enabled: bool, + pub hit_count: i64, + pub last_hit_at: Option, + pub created_at: Option, + pub updated_at: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn model_route_serialization_roundtrip_camelcase() { + let route = ModelRoute { + id: "test-id-001".into(), + app_type: "claude".into(), + pattern: "*-sonnet".into(), + provider_id: "test-prov".into(), + priority: 10, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: Some("2025-01-01 00:00:00".into()), + updated_at: Some("2025-01-01 00:00:00".into()), + }; + + let json = serde_json::to_string(&route).expect("serialize"); + assert!(json.contains("\"appType\""), "camelCase: {}", json); + assert!(json.contains("\"providerId\""), "camelCase: {}", json); + assert!(json.contains("\"createdAt\""), "camelCase: {}", json); + assert!(json.contains("\"updatedAt\""), "camelCase: {}", json); + + let deserialized: ModelRoute = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(deserialized.id, "test-id-001"); + assert_eq!(deserialized.created_at, Some("2025-01-01 00:00:00".into())); + assert_eq!(deserialized.updated_at, Some("2025-01-01 00:00:00".into())); + } +} diff --git a/src-tauri/src/proxy/handler_context.rs b/src-tauri/src/proxy/handler_context.rs index 3eabc6c5..14e1e6ea 100644 --- a/src-tauri/src/proxy/handler_context.rs +++ b/src-tauri/src/proxy/handler_context.rs @@ -8,18 +8,36 @@ use crate::provider::Provider; use super::{ error::ProxyError, + model_mapper::provider_has_explicit_role_mapping, provider_router::ProviderRouter, - providers::gemini_shadow::GeminiShadowStore, server::ProxyServerState, session::extract_session_id, types::{AppProxyConfig, CopilotOptimizerConfig, OptimizerConfig, RectifierConfig}, }; +/// Extract the model identifier from a Gemini API path like +/// `/v1beta/models/gemini-2.5-pro:generateContent` or +/// `/v1/models/gemini-2.5-flash:streamGenerateContent`. Returns `None` if +/// the path does not match the expected `models/[:action]` shape. +fn extract_gemini_model_from_path(path: &str) -> Option { + // Find the "models/" segment and take what follows up to ":" or end. + let idx = path.find("/models/")?; + let after = &path[idx + "/models/".len()..]; + let end = after.find([':', '?', '/']).unwrap_or(after.len()); + let model = &after[..end]; + if model.is_empty() { + None + } else { + Some(model.to_string()) + } +} + pub struct HandlerContext { pub start_time: Instant, pub state: ProxyServerState, pub app_type: AppType, pub provider_router: Arc, + pub route_source: Option, providers: Vec, pub app_proxy: AppProxyConfig, pub rectifier_config: RectifierConfig, @@ -37,6 +55,7 @@ impl HandlerContext { app_type: AppType, headers: &HeaderMap, body: &Value, + path: &str, ) -> Result { let _ = crate::settings::reload_settings(); let current_provider_id_at_start = @@ -48,7 +67,73 @@ impl HandlerContext { let start_time = Instant::now(); let provider_router = state.provider_router.clone(); - let providers = provider_router.select_providers(app_type.as_str()).await?; + let model_router = state.model_router.clone(); + // Gemini 请求的 model 在 URI 路径中(如 /v1beta/models/gemini-2.5-pro:generateContent), + // 标准 Claude/Codex/OpenAI 请求的 model 在 JSON body 中。 + let request_model = body + .get("model") + .and_then(|value| value.as_str()) + .map(|s| s.to_string()) + .or_else(|| extract_gemini_model_from_path(path)) + .unwrap_or_else(|| "unknown".to_string()); + + let manual_provider = current_provider_id_at_start + .is_empty() + .then_some(None) + .unwrap_or_else(|| { + state + .db + .get_provider_by_id(¤t_provider_id_at_start, app_type.as_str()) + .ok() + .flatten() + }); + + // A manual Claude provider switch writes role-model mappings into live config + // (for example client-visible aliases mapped to provider-specific upstream + // models). Treat that selected provider as the user's active choice and let + // normal-priority automatic routes yield to it. + let manual_role_provider = if matches!(app_type, AppType::Claude) { + manual_provider + .clone() + .filter(|provider| provider_has_explicit_role_mapping(provider, &request_model)) + } else { + None + }; + + // Model route matching first. The router compares generic route priority + // against the active manual provider choice; it does not special-case model + // families or provider names. + let (providers, route_source) = match model_router + .match_route_respecting_manual_provider( + app_type.as_str(), + &request_model, + manual_role_provider.as_ref(), + ) + .await + { + Ok(Some((_route_id, provider))) => (vec![provider], Some("model_route".to_string())), + Ok(None) => { + if let Some(provider) = manual_role_provider { + // No model route matched — use manual role mapping as fallback + (vec![provider], Some("manual_provider_model".to_string())) + } else { + // RT-04: no match, fallback to existing ProviderRouter + let providers = provider_router.select_providers(app_type.as_str()).await?; + (providers, None) + } + } + Err(e) => { + if let Some(provider) = manual_role_provider { + log::warn!("model route lookup failed: {e}, using manual role mapping"); + (vec![provider], Some("manual_provider_model".to_string())) + } else { + // RT-05: match_route error (DB error), log warning and fallback + log::warn!("model route lookup failed: {e}, falling back to provider router"); + let providers = provider_router.select_providers(app_type.as_str()).await?; + (providers, None) + } + } + }; let app_proxy = state .db @@ -63,11 +148,6 @@ impl HandlerContext { let rectifier_config = state.db.get_rectifier_config().unwrap_or_default(); let optimizer_config = state.db.get_optimizer_config().unwrap_or_default(); let copilot_optimizer_config = state.db.get_copilot_optimizer_config().unwrap_or_default(); - let request_model = body - .get("model") - .and_then(|value| value.as_str()) - .unwrap_or("unknown") - .to_string(); let session_result = extract_session_id(headers, body, app_type.as_str()); Ok(Self { @@ -75,6 +155,7 @@ impl HandlerContext { state: state.clone(), app_type, provider_router, + route_source, providers, app_proxy, rectifier_config, @@ -134,11 +215,18 @@ mod tests { use serde_json::json; use serial_test::serial; + use std::collections::HashMap; use std::env; use tempfile::TempDir; use tokio::sync::RwLock; - use crate::{database::Database, proxy::types::ProxyConfig}; + use crate::{ + database::Database, + proxy::{ + model_router::ModelRouter, providers::gemini_shadow::GeminiShadowStore, + types::ProxyConfig, + }, + }; struct TempHome { #[allow(dead_code)] @@ -214,9 +302,11 @@ mod tests { status: Arc::new(RwLock::new(Default::default())), start_time: Arc::new(RwLock::new(None)), current_providers: Arc::new(RwLock::new(Default::default())), - provider_router: Arc::new(ProviderRouter::new(db)), + provider_router: Arc::new(ProviderRouter::new(db.clone())), + model_router: Arc::new(ModelRouter::new(db)), codex_chat_history: Arc::new(Default::default()), gemini_shadow: Arc::new(GeminiShadowStore::default()), + provider_token_map: Arc::new(RwLock::new(HashMap::new())), } } @@ -250,6 +340,7 @@ mod tests { AppType::Claude, &HeaderMap::new(), &json!({"model": "claude-3-7-sonnet-20250219"}), + "", ) .await .expect("load handler context"); @@ -290,6 +381,7 @@ mod tests { AppType::Claude, &HeaderMap::new(), &json!({"model": "claude-3-7-sonnet-20250219"}), + "", ) .await .expect("load handler context"); @@ -331,6 +423,7 @@ mod tests { AppType::Claude, &HeaderMap::new(), &json!({"model": "claude-3-7-sonnet-20250219"}), + "", ) .await }) @@ -349,4 +442,240 @@ mod tests { assert_eq!(context.providers()[0].id, "claude-failover"); assert_eq!(context.current_provider_id_at_start, "claude-current"); } + + #[tokio::test] + #[serial(home_settings)] + async fn model_route_match_bypasses_failover_queue() { + let _home = TempHome::new(); + let db = Arc::new(Database::memory().expect("create memory database")); + let current = test_provider("claude-current", 1); + let failover = test_provider("claude-failover", 0); + + db.save_provider("claude", ¤t) + .expect("save current provider"); + db.save_provider("claude", &failover) + .expect("save failover provider"); + db.set_current_provider("claude", ¤t.id) + .expect("set current provider"); + + // Enable auto failover so select_providers would normally return the queue + let mut config = db + .get_proxy_config_for_app("claude") + .await + .expect("read app proxy config"); + config.enabled = true; + config.auto_failover_enabled = true; + db.update_proxy_config_for_app(config) + .await + .expect("enable auto failover"); + + // Create model route: pattern "*sonnet*" → claude-current (priority 1) + use crate::model_route::ModelRoute; + let route = ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "*sonnet*".into(), + provider_id: "claude-current".into(), + priority: 1, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }; + db.create_model_route(&route).expect("create model route"); + + let state = test_state(db); + let context = HandlerContext::load( + &state, + AppType::Claude, + &HeaderMap::new(), + &json!({"model": "claude-sonnet-4-6"}), + "", + ) + .await + .expect("load handler context"); + + // Model route matched — single provider, not the failover queue + assert_eq!(context.providers().len(), 1); + assert_eq!(context.providers()[0].id, "claude-current"); + assert_eq!(context.route_source, Some("model_route".to_string())); + } + + #[tokio::test] + #[serial(home_settings)] + async fn manual_role_mapping_beats_normal_priority_model_route() { + let _home = TempHome::new(); + let db = Arc::new(Database::memory().expect("create memory database")); + let mut current = test_provider("deepseek-current", 1); + current.name = "DeepSeek".to_string(); + current.settings_config = json!({ + "env": { + "ANTHROPIC_DEFAULT_OPUS_MODEL": "deepseek-v4-pro[1m]" + } + }); + let route_target = test_provider("pp-coder", 0); + + db.save_provider("claude", ¤t) + .expect("save current provider"); + db.save_provider("claude", &route_target) + .expect("save route target provider"); + db.set_current_provider("claude", ¤t.id) + .expect("set current provider"); + + let mut config = db + .get_proxy_config_for_app("claude") + .await + .expect("read app proxy config"); + config.enabled = true; + config.auto_failover_enabled = true; + db.update_proxy_config_for_app(config) + .await + .expect("enable auto failover"); + + use crate::model_route::ModelRoute; + let route = ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "*".into(), + provider_id: route_target.id.clone(), + priority: 0, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }; + db.create_model_route(&route).expect("create model route"); + + let state = test_state(db); + let context = HandlerContext::load( + &state, + AppType::Claude, + &HeaderMap::new(), + &json!({"model": "claude-opus-4-8[1M]"}), + "", + ) + .await + .expect("load handler context"); + + // Normal-priority automatic routes are fallbacks. A manual provider with an + // explicit mapping must keep the request on the selected provider. + assert_eq!(context.providers().len(), 1); + assert_eq!(context.providers()[0].id, "deepseek-current"); + assert_eq!( + context.route_source, + Some("manual_provider_model".to_string()) + ); + } + + #[tokio::test] + #[serial(home_settings)] + async fn higher_priority_model_route_beats_manual_role_mapping() { + let _home = TempHome::new(); + let db = Arc::new(Database::memory().expect("create memory database")); + + // Manual provider: deepseek-current, with explicit opus role mapping + let mut current = test_provider("deepseek-current", 1); + current.name = "DeepSeek".to_string(); + current.settings_config = json!({ + "env": { + "ANTHROPIC_DEFAULT_OPUS_MODEL": "deepseek-v4-pro[1m]" + } + }); + + // Another provider that a *specific* model route should direct to + let specific_target = test_provider("specific-opus-prov", 0); + + db.save_provider("claude", ¤t) + .expect("save current provider"); + db.save_provider("claude", &specific_target) + .expect("save specific target provider"); + db.set_current_provider("claude", ¤t.id) + .expect("set current provider"); + + let mut config = db + .get_proxy_config_for_app("claude") + .await + .expect("read app proxy config"); + config.enabled = true; + config.auto_failover_enabled = true; + db.update_proxy_config_for_app(config) + .await + .expect("enable auto failover"); + + use crate::model_route::ModelRoute; + let specific_route = ModelRoute { + id: String::new(), + app_type: "claude".into(), + pattern: "*".into(), + provider_id: specific_target.id.clone(), + priority: -2, + enabled: true, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + }; + db.create_model_route(&specific_route) + .expect("create specific model route"); + + let state = test_state(db); + let context = HandlerContext::load( + &state, + AppType::Claude, + &HeaderMap::new(), + &json!({"model": "claude-opus-4-8[1M]"}), + "", + ) + .await + .expect("load handler context"); + + // Routes with explicit higher priority can still win over manual selection. + assert_eq!(context.providers().len(), 1); + assert_eq!(context.providers()[0].id, "specific-opus-prov"); + assert_eq!(context.route_source, Some("model_route".to_string())); + } + + #[tokio::test] + #[serial(home_settings)] + async fn no_model_route_falls_back_to_provider_router() { + let _home = TempHome::new(); + let db = Arc::new(Database::memory().expect("create memory database")); + let current = test_provider("claude-current", 1); + let failover = test_provider("claude-failover", 0); + + db.save_provider("claude", ¤t) + .expect("save current provider"); + db.save_provider("claude", &failover) + .expect("save failover provider"); + db.set_current_provider("claude", ¤t.id) + .expect("set current provider"); + + let mut config = db + .get_proxy_config_for_app("claude") + .await + .expect("read app proxy config"); + config.enabled = true; + config.auto_failover_enabled = true; + db.update_proxy_config_for_app(config) + .await + .expect("enable auto failover"); + + // No model route matches "gemini-2.5-pro" + let state = test_state(db); + let context = HandlerContext::load( + &state, + AppType::Claude, + &HeaderMap::new(), + &json!({"model": "gemini-2.5-pro"}), + "", + ) + .await + .expect("load handler context"); + + // Falls back to normal ProviderRouter behavior (failover queue) + assert_eq!(context.providers()[0].id, "claude-failover"); + assert_eq!(context.route_source, None); + } } diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index fd7af934..6d7a8b0a 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -10,6 +10,8 @@ use std::time::{Duration, Instant}; use crate::{app_config::AppType, provider::Provider}; +use super::model_mapper::strip_one_m_suffix_for_upstream; + use super::{ error::ProxyError, forwarder::{ForwardOptions, RequestForwarder}, @@ -38,6 +40,160 @@ pub async fn get_status(State(state): State) -> impl IntoRespo Json(state.snapshot_status().await) } +/// Handle `GET /v1/models` — return merged model list from model routes +/// and provider env configs. +/// +/// Emits a protocol superset (Anthropic + OpenAI) so that both +/// Anthropic clients (via `ANTHROPIC_BASE_URL`) and OpenAI-style +/// clients can consume the response. +pub async fn handle_models(State(state): State) -> impl IntoResponse { + let db = state.db; + let app_type = "claude"; + + let mut model_ids: Vec = Vec::new(); + + // 1. Collect model names from all providers' env config + if let Ok(providers) = db.get_all_providers(app_type) { + for provider in providers.values() { + let env = provider.settings_config.get("env"); + if let Some(env) = env { + let keys = [ + "ANTHROPIC_DEFAULT_OPUS_MODEL", + "ANTHROPIC_DEFAULT_SONNET_MODEL", + "ANTHROPIC_DEFAULT_HAIKU_MODEL", + "ANTHROPIC_MODEL", + ]; + for key in &keys { + if let Some(val) = env + .get(*key) + .and_then(|v| v.as_str()) + .filter(|v| !v.is_empty()) + { + let cleaned = strip_one_m_suffix_for_upstream(val).to_string(); + if !model_ids.contains(&cleaned) { + model_ids.push(cleaned); + } + } + } + } + } + } + + // 2. Add standard Claude role models from route patterns + if let Ok(routes) = db.list_model_routes(app_type) { + for route in &routes { + if !route.enabled { + continue; + } + let pattern_lower = route.pattern.trim().to_ascii_lowercase(); + let standard_models = match pattern_lower.as_str() { + "*haiku*" | "haiku" => vec!["claude-haiku-4-5-20251001"], + "*sonnet*" | "sonnet" => vec!["claude-sonnet-4-6"], + "*opus*" | "opus" => vec!["claude-opus-4-8"], + _ => Vec::new(), + }; + for m in standard_models { + if !model_ids.contains(&m.to_string()) { + model_ids.push(m.to_string()); + } + } + } + } + + // 3. Build protocol superset: Anthropic + OpenAI fields + let data: Vec = model_ids + .iter() + .map(|id| { + let display_name = model_display_name(id); + json!({ + // Anthropic fields + "type": "model", + "display_name": display_name, + "created_at": "2025-01-01T00:00:00Z", + // OpenAI fields + "id": id, + "object": "model", + "created": 1700000000, + "owned_by": "cc-switch" + }) + }) + .collect(); + + let first_id = model_ids.first().cloned(); + let last_id = model_ids.last().cloned(); + + Json(json!({ + // Anthropic pagination + "type": "page", + "has_more": false, + "first_id": first_id, + "last_id": last_id, + // OpenAI + "object": "list", + "data": data + })) +} + +/// Map a model id to a human-readable display name for Anthropic's +/// `display_name` field on GET /v1/models. +fn model_display_name(id: &str) -> String { + // Some common well-known model patterns + let mapping: &[(&str, &str)] = &[ + ("claude-opus-4-8-20250514", "Claude 4.8 Opus"), + ("claude-opus-4-8", "Claude 4.8 Opus"), + ("claude-sonnet-4-6-20250514", "Claude 4.6 Sonnet"), + ("claude-sonnet-4-6", "Claude 4.6 Sonnet"), + ("claude-haiku-4-5-20251001", "Claude 4.5 Haiku"), + ("claude-haiku-4-5", "Claude 4.5 Haiku"), + ("claude-opus-4-5-20251101", "Claude 4.5 Opus"), + ("claude-opus-4-5", "Claude 4.5 Opus"), + ("claude-sonnet-4-5-20250915", "Claude 4.5 Sonnet"), + ("claude-sonnet-4-5", "Claude 4.5 Sonnet"), + ("claude-haiku-3-5-20250112", "Claude 3.5 Haiku"), + ("claude-haiku-3-5", "Claude 3.5 Haiku"), + ("deepseek-v4-pro", "DeepSeek V4 Pro"), + ("deepseek-v4", "DeepSeek V4"), + ("deepseek-v3-1", "DeepSeek V3.1"), + ("deepseek-v3", "DeepSeek V3"), + ("deepseek-r1", "DeepSeek R1"), + ("gpt-5", "GPT-5"), + ("gpt-5-mini", "GPT-5 Mini"), + ("gpt-5-nano", "GPT-5 Nano"), + ("gpt-4.1", "GPT-4.1"), + ("gpt-4.1-mini", "GPT-4.1 Mini"), + ("gpt-4.1-nano", "GPT-4.1 Nano"), + ("gemini-3.0-pro", "Gemini 3.0 Pro"), + ("gemini-2.5-pro", "Gemini 2.5 Pro"), + ("gemini-2.5-flash", "Gemini 2.5 Flash"), + ("gemini-2.5-flash-lite", "Gemini 2.5 Flash Lite"), + ("minimax-m2.5", "MiniMax M2.5"), + ("minimax-m1", "MiniMax M1"), + ("kimi-k2.5", "Kimi K2.5"), + ("kimi-k2", "Kimi K2"), + ("qwen3-coder", "Qwen3 Coder"), + ("qwen3-235b", "Qwen3 235B"), + ]; + + let id_lower = id.to_ascii_lowercase(); + for (pattern, name) in mapping { + if id_lower == *pattern { + return name.to_string(); + } + } + + // Fallback: title-case the segments + id.split('-') + .map(|seg| { + let mut chars = seg.chars(); + match chars.next() { + None => String::new(), + Some(first) => first.to_uppercase().chain(chars).collect(), + } + }) + .collect::>() + .join(" ") +} + pub async fn handle_messages( State(state): State, headers: HeaderMap, @@ -123,10 +279,11 @@ async fn handle_claude_request( headers: HeaderMap, body: Value, ) -> Response { + let estimated_input_tokens = estimate_tokens_from_value(&body); state - .record_estimated_input_tokens(estimate_tokens_from_value(&body)) + .record_estimated_input_tokens(estimated_input_tokens) .await; - let context = match HandlerContext::load(&state, AppType::Claude, &headers, &body).await { + let context = match HandlerContext::load(&state, AppType::Claude, &headers, &body, "").await { Ok(context) => context, Err(error) => { state.record_request_error(&error).await; @@ -209,6 +366,8 @@ async fn handle_claude_request( app_type: context.app_type.clone(), provider: forward_result.provider.clone(), current_provider_id_at_start: context.current_provider_id_at_start.clone(), + is_model_routed: context.route_source.as_deref() == Some("model_route"), + estimated_input_tokens, }); let first_byte_timeout = remaining_timeout(first_byte_timeout, request_started_at); let idle_timeout = context.streaming_idle_timeout(); @@ -326,6 +485,8 @@ async fn handle_claude_request( app_type: context.app_type.clone(), provider: provider.clone(), current_provider_id_at_start: context.current_provider_id_at_start.clone(), + is_model_routed: context.route_source.as_deref() == Some("model_route"), + estimated_input_tokens, }); let api_format = super::providers::get_claude_api_format(provider); let response_result = if adapter.needs_transform(provider) { @@ -465,10 +626,11 @@ async fn handle_passthrough_request( app_type: AppType, endpoint: String, ) -> Response { + let estimated_input_tokens = estimate_tokens_from_value(&body); state - .record_estimated_input_tokens(estimate_tokens_from_value(&body)) + .record_estimated_input_tokens(estimated_input_tokens) .await; - let context = match HandlerContext::load(&state, app_type, &headers, &body).await { + let context = match HandlerContext::load(&state, app_type, &headers, &body, &endpoint).await { Ok(context) => context, Err(error) => { state.record_request_error(&error).await; @@ -558,6 +720,8 @@ async fn handle_passthrough_request( app_type: context.app_type.clone(), provider: forward_result.provider.clone(), current_provider_id_at_start: context.current_provider_id_at_start.clone(), + is_model_routed: context.route_source.as_deref() == Some("model_route"), + estimated_input_tokens, }); let response_result = match response { super::forwarder::StreamingResponse::Live(response) @@ -667,6 +831,7 @@ async fn handle_passthrough_request( streaming_first_byte_timeout, non_streaming_timeout, codex_tool_context.unwrap_or_default(), + estimated_input_tokens, ) .await; } @@ -709,6 +874,8 @@ async fn handle_passthrough_request( app_type: context.app_type.clone(), provider: forward_result.provider.clone(), current_provider_id_at_start: context.current_provider_id_at_start.clone(), + is_model_routed: context.route_source.as_deref() == Some("model_route"), + estimated_input_tokens, }); let status = response.status; let request_log = Some(RequestLogContext::from_handler( @@ -930,6 +1097,7 @@ async fn finish_codex_live_aware_response( streaming_first_byte_timeout: Option, non_streaming_timeout: Option, tool_context: super::providers::transform_codex_chat::CodexToolContext, + estimated_input_tokens: u64, ) -> Response { let provider = forward_result.provider; let response = forward_result.response; @@ -938,6 +1106,8 @@ async fn finish_codex_live_aware_response( app_type: context.app_type.clone(), provider: provider.clone(), current_provider_id_at_start: context.current_provider_id_at_start.clone(), + is_model_routed: context.route_source.as_deref() == Some("model_route"), + estimated_input_tokens, }); if super::providers::should_convert_codex_responses_to_chat(&provider, endpoint) { @@ -1131,7 +1301,7 @@ fn remaining_timeout(timeout: Option, started_at: Instant) -> Option Option { + let model_lower = original_model.to_lowercase(); + + if model_lower.contains("haiku") { + return self.haiku_model.clone(); + } + if model_lower.contains("opus") { + return self.opus_model.clone(); + } + if model_lower.contains("sonnet") { + return self.sonnet_model.clone(); + } + + None + } +} + +pub fn provider_has_explicit_role_mapping(provider: &Provider, original_model: &str) -> bool { + let Some(mapped) = + ModelMapping::from_provider(provider).map_explicit_role_model(original_model) + else { + return false; + }; + + mapped.trim() != original_model.trim() } pub fn apply_model_mapping( @@ -186,4 +212,19 @@ mod tests { let result = strip_one_m_suffix_for_upstream_from_body(body); assert_eq!(result["model"], "deepseek-v4-pro"); } + + #[test] + fn detects_explicit_role_mapping_without_using_default_model() { + let mut provider = provider_with_mapping("deepseek-v4-pro [1M]"); + provider.settings_config["env"]["ANTHROPIC_MODEL"] = json!("default-model"); + + assert!(provider_has_explicit_role_mapping( + &provider, + "claude-sonnet-4-6[1M]" + )); + assert!(!provider_has_explicit_role_mapping( + &provider, + "some-custom-model" + )); + } } diff --git a/src-tauri/src/proxy/model_router.rs b/src-tauri/src/proxy/model_router.rs new file mode 100644 index 00000000..997dfdb6 --- /dev/null +++ b/src-tauri/src/proxy/model_router.rs @@ -0,0 +1,582 @@ +//! Model Router — per-model provider routing engine +//! +//! When a model route matches, the request uses the route-targeted provider only (single +//! provider, no failover queue). When no model route matches, the request falls back to +//! existing ProviderRouter logic. +//! +//! Wildcard * in pattern matches zero or more characters in model name, case-insensitively. +//! Multiple matching rules resolve to the one with lowest priority number (highest priority). +//! Disabled rules (enabled=false) are never matched. + +use std::sync::Arc; + +use regex::Regex; + +use crate::database::Database; +use crate::provider::Provider; + +use super::error::ProxyError; + +// Route priority uses lower numbers as higher priority. Manual provider +// selection outranks normal automatic routes (default 0), while an explicitly +// higher-priority route (< -1) can still override it. +const MANUAL_PROVIDER_PRIORITY: i32 = -1; + +pub struct ModelRouter { + db: Arc, +} + +impl ModelRouter { + pub fn new(db: Arc) -> Self { + Self { db } + } + + /// Match a model name against stored model routes for the given app_type. + /// + /// Routes are ordered by priority ASC (lowest number = highest priority). + /// The first enabled route whose pattern matches `model` wins. + /// Returns the matched (route_id, Provider) if found, or None if no route matches. + pub async fn match_route( + &self, + app_type: &str, + model: &str, + ) -> Result, ProxyError> { + self.match_route_internal(app_type, model, None).await + } + + pub async fn match_route_respecting_manual_provider( + &self, + app_type: &str, + model: &str, + manual_provider: Option<&Provider>, + ) -> Result, ProxyError> { + self.match_route_internal(app_type, model, manual_provider) + .await + } + + async fn match_route_internal( + &self, + app_type: &str, + model: &str, + manual_provider: Option<&Provider>, + ) -> Result, ProxyError> { + if model.is_empty() { + return Ok(None); + } + + let routes = self + .db + .list_model_routes(app_type) + .map_err(|e| ProxyError::DatabaseError(format!("list_model_routes: {e}")))?; + + for route in routes { + if !route.enabled { + continue; + } + if should_skip_route_for_manual_provider(route.priority, manual_provider) { + continue; + } + + let regex = match compile_pattern(&route.pattern) { + Ok(re) => re, + Err(_) => { + log::warn!( + "model route pattern '{}' is not a valid pattern, skipping", + route.pattern + ); + continue; + } + }; + + if regex.is_match(model) { + let provider_opt = self + .db + .get_provider_by_id(&route.provider_id, app_type) + .map_err(|e| ProxyError::DatabaseError(format!("get_provider_by_id: {e}")))?; + let Some(provider) = provider_opt else { + log::warn!( + "model route matched but provider '{}' not found for app '{}' (route={}, pattern={})", + route.provider_id, app_type, route.id, route.pattern + ); + continue; + }; + // 记录命中(异步 + spawn_blocking 避免阻塞) + let db = self.db.clone(); + let route_id = route.id.clone(); + let model_str = model.to_string(); + let pattern = route.pattern.clone(); + let provider_name = provider.name.clone(); + let provider_id = provider.id.clone(); + let app_type_owned = app_type.to_string(); + tokio::task::spawn_blocking(move || { + if let Err(e) = db.record_model_route_hit(&route_id) { + log::warn!("failed to record model_route hit: {e}"); + } else { + log::info!( + "model route matched: app={app_type_owned}, model={model_str}, pattern={pattern} → provider={provider_name} (id={provider_id})" + ); + } + }); + return Ok(Some((route.id, provider))); + } + } + + Ok(None) + } +} + +fn should_skip_route_for_manual_provider( + route_priority: i32, + manual_provider: Option<&Provider>, +) -> bool { + manual_provider.is_some() && route_priority >= MANUAL_PROVIDER_PRIORITY +} + +/// Compile a model route pattern into a case-insensitive regex. +/// +/// The only special character is `*`, which becomes `.*`. +/// All other characters are treated as literals (regex meta-characters are escaped). +/// Exact patterns (no `*`) are anchored with `^...$`. +fn compile_pattern(pattern: &str) -> Result { + if !pattern.contains('*') { + // Exact match — anchor and escape + let escaped = regex::escape(pattern); + return Regex::new(&format!("(?i)^{escaped}$")); + } + + // Split on *, escape each segment, join with .* and anchor at the start. + // ^ prevents substring matches (e.g. "claude-*" matching "xclaude-opus"). + // Patterns that do NOT end with '*' are also anchored at the end ($): a + // suffix rule like "*-4-5" then matches only ids ending in "-4-5" and not + // "claude-haiku-4-55". Patterns ending in '*' (e.g. "claude-*", "sonnet*") + // stay open-ended prefix matches; use "*sonnet*" to match a substring. + let ends_with_wild = pattern.ends_with('*'); + let segments: Vec<&str> = pattern.split('*').collect(); + let mut regex_str = String::from("(?i)^"); + for (i, segment) in segments.iter().enumerate() { + if i > 0 { + regex_str.push_str(".*"); + } + regex_str.push_str(®ex::escape(segment)); + } + if !ends_with_wild { + regex_str.push('$'); + } + + Regex::new(®ex_str) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model_route::ModelRoute; + + fn seed_provider(db: &Database, app_type: &str, id: &str) { + // lock_conn! macro expands to a scope that uses AppError — we call the raw + // Mutex::lock to avoid requiring an AppError import here. + let guard = db.conn.lock().unwrap_or_else(|e| e.into_inner()); + guard + .execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES (?1, ?2, ?3, '{}', '{}')", + rusqlite::params![id, app_type, id], + ) + .expect("seed provider"); + } + + fn test_route( + app_type: &str, + pattern: &str, + provider_id: &str, + priority: i32, + enabled: bool, + ) -> ModelRoute { + ModelRoute { + id: String::new(), + app_type: app_type.into(), + pattern: pattern.into(), + provider_id: provider_id.into(), + priority, + enabled, + hit_count: 0, + last_hit_at: None, + created_at: None, + updated_at: None, + } + } + + fn manual_provider(id: &str) -> Provider { + Provider { + id: id.to_string(), + name: id.to_string(), + settings_config: serde_json::json!({}), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: None, + icon: None, + icon_color: None, + in_failover_queue: false, + } + } + + // --- Unit tests for compile_pattern --- + + #[test] + fn compile_pattern_exact_match() { + let re = compile_pattern("claude-sonnet-4-6").expect("compile exact pattern"); + assert!(re.is_match("claude-sonnet-4-6")); + assert!(!re.is_match("claude-sonnet-4-55")); + // Leading/trailing text should not match (anchored) + assert!(!re.is_match("prefix-claude-sonnet-4-6")); + } + + #[test] + fn compile_pattern_star_middle() { + let re = compile_pattern("*sonnet*").expect("compile *sonnet*"); + assert!(re.is_match("claude-sonnet-4-6")); + assert!(re.is_match("sonnet")); + assert!(!re.is_match("opus")); + } + + #[test] + fn compile_pattern_star_suffix() { + let re = compile_pattern("claude-*").expect("compile claude-*"); + assert!(re.is_match("claude-opus-4-8")); + assert!(!re.is_match("gemini-2.5-pro")); + // 锚定保证:前缀匹配,不可中间包含 + assert!(!re.is_match("xclaude-opus")); + } + + #[test] + fn compile_pattern_star_middle_anchored() { + // *sonnet* 加 ^ 锚定后,必须从开头匹配,但 .* 仍允许中间任意内容 + let re = compile_pattern("*sonnet*").expect("compile *sonnet*"); + assert!(re.is_match("sonnet")); + assert!(re.is_match("claude-sonnet-4-6")); + assert!(re.is_match("claude-sonnet")); + // 包含 "sonnet" 的都匹配(.*sonnet.* 语义) + assert!(re.is_match("claude- haikuxxsonnetyy")); + assert!(!re.is_match("claude-haiku-4-6")); + } + + #[test] + fn compile_pattern_prefix_anchor_prevents_substring() { + // claude-* 加 ^ 后,不再匹配 xclaude-opus + let re = compile_pattern("claude-*").expect("compile claude-*"); + assert!(re.is_match("claude-opus-4-8")); + assert!(re.is_match("claude-")); + assert!(!re.is_match("xclaude-opus")); + assert!(!re.is_match("gemini-2.5-pro")); + } + + #[test] + fn compile_pattern_star_prefix() { + let re = compile_pattern("*-4-5").expect("compile *-4-5"); + assert!(re.is_match("claude-haiku-4-5")); + assert!(re.is_match("deepseek-4-5")); + assert!(!re.is_match("claude-haiku-4-6")); + } + + #[test] + fn compile_pattern_regex_meta_chars_escaped() { + // + is a regex quantifier — should be treated as literal + let re = compile_pattern("gpt-4+").expect("compile gpt-4+"); + assert!(re.is_match("gpt-4+")); + assert!(!re.is_match("gpt-4")); + assert!(!re.is_match("gpt-4++")); + } + + // --- Integration tests for match_route (uses in-memory DB) --- + + #[tokio::test] + async fn test_match_route_exact_pattern() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-sonnet"); + + let route = test_route("claude", "claude-sonnet-4-6", "prov-sonnet", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + let result = router + .match_route("claude", "claude-sonnet-4-6") + .await + .expect("match_route"); + assert!(result.is_some()); + assert_eq!(result.unwrap().1.id, "prov-sonnet"); + } + + #[tokio::test] + async fn test_match_route_star_sonnet_star() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-sonnet"); + + let route = test_route("claude", "*sonnet*", "prov-sonnet", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "claude-sonnet-4-6") + .await + .expect("match_route") + .is_some()); + assert!(router + .match_route("claude", "sonnet") + .await + .expect("match_route") + .is_some()); + } + + #[tokio::test] + async fn test_match_route_claude_star() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-claude"); + + let route = test_route("claude", "claude-*", "prov-claude", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "claude-opus-4-8") + .await + .expect("match_route") + .is_some()); + assert!(router + .match_route("claude", "gemini-2.5-pro") + .await + .expect("match_route") + .is_none()); + } + + #[tokio::test] + async fn test_match_route_star_suffix() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-45"); + + let route = test_route("claude", "*-4-5", "prov-45", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "claude-haiku-4-5") + .await + .expect("match_route") + .is_some()); + assert!(router + .match_route("claude", "deepseek-4-5") + .await + .expect("match_route") + .is_some()); + } + + #[tokio::test] + async fn test_match_route_star_suffix_rejects_partial() { + // Regression (Codex P2): "*-4-5" must not match "claude-haiku-4-55". + // Non-trailing-* suffix rules are anchored at the end, so a longer id + // that merely contains "-4-5" as a substring is not matched. + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-45"); + + let route = test_route("claude", "*-4-5", "prov-45", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "claude-haiku-4-55") + .await + .expect("match_route") + .is_none()); + } + + #[tokio::test] + async fn test_match_route_priority() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-high"); + seed_provider(&db, "claude", "prov-low"); + + // Higher priority (lower number) should win + let route_high = test_route("claude", "*sonnet*", "prov-high", 1, true); + let route_low = test_route("claude", "*sonnet*", "prov-low", 10, true); + db.create_model_route(&route_high) + .expect("create high-priority route"); + db.create_model_route(&route_low) + .expect("create low-priority route"); + + let router = ModelRouter::new(db); + let result = router + .match_route("claude", "claude-sonnet-4-6") + .await + .expect("match_route"); + assert!(result.is_some()); + assert_eq!(result.unwrap().1.id, "prov-high"); + } + + #[tokio::test] + async fn test_match_route_disabled_skipped() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-disabled"); + + let route = test_route("claude", "*sonnet*", "prov-disabled", 1, false); + db.create_model_route(&route) + .expect("create disabled route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "claude-sonnet-4-6") + .await + .expect("match_route") + .is_none()); + } + + #[tokio::test] + async fn test_match_route_no_match() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-specific"); + + let route = test_route("claude", "claude-*", "prov-specific", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "gemini-2.5-pro") + .await + .expect("match_route") + .is_none()); + } + + #[tokio::test] + async fn test_match_route_empty_model() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-any"); + + let route = test_route("claude", "*", "prov-any", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + assert!(router + .match_route("claude", "") + .await + .expect("match_route") + .is_none()); + } + + #[tokio::test] + async fn test_match_route_case_insensitive() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-case"); + + let route = test_route("claude", "claude-sonnet-*", "prov-case", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + let result = router + .match_route("claude", "CLAUDE-SONNET-4-6") + .await + .expect("match_route"); + assert!(result.is_some()); + assert_eq!(result.unwrap().1.id, "prov-case"); + } + + #[tokio::test] + async fn test_match_route_regex_meta_chars() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "prov-meta"); + + // gpt-4+ has a literal + — the pattern's + is escaped, not a regex quantifier + let route = test_route("claude", "gpt-4+", "prov-meta", 1, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + let result = router + .match_route("claude", "gpt-4+") + .await + .expect("match_route"); + assert!(result.is_some()); + assert_eq!(result.unwrap().1.id, "prov-meta"); + } + + #[tokio::test] + async fn test_match_route_missing_provider() { + let db = Arc::new(Database::memory().expect("create memory database")); + + // FK constraint prevents create_model_route from referencing a non-existent + // provider. Disable foreign keys to insert a dangling route, then re-enable. + let guard = db.conn.lock().unwrap_or_else(|e| e.into_inner()); + guard + .execute_batch("PRAGMA foreign_keys = OFF") + .expect("disable foreign keys"); + guard + .execute( + "INSERT INTO model_routes (id, app_type, pattern, provider_id, priority, enabled) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + rusqlite::params![ + uuid::Uuid::new_v4().to_string(), + "claude", + "*-missing", + "prov-missing", + 1, + true + ], + ) + .expect("insert dangling model route"); + guard + .execute_batch("PRAGMA foreign_keys = ON") + .expect("re-enable foreign keys"); + drop(guard); + + let router = ModelRouter::new(db); + let result = router + .match_route("claude", "claude-missing") + .await + .expect("match_route"); + // Provider doesn't exist — get_provider_by_id returns None + assert!(result.is_none()); + } + + #[tokio::test] + async fn normal_route_priority_yields_to_manual_provider() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "automatic-provider"); + + let route = test_route("claude", "*", "automatic-provider", 0, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + let manual_provider = manual_provider("manually-selected"); + let result = router + .match_route_respecting_manual_provider( + "claude", + "any-request-model", + Some(&manual_provider), + ) + .await + .expect("match route"); + + assert!(result.is_none()); + } + + #[tokio::test] + async fn explicit_higher_priority_route_can_override_manual_provider() { + let db = Arc::new(Database::memory().expect("create memory database")); + seed_provider(&db, "claude", "explicit-route-provider"); + + let route = test_route("claude", "*", "explicit-route-provider", -2, true); + db.create_model_route(&route).expect("create route"); + + let router = ModelRouter::new(db); + let manual_provider = manual_provider("manually-selected"); + let result = router + .match_route_respecting_manual_provider( + "claude", + "any-request-model", + Some(&manual_provider), + ) + .await + .expect("match route") + .expect("higher-priority route should match"); + + assert_eq!(result.1.id, "explicit-route-provider"); + } +} diff --git a/src-tauri/src/proxy/response_handler.rs b/src-tauri/src/proxy/response_handler.rs index 99d9f592..3811ebb4 100644 --- a/src-tauri/src/proxy/response_handler.rs +++ b/src-tauri/src/proxy/response_handler.rs @@ -28,6 +28,13 @@ pub struct SuccessSyncInfo { pub app_type: AppType, pub provider: Provider, pub current_provider_id_at_start: String, + /// 当为 true 时,跳过 set_current_provider / update_live_backup, + /// 因为 provider 是模型路由命中选中的,不是用户主动切换的。 + pub is_model_routed: bool, + /// 该请求估算的 input token(请求入口算出,随请求带到此处按 provider 归类)。 + /// 用于让 per-provider 活动统计同时覆盖 input 流量,使点阵图 input/output 波形 + /// 都能正确按 provider 着色。 + pub estimated_input_tokens: u64, } impl ResponseHandler { @@ -55,6 +62,15 @@ impl ResponseHandler { state .record_estimated_output_tokens(estimated_output_tokens) .await; + if let Some(ref sync) = success_sync { + state + .record_provider_activity( + &sync.provider.id, + sync.estimated_input_tokens + .saturating_add(estimated_output_tokens), + ) + .await; + } if status.is_success() { if let Some(success_sync) = success_sync { state @@ -62,6 +78,7 @@ impl ResponseHandler { &success_sync.app_type, &success_sync.provider, &success_sync.current_provider_id_at_start, + success_sync.is_model_routed, ) .await; } @@ -281,12 +298,20 @@ impl StreamingOutcomeRecorder { state .record_estimated_output_tokens(estimated_output_tokens) .await; - if let Some(success_sync) = success_sync { + if let Some(ref sync) = success_sync { + state + .record_provider_activity( + &sync.provider.id, + sync.estimated_input_tokens + .saturating_add(estimated_output_tokens), + ) + .await; state .sync_successful_provider_selection( - &success_sync.app_type, - &success_sync.provider, - &success_sync.current_provider_id_at_start, + &sync.app_type, + &sync.provider, + &sync.current_provider_id_at_start, + sync.is_model_routed, ) .await; } @@ -306,12 +331,20 @@ impl StreamingOutcomeRecorder { state .record_estimated_output_tokens(estimated_output_tokens) .await; - if let Some(success_sync) = success_sync { + if let Some(ref sync) = success_sync { + state + .record_provider_activity( + &sync.provider.id, + sync.estimated_input_tokens + .saturating_add(estimated_output_tokens), + ) + .await; state .sync_successful_provider_selection( - &success_sync.app_type, - &success_sync.provider, - &success_sync.current_provider_id_at_start, + &sync.app_type, + &sync.provider, + &sync.current_provider_id_at_start, + sync.is_model_routed, ) .await; } diff --git a/src-tauri/src/proxy/response_handler/tests.rs b/src-tauri/src/proxy/response_handler/tests.rs index eb3b7977..98a1c987 100644 --- a/src-tauri/src/proxy/response_handler/tests.rs +++ b/src-tauri/src/proxy/response_handler/tests.rs @@ -16,8 +16,8 @@ use crate::{ database::Database, provider::Provider, proxy::{ - provider_router::ProviderRouter, providers::gemini_shadow::GeminiShadowStore, - types::ProxyConfig, + model_router::ModelRouter, provider_router::ProviderRouter, + providers::gemini_shadow::GeminiShadowStore, types::ProxyConfig, }, test_support::TestEnvGuard, }; @@ -52,9 +52,11 @@ fn test_state_with_db(db: Arc) -> ProxyServerState { status: Arc::new(RwLock::new(crate::proxy::types::ProxyStatus::default())), start_time: Arc::new(RwLock::new(None)), current_providers: Arc::new(RwLock::new(HashMap::new())), - provider_router: Arc::new(ProviderRouter::new(db)), + provider_router: Arc::new(ProviderRouter::new(db.clone())), + model_router: Arc::new(ModelRouter::new(db)), codex_chat_history: Arc::new(Default::default()), gemini_shadow: Arc::new(GeminiShadowStore::default()), + provider_token_map: Arc::new(RwLock::new(HashMap::new())), } } @@ -107,6 +109,61 @@ async fn buffered_failures_still_accumulate_output_tokens() { assert_eq!(snapshot.estimated_output_tokens_total, 9); } +#[tokio::test] +async fn buffered_success_records_input_and_output_tokens_per_provider() { + // provider_token_map 应同时累积 input + output token(按服务 provider 归类), + // 让点阵图 input/output 波形都能正确按 provider 着色。 + let state = test_state(); + state.record_request_start().await; + + let provider = test_provider_with_settings( + "zhipu", + "Zhipu", + json!({"apiKey": "zhipu-key", "base_url": "https://zhipu.example"}), + ); + let estimated_input_tokens: u64 = 4_000u64; + let estimated_output_tokens: u64 = 600u64; + + let response = PreparedResponse { + response: Response::builder() + .status(StatusCode::OK) + .body(Body::from("ok response body")) + .expect("response"), + stream_completion: None, + estimated_output_tokens, + upstream_error_summary: None, + body_bytes: Some(Bytes::from_static(b"ok response body")), + }; + + let _ = ResponseHandler::finish_buffered( + &state, + Ok(response), + reqwest::StatusCode::OK, + Some(SuccessSyncInfo { + app_type: AppType::Claude, + provider: provider.clone(), + current_provider_id_at_start: provider.id.clone(), + is_model_routed: false, + estimated_input_tokens, + }), + None, + ) + .await; + settle_tasks().await; + + let snapshot = state.snapshot_status().await; + let recorded = snapshot + .provider_token_map + .get(&provider.id) + .copied() + .unwrap_or(0); + assert!( + recorded >= estimated_input_tokens.saturating_add(estimated_output_tokens), + "provider_token_map should record input+output tokens (>= {}), got {recorded}", + estimated_input_tokens.saturating_add(estimated_output_tokens) + ); +} + #[tokio::test] async fn interrupted_streams_keep_partial_output_estimate() { let state = test_state(); @@ -273,6 +330,8 @@ async fn streaming_success_syncs_failover_state_after_body_drains() { app_type: AppType::Claude, provider: failover.clone(), current_provider_id_at_start: current.id.clone(), + is_model_routed: false, + estimated_input_tokens: 0, }), None, ) diff --git a/src-tauri/src/proxy/server.rs b/src-tauri/src/proxy/server.rs index 6c941612..72b1e5d9 100644 --- a/src-tauri/src/proxy/server.rs +++ b/src-tauri/src/proxy/server.rs @@ -19,6 +19,7 @@ use super::{ circuit_breaker::CircuitBreakerConfig, error::ProxyError, handlers, + model_router::ModelRouter, provider_router::ProviderRouter, providers::codex_chat_history::CodexChatHistoryStore, providers::gemini_shadow::GeminiShadowStore, @@ -35,8 +36,10 @@ pub struct ProxyServerState { pub start_time: Arc>>, pub current_providers: Arc>>, pub provider_router: Arc, + pub model_router: Arc, pub codex_chat_history: Arc, pub gemini_shadow: Arc, + pub provider_token_map: Arc>>, } impl ProxyServerState { @@ -61,6 +64,8 @@ impl ProxyServerState { active_targets.sort_by(|left, right| left.app_type.cmp(&right.app_type)); status.active_targets = active_targets; + status.provider_token_map = self.provider_token_map.read().await.clone(); + status } @@ -91,6 +96,14 @@ impl ProxyServerState { status.estimated_output_tokens_total.saturating_add(tokens); } + /// 按 provider 记录预估 token 数,用于仪表盘点阵图多色展示。 + /// 即使 token 估算为 0(非流式响应无 char_count 估算),也至少记录一次 + /// 命中计数为 1,避免点阵图因 estimated_output_tokens == 0 而完全空。 + pub async fn record_provider_activity(&self, provider_id: &str, tokens: u64) { + let mut map = self.provider_token_map.write().await; + *map.entry(provider_id.to_string()).or_default() += tokens.max(1); + } + pub async fn record_active_target(&self, app_type: &AppType, provider: &Provider) { self.current_providers.write().await.insert( app_type.as_str().to_string(), @@ -107,6 +120,7 @@ impl ProxyServerState { app_type: &AppType, provider: &Provider, current_provider_id_at_start: &str, + is_model_routed: bool, ) { self.record_active_target(app_type, provider).await; @@ -114,6 +128,12 @@ impl ProxyServerState { return; } + // 模型路由选中的 provider 不应切换当前 provider / 更新 live backup。 + // 路由命中是瞬态行为,不应覆盖用户主动选择的 provider。 + if is_model_routed { + return; + } + let takeover_enabled = self .db .get_proxy_config_for_app(app_type.as_str()) @@ -279,9 +299,11 @@ mod tests { status: Arc::new(RwLock::new(ProxyStatus::default())), start_time: Arc::new(RwLock::new(None)), current_providers: Arc::new(RwLock::new(HashMap::new())), - provider_router: Arc::new(ProviderRouter::new(db)), + provider_router: Arc::new(ProviderRouter::new(db.clone())), + model_router: Arc::new(ModelRouter::new(db)), codex_chat_history: Arc::new(CodexChatHistoryStore::default()), gemini_shadow: Arc::new(GeminiShadowStore::default()), + provider_token_map: Arc::new(RwLock::new(HashMap::new())), } } @@ -330,7 +352,7 @@ mod tests { let state = test_state(db.clone()); state - .sync_successful_provider_selection(&AppType::Claude, &failover, ¤t.id) + .sync_successful_provider_selection(&AppType::Claude, &failover, ¤t.id, false) .await; assert_eq!( @@ -381,7 +403,7 @@ mod tests { let state = test_state(db.clone()); state - .sync_successful_provider_selection(&AppType::Claude, ¤t, ¤t.id) + .sync_successful_provider_selection(&AppType::Claude, ¤t, ¤t.id, false) .await; assert_eq!( @@ -435,7 +457,7 @@ mod tests { let state = test_state(db.clone()); state - .sync_successful_provider_selection(&AppType::Claude, &failover, ¤t.id) + .sync_successful_provider_selection(&AppType::Claude, &failover, ¤t.id, false) .await; assert_eq!( @@ -491,6 +513,7 @@ pub struct ProxyServer { impl ProxyServer { pub fn new(config: ProxyConfig, db: Arc) -> Self { let provider_router = Arc::new(ProviderRouter::new(db.clone())); + let model_router = Arc::new(ModelRouter::new(db.clone())); let managed_session_token = std::env::var(PROXY_RUNTIME_SESSION_TOKEN_ENV_KEY) .ok() .filter(|value| !value.trim().is_empty()); @@ -507,8 +530,10 @@ impl ProxyServer { start_time: Arc::new(RwLock::new(None)), current_providers: Arc::new(RwLock::new(HashMap::new())), provider_router, + model_router, codex_chat_history: Arc::new(CodexChatHistoryStore::default()), gemini_shadow: Arc::new(GeminiShadowStore::default()), + provider_token_map: Arc::new(RwLock::new(HashMap::new())), }, shutdown_tx: Arc::new(RwLock::new(None)), server_handle: Arc::new(RwLock::new(None)), @@ -624,6 +649,7 @@ impl ProxyServer { Router::new() .route("/health", get(handlers::health_check)) .route("/status", get(handlers::get_status)) + .route("/v1/models", get(handlers::handle_models)) .route("/v1/messages", post(handlers::handle_messages)) .route("/claude/v1/messages", post(handlers::handle_messages)) .route("/chat/completions", post(handlers::handle_chat_completions)) diff --git a/src-tauri/src/proxy/types.rs b/src-tauri/src/proxy/types.rs index a0deed15..55423589 100644 --- a/src-tauri/src/proxy/types.rs +++ b/src-tauri/src/proxy/types.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use serde::{Deserialize, Serialize}; @@ -103,6 +103,9 @@ pub struct ProxyStatus { /// 当前活跃的 daemon-managed worker 列表 #[serde(default)] pub active_workers: Vec, + /// 按 provider 聚合的预估 token 数(provider_id → token_count) + #[serde(default)] + pub provider_token_map: HashMap, } /// 活跃的 daemon-managed worker 信息