From 6e613991152f4be11a945cde7e52414036d1e725 Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 26 Apr 2026 01:22:21 +0800 Subject: [PATCH 1/8] chore: save local main workspace changes before sync Preserve local .gitignore edits before cherry-picking proxy integration commits from worktree. --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 06305b9..0b2f9d5 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,9 @@ frontend/node_modules/ frontend/.next/ frontend/out/ *.tsbuildinfo +.omc/ +.spec-workflow/ +docs/superpowers/ # Local test sources stay out of git *_test.go @@ -41,3 +44,4 @@ frontend/out/ *.spec.tsx __tests__/ WEBUI_DEVELOPMENT_GUIDE.md +CLAUDE.md From a97a25677f06d8a60e6500231ea46ed33ee4fc8f Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 26 Apr 2026 01:21:44 +0800 Subject: [PATCH 2/8] feat(proxy): add account-scoped proxy and resin sticky routing Add proxy policy/resolver with N2A env precedence and wire account-aware proxy behavior through login, refresh, discovery, request dispatch, and browser fallback paths. Update tests and docs/config examples for HTTP/SOCKS5 and per-account Resin sticky proxy setup. --- README.md | 82 +++++- config.example.json | 60 ++-- internal/app/account_discovery.go | 5 +- internal/app/account_pool.go | 13 +- internal/app/admin_accounts.go | 6 +- internal/app/config.go | 272 ++++++++++++++++++ internal/app/conversations.go | 4 +- internal/app/login_helper.go | 27 +- internal/app/main.go | 2 +- internal/app/notion_client.go | 59 +++- .../app/notion_client_best_effort_test.go | 2 +- .../app/notion_client_browser_transport.go | 15 +- .../notion_client_browser_transport_test.go | 34 ++- internal/app/notion_client_protocol_test.go | 46 ++- internal/app/proxy_policy.go | 121 ++++++++ internal/app/proxy_resolver.go | 131 +++++++++ internal/app/request_dispatch.go | 18 +- internal/app/session_refresh.go | 3 +- 18 files changed, 828 insertions(+), 72 deletions(-) create mode 100644 internal/app/proxy_policy.go create mode 100644 internal/app/proxy_resolver.go diff --git a/README.md b/README.md index fce7709..a789cf4 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,14 @@ ### 本地运行 -```powershell -Set-Location 'E:\WorkSpace\sub2api\chatgpt_register\Nation2API' -& 'D:\Go\bin\go.exe' run .\cmd\notion2api --config .\config.example.json +```bash +go run ./cmd/notion2api --config ./config.example.json ``` ### 本地构建 -```powershell -Set-Location 'E:\WorkSpace\sub2api\chatgpt_register\Nation2API' -& 'D:\Go\bin\go.exe' build .\cmd\notion2api +```bash +go build ./cmd/notion2api ``` ## Docker 部署 @@ -47,17 +45,75 @@ docker compose -f docker-compose.prod.yml up -d --build - Health:`http://127.0.0.1:8787/healthz` - WebUI:`http://127.0.0.1:8787/admin` +## 代理与 Resin 粘性代理 + +### 代理模式 + +`proxy_mode` 支持: + +- `off`:关闭代理 +- `env`:从环境变量读取(优先 `N2A_*`) +- `http`:固定 HTTP 代理 +- `https`:按协议拆分 HTTP/HTTPS 代理 +- `socks5`:SOCKS5/SOCKS5H 代理 +- `resin_forward`:Resin 粘性代理转发 + +### 环境变量优先级(`proxy_mode=env`) + +HTTPS 请求优先顺序: + +1. `N2A_PROXY_HTTPS_URL` +2. `N2A_UPSTREAM_PROXY_HTTPS_URL` +3. `N2A_PROXY_URL` +4. `N2A_UPSTREAM_PROXY_URL` +5. `HTTPS_PROXY` / `https_proxy` +6. `ALL_PROXY` / `all_proxy` + +HTTP 请求优先顺序: + +1. `N2A_PROXY_HTTP_URL` +2. `N2A_UPSTREAM_PROXY_HTTP_URL` +3. `N2A_PROXY_URL` +4. `N2A_UPSTREAM_PROXY_URL` +5. `HTTP_PROXY` / `http_proxy` +6. `ALL_PROXY` / `all_proxy` + +也可以直接用环境变量覆盖配置文件中的代理字段: + +- `N2A_PROXY_MODE` +- `N2A_PROXY_URL` +- `N2A_PROXY_HTTP_URL` +- `N2A_PROXY_HTTPS_URL` +- `N2A_RESIN_ENABLED` +- `N2A_RESIN_URL` +- `N2A_RESIN_PLATFORM` +- `N2A_RESIN_MODE` + +### Resin 粘性代理(按账号隔离) + +每个账号都可以独立设置粘性身份: + +- `accounts[].sticky_proxy_account`:显式设置粘性账号名(推荐) +- 未设置时会回退到邮箱派生值 + +当启用 `resin_forward` 时: + +- 代理认证用户名格式:`.` +- 密码使用 `resin_url` 中 token +- 请求会附带 `X-Resin-Account` 头 + ## 配置说明 建议优先检查这些字段: - `api_key`:OpenAI 兼容接口密钥 - `admin.password`:WebUI 登录密码 -- `upstream_base_url`:上游站点地址 -- `upstream_origin`:上游请求 `Origin` -- `accounts`:账号池配置 -- `active_account`:默认激活账号 -- `storage.sqlite_path`:SQLite 数据库路径 +- `upstream_base_url` / `upstream_origin` +- `proxy_mode` / `proxy_url` / `proxy_http_url` / `proxy_https_url` +- `resin_enabled` / `resin_url` / `resin_platform` / `resin_mode` +- `accounts[*].sticky_proxy_account` +- `accounts` / `active_account` +- `storage.sqlite_path` 可直接参考: @@ -67,8 +123,8 @@ docker compose -f docker-compose.prod.yml up -d --build ## 使用建议 - 首次启动后先访问 `/admin`,确认账号、配置和连通性是否正常 -- 常规本地使用直接运行二进制或 `go run` 即可 -- 需要容器化部署时优先使用 Docker Compose +- 修改管理台前端后需执行 `npm --prefix ./frontend run build:static` +- 调整会话延续与存储时,建议同步检查 `internal/app/sqlite_store.go` 的 schema 与迁移兼容性 ## 开源协议 diff --git a/config.example.json b/config.example.json index b185c0e..aebb2b4 100644 --- a/config.example.json +++ b/config.example.json @@ -1,13 +1,24 @@ { - "probe_json": "C:\\Users\\GALIAIS\\AppData\\Local\\Temp\\TestAdminAccountManualImportCreatesSessionFiles722969220\\001\\manual_example_com\\probe.json", + "probe_json": "probe_files/default/probe.json", "host": "127.0.0.1", - "port": 8791, - "api_key": "notion2api-local-key", + "port": 8787, + "api_key": "change-me-openai-key", "upstream_base_url": "https://www.notion.so", "upstream_origin": "https://www.notion.so", + "upstream_host_header": "", + "upstream_tls_server_name": "", + "upstream_use_env_proxy": false, + "proxy_mode": "off", + "proxy_url": "", + "proxy_http_url": "", + "proxy_https_url": "", + "resin_enabled": false, + "resin_url": "", + "resin_platform": "Default", + "resin_mode": "forward", "model_id": "auto", "default_model": "auto", - "active_account": "manual@example.com", + "active_account": "", "timeout_sec": 180, "poll_interval_sec": 1.5, "poll_max_rounds": 40, @@ -60,24 +71,35 @@ }, "accounts": [ { - "email": "manual@example.com", - "probe_json": "C:\\Users\\GALIAIS\\AppData\\Local\\Temp\\TestAdminAccountManualImportCreatesSessionFiles722969220\\001\\manual_example_com\\probe.json", - "profile_dir": "C:\\Users\\GALIAIS\\AppData\\Local\\Temp\\TestAdminAccountManualImportCreatesSessionFiles722969220\\001\\manual_example_com", - "storage_state_path": "C:\\Users\\GALIAIS\\AppData\\Local\\Temp\\TestAdminAccountManualImportCreatesSessionFiles722969220\\001\\manual_example_com\\storage_state.json", - "pending_state_path": "C:\\Users\\GALIAIS\\AppData\\Local\\Temp\\TestAdminAccountManualImportCreatesSessionFiles722969220\\001\\manual_example_com\\pending_login.json", - "user_id": "user-1", - "user_name": "manual", - "space_id": "space-1", - "space_name": "manual's Space", - "client_version": "23.13.test", - "status": "expired", - "last_error": "open C:\\Users\\GALIAIS\\AppData\\Local\\Temp\\TestAdminAccountManualImportCreatesSessionFiles722969220\\001\\manual_example_com\\probe.json: The system cannot find the path specified.", - "last_login_at": "2026-03-22T14:30:44+08:00" + "email": "alice@example.com", + "probe_json": "probe_files/notion_accounts/alice_example_com/probe.json", + "profile_dir": "probe_files/notion_accounts/alice_example_com", + "storage_state_path": "probe_files/notion_accounts/alice_example_com/storage_state.json", + "pending_state_path": "probe_files/notion_accounts/alice_example_com/pending_login.json", + "proxy_mode": "resin_forward", + "sticky_proxy_account": "alice", + "resin_enabled": true, + "resin_url": "http://127.0.0.1:2260/your-resin-token", + "resin_platform": "Default", + "resin_mode": "forward", + "priority": 100, + "disabled": false + }, + { + "email": "bob@example.com", + "probe_json": "probe_files/notion_accounts/bob_example_com/probe.json", + "profile_dir": "probe_files/notion_accounts/bob_example_com", + "storage_state_path": "probe_files/notion_accounts/bob_example_com/storage_state.json", + "pending_state_path": "probe_files/notion_accounts/bob_example_com/pending_login.json", + "proxy_mode": "socks5", + "proxy_url": "socks5://127.0.0.1:1080", + "priority": 50, + "disabled": false } ], "model_aliases": { - "gpt52": "gpt-5.2", - "gpt54": "gpt-5.4" + "gpt54": "gpt-5.4", + "gpt52": "gpt-5.2" } } diff --git a/internal/app/account_discovery.go b/internal/app/account_discovery.go index a90e42b..0de5322 100644 --- a/internal/app/account_discovery.go +++ b/internal/app/account_discovery.go @@ -158,14 +158,15 @@ func fetchAvailableModelsMetadata(ctx context.Context, client *http.Client, upst return parseProbeModelsBlob(string(raw)), nil } -func discoverImportedAccountMetadata(ctx context.Context, cfg AppConfig, cookies []ProbeCookie, fallback discoveredAccountMetadata) (discoveredAccountMetadata, error) { +func discoverImportedAccountMetadata(ctx context.Context, cfg AppConfig, accountEmail string, cookies []ProbeCookie, fallback discoveredAccountMetadata) (discoveredAccountMetadata, error) { meta := fallback cookies = normalizeProbeCookies(cookies) if len(cookies) == 0 { return meta, fmt.Errorf("cookies are required for auto-discovery") } upstream := cfg.NotionUpstream() - client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream) + resolver := NewProxyResolver(cfg) + client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream, resolver, accountEmail) if err != nil { return meta, err } diff --git a/internal/app/account_pool.go b/internal/app/account_pool.go index ca00c59..a287faf 100644 --- a/internal/app/account_pool.go +++ b/internal/app/account_pool.go @@ -223,6 +223,7 @@ func (s *ServerState) startAutoRelogin(ctx context.Context, cfg AppConfig, accou ProfileDir: account.ProfileDir, PendingPath: account.PendingStatePath, StorageStatePath: account.StorageStatePath, + AccountEmail: account.Email, }) account = mergeAccountWithStatus(cfg, account, status) account = markAccountReloginPending(account, now) @@ -236,13 +237,13 @@ func (s *ServerState) startAutoRelogin(ctx context.Context, cfg AppConfig, accou return cfg, fmt.Errorf("verification code required for %s; auto relogin started (%s)", account.Email, reason) } -func (a *App) runPromptWithSession(ctx context.Context, cfg AppConfig, session SessionInfo, request PromptRunRequest, onDelta func(string) error) (InferenceResult, error) { +func (a *App) runPromptWithSession(ctx context.Context, cfg AppConfig, session SessionInfo, accountEmail string, request PromptRunRequest, onDelta func(string) error) (InferenceResult, error) { if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, onDelta) } - client := newNotionAIClient(session, cfg) + client := newNotionAIClient(session, cfg, accountEmail) if onDelta != nil { - client = newNotionAIStreamingClient(session, cfg) + client = newNotionAIStreamingClient(session, cfg, accountEmail) } execute := func(ctx context.Context, current PromptRunRequest, forward func(string) error) (InferenceResult, error) { if forward == nil { @@ -253,16 +254,16 @@ func (a *App) runPromptWithSession(ctx context.Context, cfg AppConfig, session S return execute(ctx, request, onDelta) } -func (a *App) runPromptWithSessionWithSink(ctx context.Context, cfg AppConfig, session SessionInfo, request PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { +func (a *App) runPromptWithSessionWithSink(ctx context.Context, cfg AppConfig, session SessionInfo, accountEmail string, request PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { if a.runPromptWithSessionSinkOverride != nil { return a.runPromptWithSessionSinkOverride(ctx, cfg, session, request, sink) } if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, sink.Text) } - client := newNotionAIStreamingClient(session, cfg) + client := newNotionAIStreamingClient(session, cfg, accountEmail) if sink.Text == nil && sink.Reasoning == nil && sink.ReasoningWarmup == nil && sink.KeepAlive == nil { - client = newNotionAIClient(session, cfg) + client = newNotionAIClient(session, cfg, accountEmail) } if sink.Reasoning != nil || sink.ReasoningWarmup != nil || sink.KeepAlive != nil { return client.RunPromptStreamWithSink(ctx, request, sink) diff --git a/internal/app/admin_accounts.go b/internal/app/admin_accounts.go index b18dafd..dbc2a8b 100644 --- a/internal/app/admin_accounts.go +++ b/internal/app/admin_accounts.go @@ -456,7 +456,7 @@ func (a *App) handleAdminAccountsTest(w http.ResponseWriter, r *http.Request) { SuppressUpstreamThreadPersistence: true, } conversationID := a.beginConversation("", "admin_account_test", "account_test", prompt, request) - result, err := a.runPromptWithSession(ctx, cfg, session, request, nil) + result, err := a.runPromptWithSession(ctx, cfg, session, activeEmail, request, nil) if err != nil { a.failConversation(conversationID, err) writeAdminUpstreamError(w, err, map[string]any{"account": activeEmail}) @@ -602,7 +602,7 @@ func buildImportedSession(ctx context.Context, cfg AppConfig, req manualAccountI var discovered discoveredAccountMetadata var discoverErr error if shouldTryDiscovery { - discovered, discoverErr = discoverImportedAccountMetadata(ctx, cfg, probe.Cookies, discoveredAccountMetadata{ + discovered, discoverErr = discoverImportedAccountMetadata(ctx, cfg, probe.Email, probe.Cookies, discoveredAccountMetadata{ Email: probe.Email, UserID: probe.UserID, UserName: userName, @@ -786,6 +786,7 @@ func (a *App) handleAdminAccountLoginStart(w http.ResponseWriter, r *http.Reques ProfileDir: account.ProfileDir, PendingPath: account.PendingStatePath, StorageStatePath: account.StorageStatePath, + AccountEmail: account.Email, }) cfg, _, _ = a.State.Snapshot() @@ -848,6 +849,7 @@ func (a *App) handleAdminAccountLoginVerify(w http.ResponseWriter, r *http.Reque PendingPath: account.PendingStatePath, StorageStatePath: account.StorageStatePath, ProbePath: account.ProbeJSON, + AccountEmail: account.Email, }) cfg, _, _ = a.State.Snapshot() diff --git a/internal/app/config.go b/internal/app/config.go index 7a834c7..1fe44f4 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -95,6 +95,15 @@ type NotionAccount struct { LastSuccessAt string `json:"last_success_at,omitempty"` LastRefreshAt string `json:"last_refresh_at,omitempty"` LastReloginAt string `json:"last_relogin_at,omitempty"` + ProxyMode string `json:"proxy_mode,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + ProxyHTTPURL string `json:"proxy_http_url,omitempty"` + ProxyHTTPSURL string `json:"proxy_https_url,omitempty"` + StickyProxyAccount string `json:"sticky_proxy_account,omitempty"` + ResinEnabled bool `json:"resin_enabled,omitempty"` + ResinURL string `json:"resin_url,omitempty"` + ResinPlatform string `json:"resin_platform,omitempty"` + ResinMode string `json:"resin_mode,omitempty"` ConsecutiveFailures int `json:"consecutive_failures,omitempty"` TotalSuccesses int `json:"total_successes,omitempty"` TotalFailures int `json:"total_failures,omitempty"` @@ -122,6 +131,14 @@ type AppConfig struct { UpstreamHost string `json:"upstream_host_header,omitempty"` UpstreamTLSServerName string `json:"upstream_tls_server_name,omitempty"` UpstreamUseEnvProxy bool `json:"upstream_use_env_proxy,omitempty"` + ProxyMode string `json:"proxy_mode,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + ProxyHTTPURL string `json:"proxy_http_url,omitempty"` + ProxyHTTPSURL string `json:"proxy_https_url,omitempty"` + ResinEnabled bool `json:"resin_enabled,omitempty"` + ResinURL string `json:"resin_url,omitempty"` + ResinPlatform string `json:"resin_platform,omitempty"` + ResinMode string `json:"resin_mode,omitempty"` ModelID string `json:"model_id,omitempty"` DefaultModel string `json:"default_model,omitempty"` ActiveAccount string `json:"active_account,omitempty"` @@ -154,6 +171,198 @@ func defaultPromptCognitiveReframingPrefix() string { }, "\n") } +const ( + proxyModeOff = "off" + proxyModeEnv = "env" + proxyModeHTTP = "http" + proxyModeHTTPS = "https" + proxyModeSOCKS5 = "socks5" + proxyModeResinForward = "resin_forward" +) + +var supportedProxyModes = map[string]string{ + proxyModeOff: proxyModeOff, + proxyModeEnv: proxyModeEnv, + proxyModeHTTP: proxyModeHTTP, + proxyModeHTTPS: proxyModeHTTPS, + proxyModeSOCKS5: proxyModeSOCKS5, + proxyModeResinForward: proxyModeResinForward, +} + +func normalizeProxyMode(raw string) string { + mode := strings.ToLower(strings.TrimSpace(raw)) + if mode == "" { + return "" + } + if canonical, ok := supportedProxyModes[mode]; ok { + return canonical + } + return proxyModeOff +} + +func trimProxyFields(mode string, proxyURL string, proxyHTTPURL string, proxyHTTPSURL string, resinURL string, resinPlatform string, resinMode string) (string, string, string, string, string, string, string) { + return normalizeProxyMode(mode), strings.TrimSpace(proxyURL), strings.TrimSpace(proxyHTTPURL), strings.TrimSpace(proxyHTTPSURL), strings.TrimSpace(resinURL), strings.TrimSpace(resinPlatform), strings.TrimSpace(resinMode) +} + +func resolveProxyModeFromN2AEnv() string { + value := strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_PROXY_MODE"), + os.Getenv("N2A_UPSTREAM_PROXY_MODE"), + )) + if value == "" { + return "" + } + return normalizeProxyMode(value) +} + +func resolveProxyURLFromN2AEnv() string { + return strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_PROXY_URL"), + os.Getenv("N2A_UPSTREAM_PROXY_URL"), + )) +} + +func resolveProxyHTTPURLFromN2AEnv() string { + return strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_PROXY_HTTP_URL"), + os.Getenv("N2A_UPSTREAM_PROXY_HTTP_URL"), + )) +} + +func resolveProxyHTTPSURLFromN2AEnv() string { + return strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_PROXY_HTTPS_URL"), + os.Getenv("N2A_UPSTREAM_PROXY_HTTPS_URL"), + )) +} + +func parseBoolEnv(value string) (bool, bool) { + clean := strings.ToLower(strings.TrimSpace(value)) + switch clean { + case "1", "true", "yes", "on": + return true, true + case "0", "false", "no", "off": + return false, true + default: + return false, false + } +} + +func resolveResinEnabledFromN2AEnv() (bool, bool) { + for _, key := range []string{"N2A_RESIN_ENABLED", "N2A_PROXY_RESIN_ENABLED", "N2A_UPSTREAM_RESIN_ENABLED"} { + if parsed, ok := parseBoolEnv(os.Getenv(key)); ok { + return parsed, true + } + } + return false, false +} + +func resolveResinURLFromN2AEnv() string { + return strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_RESIN_URL"), + os.Getenv("N2A_PROXY_RESIN_URL"), + os.Getenv("N2A_UPSTREAM_RESIN_URL"), + )) +} + +func resolveResinPlatformFromN2AEnv() string { + return strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_RESIN_PLATFORM"), + os.Getenv("N2A_PROXY_RESIN_PLATFORM"), + os.Getenv("N2A_UPSTREAM_RESIN_PLATFORM"), + )) +} + +func resolveResinModeFromN2AEnv() string { + return strings.TrimSpace(firstNonEmpty( + os.Getenv("N2A_RESIN_MODE"), + os.Getenv("N2A_PROXY_RESIN_MODE"), + os.Getenv("N2A_UPSTREAM_RESIN_MODE"), + )) +} + +func applyN2AProxyEnv(cfg AppConfig) AppConfig { + if mode := resolveProxyModeFromN2AEnv(); mode != "" { + cfg.ProxyMode = mode + } + if value := resolveProxyURLFromN2AEnv(); value != "" { + cfg.ProxyURL = value + } + if value := resolveProxyHTTPURLFromN2AEnv(); value != "" { + cfg.ProxyHTTPURL = value + } + if value := resolveProxyHTTPSURLFromN2AEnv(); value != "" { + cfg.ProxyHTTPSURL = value + } + if enabled, ok := resolveResinEnabledFromN2AEnv(); ok { + cfg.ResinEnabled = enabled + } + if value := resolveResinURLFromN2AEnv(); value != "" { + cfg.ResinURL = value + } + if value := resolveResinPlatformFromN2AEnv(); value != "" { + cfg.ResinPlatform = value + } + if value := resolveResinModeFromN2AEnv(); value != "" { + cfg.ResinMode = value + } + return cfg +} + +func proxyEnvKeysForScheme(scheme string) []string { + if strings.EqualFold(strings.TrimSpace(scheme), "https") { + return []string{ + "N2A_PROXY_HTTPS_URL", + "N2A_UPSTREAM_PROXY_HTTPS_URL", + "N2A_PROXY_URL", + "N2A_UPSTREAM_PROXY_URL", + "HTTPS_PROXY", + "https_proxy", + "ALL_PROXY", + "all_proxy", + } + } + return []string{ + "N2A_PROXY_HTTP_URL", + "N2A_UPSTREAM_PROXY_HTTP_URL", + "N2A_PROXY_URL", + "N2A_UPSTREAM_PROXY_URL", + "HTTP_PROXY", + "http_proxy", + "ALL_PROXY", + "all_proxy", + } +} + +func resolveProxyURLForSchemeFromEnv(scheme string) string { + for _, key := range proxyEnvKeysForScheme(scheme) { + value := strings.TrimSpace(os.Getenv(key)) + if value != "" { + return value + } + } + return "" +} + +func (cfg AppConfig) normalizedProxyMode() string { + mode := normalizeProxyMode(cfg.ProxyMode) + if mode == "" { + if cfg.UpstreamUseEnvProxy { + return proxyModeEnv + } + return proxyModeOff + } + return mode +} + +func (cfg NotionAccount) normalizedProxyMode(appCfg AppConfig) string { + mode := normalizeProxyMode(cfg.ProxyMode) + if mode != "" { + return mode + } + return appCfg.normalizedProxyMode() +} + func defaultPromptToolboxCapabilityExpansionPrefix() string { return strings.Join([]string{ "You are a versatile general-purpose AI assistant accessed through a workspace integration.", @@ -189,6 +398,7 @@ func defaultConfig() AppConfig { Host: "127.0.0.1", Port: 8787, UpstreamBaseURL: "https://www.notion.so", + ProxyMode: proxyModeOff, ModelID: "auto", TimeoutSec: 180, PollIntervalSec: 1.5, @@ -261,6 +471,22 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.UpstreamOrigin = normalizeBaseURL(firstNonEmpty(cfg.UpstreamOrigin, cfg.UpstreamBaseURL)) cfg.UpstreamHost = strings.TrimSpace(cfg.UpstreamHost) cfg.UpstreamTLSServerName = strings.TrimSpace(cfg.UpstreamTLSServerName) + rawProxyMode := strings.TrimSpace(cfg.ProxyMode) + cfg.ProxyMode, cfg.ProxyURL, cfg.ProxyHTTPURL, cfg.ProxyHTTPSURL, cfg.ResinURL, cfg.ResinPlatform, cfg.ResinMode = trimProxyFields( + cfg.ProxyMode, + cfg.ProxyURL, + cfg.ProxyHTTPURL, + cfg.ProxyHTTPSURL, + cfg.ResinURL, + cfg.ResinPlatform, + cfg.ResinMode, + ) + if cfg.ProxyMode == "" { + cfg.ProxyMode = proxyModeOff + } + if cfg.UpstreamUseEnvProxy && rawProxyMode == "" && cfg.ProxyMode == proxyModeOff { + cfg.ProxyMode = proxyModeEnv + } if cfg.Port <= 0 { cfg.Port = 8787 } @@ -351,6 +577,19 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.Accounts[i].Status = strings.TrimSpace(cfg.Accounts[i].Status) cfg.Accounts[i].LastError = strings.TrimSpace(cfg.Accounts[i].LastError) cfg.Accounts[i].LastLoginAt = strings.TrimSpace(cfg.Accounts[i].LastLoginAt) + cfg.Accounts[i].ProxyMode, cfg.Accounts[i].ProxyURL, cfg.Accounts[i].ProxyHTTPURL, cfg.Accounts[i].ProxyHTTPSURL, cfg.Accounts[i].ResinURL, cfg.Accounts[i].ResinPlatform, cfg.Accounts[i].ResinMode = trimProxyFields( + cfg.Accounts[i].ProxyMode, + cfg.Accounts[i].ProxyURL, + cfg.Accounts[i].ProxyHTTPURL, + cfg.Accounts[i].ProxyHTTPSURL, + cfg.Accounts[i].ResinURL, + cfg.Accounts[i].ResinPlatform, + cfg.Accounts[i].ResinMode, + ) + cfg.Accounts[i].StickyProxyAccount = strings.TrimSpace(cfg.Accounts[i].StickyProxyAccount) + if cfg.Accounts[i].ProxyMode == "" { + cfg.Accounts[i].ProxyMode = cfg.normalizedProxyMode() + } cfg.Accounts[i] = ensureAccountPaths(cfg, cfg.Accounts[i]) } cfg.ProbeJSON = strings.TrimSpace(cfg.ProbeJSON) @@ -544,6 +783,14 @@ func parseCLI() AppConfig { upstreamHost := flag.String("upstream-host-header", "", "override Host header for upstream requests") upstreamTLSServerName := flag.String("upstream-tls-server-name", "", "override TLS SNI server name for upstream requests") upstreamUseEnvProxy := flag.Bool("upstream-use-env-proxy", false, "use HTTP(S)_PROXY/ALL_PROXY from environment for upstream requests") + proxyMode := flag.String("proxy-mode", "", "upstream proxy mode: off/env/http/https/socks5/resin_forward") + proxyURL := flag.String("proxy-url", "", "upstream proxy url") + proxyHTTPURL := flag.String("proxy-http-url", "", "upstream HTTP proxy url") + proxyHTTPSURL := flag.String("proxy-https-url", "", "upstream HTTPS proxy url") + resinEnabled := flag.Bool("resin-enabled", false, "enable resin forwarding") + resinURL := flag.String("resin-url", "", "resin forward url") + resinPlatform := flag.String("resin-platform", "", "resin platform") + resinMode := flag.String("resin-mode", "", "resin mode") modelID := flag.String("model", "", "default public model id") timeoutSec := flag.Int("timeout-sec", 0, "request timeout sec") pollIntervalSec := flag.Float64("poll-interval-sec", 0, "poll interval sec") @@ -580,6 +827,31 @@ func parseCLI() AppConfig { if strings.TrimSpace(*upstreamTLSServerName) != "" { cfg.UpstreamTLSServerName = *upstreamTLSServerName } + if strings.TrimSpace(*proxyMode) != "" { + cfg.ProxyMode = *proxyMode + } + if strings.TrimSpace(*proxyURL) != "" { + cfg.ProxyURL = *proxyURL + } + if strings.TrimSpace(*proxyHTTPURL) != "" { + cfg.ProxyHTTPURL = *proxyHTTPURL + } + if strings.TrimSpace(*proxyHTTPSURL) != "" { + cfg.ProxyHTTPSURL = *proxyHTTPSURL + } + if *resinEnabled { + cfg.ResinEnabled = true + } + if strings.TrimSpace(*resinURL) != "" { + cfg.ResinURL = *resinURL + } + if strings.TrimSpace(*resinPlatform) != "" { + cfg.ResinPlatform = *resinPlatform + } + if strings.TrimSpace(*resinMode) != "" { + cfg.ResinMode = *resinMode + } + cfg = applyN2AProxyEnv(cfg) if *upstreamUseEnvProxy { cfg.UpstreamUseEnvProxy = true } diff --git a/internal/app/conversations.go b/internal/app/conversations.go index 938d697..affb6d7 100644 --- a/internal/app/conversations.go +++ b/internal/app/conversations.go @@ -1092,7 +1092,7 @@ func (a *App) notionClientForAccount(ctx context.Context, accountEmail string) ( } return nil, fmt.Errorf("load account session for %s: %w", email, err) } - return newNotionAIClient(session, cfg), nil + return newNotionAIClient(session, cfg, email), nil } } if fallbackClient != nil { @@ -1102,7 +1102,7 @@ func (a *App) notionClientForAccount(ctx context.Context, accountEmail string) ( if err != nil { return nil, err } - return newNotionAIClient(session, cfg), nil + return newNotionAIClient(session, cfg, ""), nil } func (a *App) deleteConversation(conversationID string) error { diff --git a/internal/app/login_helper.go b/internal/app/login_helper.go index d47d856..4c62ee8 100644 --- a/internal/app/login_helper.go +++ b/internal/app/login_helper.go @@ -29,6 +29,7 @@ type LoginStartRequest struct { ProfileDir string PendingPath string StorageStatePath string + AccountEmail string } type LoginVerifyRequest struct { @@ -38,6 +39,7 @@ type LoginVerifyRequest struct { PendingPath string StorageStatePath string ProbePath string + AccountEmail string } type loginStorageState struct { @@ -162,7 +164,7 @@ func writeLoginStorageState(path string, payload loginStorageState) error { return writePrettyJSONFile(path, payload) } -func newNotionLoginHTTPClient(timeout time.Duration, upstream NotionUpstream) (*http.Client, error) { +func newNotionLoginHTTPClient(timeout time.Duration, upstream NotionUpstream, resolver *ProxyResolver, accountEmail string) (*http.Client, error) { jar, err := cookiejar.New(nil) if err != nil { return nil, err @@ -171,12 +173,27 @@ func newNotionLoginHTTPClient(timeout time.Duration, upstream NotionUpstream) (* if strings.TrimSpace(upstream.TLSServerName) != "" { tlsConfig.ServerName = strings.TrimSpace(upstream.TLSServerName) } + proxyFunc := upstream.ProxyFunc() return &http.Client{ Timeout: timeout, Jar: jar, Transport: &http.Transport{ TLSClientConfig: tlsConfig, - Proxy: upstream.ProxyFunc(), + Proxy: func(req *http.Request) (*url.URL, error) { + if resolver != nil { + proxyURL, _, resolveErr := resolver.ResolveProxyForRequest(accountEmail, req.URL) + if resolveErr != nil { + return nil, resolveErr + } + if proxyURL != nil { + return proxyURL, nil + } + } + if proxyFunc == nil { + return nil, nil + } + return proxyFunc(req) + }, }, }, nil } @@ -628,8 +645,9 @@ func StartEmailLogin(ctx context.Context, cfg AppConfig, req LoginStartRequest) ctx, cancel := helperContext(ctx, cfg) defer cancel() upstream := cfg.NotionUpstream() + resolver := NewProxyResolver(cfg) - client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream) + client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email)) if err != nil { return failLoginState(req.PendingPath, state, err) } @@ -696,8 +714,9 @@ func VerifyEmailLogin(ctx context.Context, cfg AppConfig, req LoginVerifyRequest ctx, cancel := helperContext(ctx, cfg) defer cancel() upstream := cfg.NotionUpstream() + resolver := NewProxyResolver(cfg) - client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream) + client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email)) if err != nil { return failLoginState(req.PendingPath, pending, err) } diff --git a/internal/app/main.go b/internal/app/main.go index c35e69f..632ef3e 100644 --- a/internal/app/main.go +++ b/internal/app/main.go @@ -204,7 +204,7 @@ func (s *ServerState) ApplyConfig(cfg AppConfig) error { log.Printf("[startup] session bootstrap skipped for probe=%s active=%s: %v", probePath, activeEmail, err) } else { session = loadedSession - client = newNotionAIClient(loadedSession, cfg) + client = newNotionAIClient(loadedSession, cfg, activeEmail) if activeEmail != "" { cfg.ProbeJSON = loadedSession.ProbePath } diff --git a/internal/app/notion_client.go b/internal/app/notion_client.go index 8668ddb..38f6087 100644 --- a/internal/app/notion_client.go +++ b/internal/app/notion_client.go @@ -13,6 +13,7 @@ import ( "log" "mime/multipart" "net/http" + "net/url" "os" "path/filepath" "regexp" @@ -369,6 +370,8 @@ func (e *notionAPIError) Error() string { type NotionAIClient struct { Session SessionInfo Config AppConfig + AccountEmail string + ProxyResolver *ProxyResolver Timeout time.Duration PollInterval time.Duration PollMaxRounds int @@ -612,24 +615,40 @@ func (c *NotionAIClient) ensureSessionLiveMetadata(ctx context.Context) { _ = c.persistSessionProbe() } -func newNotionAIClient(session SessionInfo, cfg AppConfig) *NotionAIClient { - return newNotionAIClientWithMode(session, cfg, false) +func newNotionAIClient(session SessionInfo, cfg AppConfig, accountEmail string) *NotionAIClient { + return newNotionAIClientWithMode(session, cfg, accountEmail, false) } -func newNotionAIStreamingClient(session SessionInfo, cfg AppConfig) *NotionAIClient { - return newNotionAIClientWithMode(session, cfg, true) +func newNotionAIStreamingClient(session SessionInfo, cfg AppConfig, accountEmail string) *NotionAIClient { + return newNotionAIClientWithMode(session, cfg, accountEmail, true) } -func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, streaming bool) *NotionAIClient { +func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail string, streaming bool) *NotionAIClient { normalizedCfg := normalizeConfig(cfg) + resolver := NewProxyResolver(normalizedCfg) upstream := normalizedCfg.NotionUpstream() tlsConfig := &tls.Config{InsecureSkipVerify: true} if strings.TrimSpace(upstream.TLSServerName) != "" { tlsConfig.ServerName = strings.TrimSpace(upstream.TLSServerName) } + proxyFunc := upstream.ProxyFunc() transport := &http.Transport{ TLSClientConfig: tlsConfig, - Proxy: upstream.ProxyFunc(), + Proxy: func(req *http.Request) (*url.URL, error) { + if resolver != nil { + proxyURL, _, err := resolver.ResolveProxyForRequest(accountEmail, req.URL) + if err != nil { + return nil, err + } + if proxyURL != nil { + return proxyURL, nil + } + } + if proxyFunc == nil { + return nil, nil + } + return proxyFunc(req) + }, } timeout := requestTimeout(normalizedCfg) clientTimeout := timeout @@ -640,6 +659,8 @@ func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, streaming boo return &NotionAIClient{ Session: session, Config: normalizedCfg, + AccountEmail: strings.TrimSpace(accountEmail), + ProxyResolver: resolver, Timeout: timeout, PollInterval: time.Duration(maxFloat(normalizedCfg.PollIntervalSec, 0.5) * float64(time.Second)), PollMaxRounds: maxInt(normalizedCfg.PollMaxRounds, 1), @@ -959,7 +980,6 @@ func (c *NotionAIClient) postJSONResponseWithReferer(ctx context.Context, url st if strings.TrimSpace(contentType) == "" { contentType = "application/json" } - requestContentType := "application/json" body, err := json.Marshal(payload) if err != nil { return nil, err @@ -983,8 +1003,18 @@ func (c *NotionAIClient) postJSONResponseWithReferer(ctx context.Context, url st } req.Header.Set(key, value) } - req.Header.Set("content-type", requestContentType) - c.captureDebugUpstreamRequest(url, headers, payload, body) + req.Header.Set("content-type", "application/json") + if c.ProxyResolver != nil { + if _, extraHeaders, resolveErr := c.ProxyResolver.ResolveProxyForRequest(c.AccountEmail, req.URL); resolveErr == nil { + for key, value := range extraHeaders { + if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" { + continue + } + req.Header.Set(key, value) + } + } + } + c.captureDebugUpstreamRequestFromHeader(url, req.Header, payload, body) c.Config.NotionUpstream().ApplyHost(req) resp, err := c.HTTPClient.Do(req) if err != nil { @@ -1171,6 +1201,17 @@ func (c *NotionAIClient) captureDebugUpstreamRequest(url string, headers map[str } } +func (c *NotionAIClient) captureDebugUpstreamRequestFromHeader(url string, header http.Header, payload map[string]any, body []byte) { + headers := map[string]string{} + for key, values := range header { + if len(values) == 0 { + continue + } + headers[strings.ToLower(strings.TrimSpace(key))] = strings.TrimSpace(values[0]) + } + c.captureDebugUpstreamRequest(url, headers, payload, body) +} + func isoNowMillis() string { return time.Now().Format("2006-01-02T15:04:05.000Z07:00") } diff --git a/internal/app/notion_client_best_effort_test.go b/internal/app/notion_client_best_effort_test.go index 5248489..bb04158 100644 --- a/internal/app/notion_client_best_effort_test.go +++ b/internal/app/notion_client_best_effort_test.go @@ -24,7 +24,7 @@ func newBestEffortTestClient(baseURL string) *NotionAIClient { Name: "token_v2", Value: "test-cookie", }}, - }, cfg) + }, cfg, "") } func buildThreadErrorRecordMap(threadID string, spaceID string, messageID string, message string, subType string, traceID string) map[string]any { diff --git a/internal/app/notion_client_browser_transport.go b/internal/app/notion_client_browser_transport.go index e182804..6b6dc5d 100644 --- a/internal/app/notion_client_browser_transport.go +++ b/internal/app/notion_client_browser_transport.go @@ -596,12 +596,25 @@ func buildBrowserTransportRequest(client *NotionAIClient, payload map[string]any if originURL == "" { originURL = "https://www.notion.so" } + runURL := client.Config.NotionUpstream().API("runInferenceTranscript") headers := client.baseHeaders("application/x-ndjson", client.Config.NotionUpstream().AIURL()) delete(headers, "cookie") + if client.ProxyResolver != nil { + if parsedRunURL, err := url.Parse(runURL); err == nil { + if _, extraHeaders, resolveErr := client.ProxyResolver.ResolveProxyForRequest(client.AccountEmail, parsedRunURL); resolveErr == nil { + for key, value := range extraHeaders { + if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" { + continue + } + headers[key] = value + } + } + } + } return browserTransportRequest{ OriginURL: originURL, AIURL: client.Config.NotionUpstream().AIURL(), - RunURL: client.Config.NotionUpstream().API("runInferenceTranscript"), + RunURL: runURL, Headers: headers, Payload: payload, Cookies: client.Session.Cookies, diff --git a/internal/app/notion_client_browser_transport_test.go b/internal/app/notion_client_browser_transport_test.go index 59273a3..c4d9179 100644 --- a/internal/app/notion_client_browser_transport_test.go +++ b/internal/app/notion_client_browser_transport_test.go @@ -275,6 +275,38 @@ func waitForBrowserHelperChildPID(t *testing.T, pidFile string) int { } time.Sleep(25 * time.Millisecond) } - t.Fatalf("timed out waiting for child pid file %s", pidFile) return 0 } + +func TestBuildBrowserTransportRequestAddsResinAccountHeaderWhenEnabled(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.UpstreamBaseURL = "https://www.notion.so" + cfg.UpstreamOrigin = "https://www.notion.so" + cfg.ProxyMode = proxyModeResinForward + cfg.ResinEnabled = true + cfg.ResinURL = "http://127.0.0.1:2260/my-token" + cfg.ResinPlatform = "Default" + cfg.Accounts = []NotionAccount{{ + Email: "alice@example.com", + StickyProxyAccount: "alice", + }} + + client := newNotionAIClientWithMode(SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + }, cfg, "alice@example.com", true) + + request, err := buildBrowserTransportRequest(client, map[string]any{"threadId": "thread-test"}) + if err != nil { + t.Fatalf("buildBrowserTransportRequest returned error: %v", err) + } + if got, want := request.Headers[defaultResinAccountHeader], "alice"; got != want { + t.Fatalf("%s = %q, want %q", defaultResinAccountHeader, got, want) + } +} diff --git a/internal/app/notion_client_protocol_test.go b/internal/app/notion_client_protocol_test.go index 8027047..6ee2765 100644 --- a/internal/app/notion_client_protocol_test.go +++ b/internal/app/notion_client_protocol_test.go @@ -28,7 +28,7 @@ func newProtocolTestClient(cfg AppConfig) *NotionAIClient { Name: "token_v2", Value: "test-cookie", }}, - }, cfg) + }, cfg, "") } func transcriptStepValue(t *testing.T, payload map[string]any, stepType string) map[string]any { @@ -173,3 +173,47 @@ func TestSaveContinuationScaffoldOmitsUnretryableErrorBehavior(t *testing.T) { t.Fatalf("expected saveTransactionsFanout payload to omit unretryable_error_behavior") } } + +func TestPostJSONResponseAddsResinAccountHeaderWhenEnabled(t *testing.T) { + capturedHeader := "" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeader = r.Header.Get(defaultResinAccountHeader) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + })) + defer server.Close() + + cfg := defaultConfig() + cfg.UpstreamBaseURL = server.URL + cfg.UpstreamOrigin = server.URL + cfg.ProxyMode = proxyModeResinForward + cfg.ResinEnabled = true + cfg.ResinURL = "http://127.0.0.1:2260/my-token" + cfg.ResinPlatform = "Default" + cfg.Accounts = []NotionAccount{{ + Email: "alice@example.com", + StickyProxyAccount: "alice", + }} + + client := newNotionAIClientWithMode(SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + }, cfg, "alice@example.com", false) + client.HTTPClient.Transport = &http.Transport{} + client.AccountEmail = "alice@example.com" + + if _, err := client.postJSONResponse(context.Background(), server.URL+"/api/v3/markInferenceTranscriptSeen", map[string]any{ + "threadId": "thread-test", + "spaceId": "test-space", + }, "application/json"); err != nil { + t.Fatalf("postJSONResponse failed: %v", err) + } + if got, want := capturedHeader, "alice"; got != want { + t.Fatalf("%s = %q, want %q", defaultResinAccountHeader, got, want) + } +} diff --git a/internal/app/proxy_policy.go b/internal/app/proxy_policy.go new file mode 100644 index 0000000..cc7dd35 --- /dev/null +++ b/internal/app/proxy_policy.go @@ -0,0 +1,121 @@ +package app + +import "strings" + +const ( + defaultResinAccountHeader = "X-Resin-Account" + resinModeForward = "forward" +) + +type ResinPolicy struct { + Enabled bool + URL string + Platform string + Mode string + AccountHeader string +} + +type ProxyPolicy struct { + Mode string + URL string + HTTPURL string + HTTPSURL string + Resin ResinPolicy +} + +func normalizeResinMode(raw string) string { + mode := strings.ToLower(strings.TrimSpace(raw)) + switch mode { + case "", resinModeForward: + return mode + default: + return mode + } +} + +func (cfg AppConfig) ResolveProxyPolicy() ProxyPolicy { + mode := cfg.normalizedProxyMode() + if mode == proxyModeOff && cfg.ResinEnabled { + mode = proxyModeResinForward + } + policy := ProxyPolicy{ + Mode: mode, + URL: strings.TrimSpace(cfg.ProxyURL), + HTTPURL: strings.TrimSpace(cfg.ProxyHTTPURL), + HTTPSURL: strings.TrimSpace(cfg.ProxyHTTPSURL), + Resin: ResinPolicy{ + Enabled: cfg.ResinEnabled, + URL: strings.TrimSpace(cfg.ResinURL), + Platform: strings.TrimSpace(cfg.ResinPlatform), + Mode: normalizeResinMode(cfg.ResinMode), + AccountHeader: defaultResinAccountHeader, + }, + } + if policy.Resin.Mode == "" { + policy.Resin.Mode = resinModeForward + } + if policy.Mode == proxyModeResinForward { + policy.Resin.Enabled = true + } + if policy.Mode == proxyModeEnv { + policy.HTTPURL = firstNonEmpty(resolveProxyURLForSchemeFromEnv("http"), policy.HTTPURL, policy.URL) + policy.HTTPSURL = firstNonEmpty(resolveProxyURLForSchemeFromEnv("https"), policy.HTTPSURL, policy.URL) + } else { + policy.HTTPURL = firstNonEmpty(policy.HTTPURL, policy.URL) + policy.HTTPSURL = firstNonEmpty(policy.HTTPSURL, policy.URL) + } + return policy +} + +func (cfg AppConfig) ResolveProxyPolicyForAccount(email string) ProxyPolicy { + policy := cfg.ResolveProxyPolicy() + account, _, ok := cfg.FindAccount(email) + if !ok { + return policy + } + if mode := normalizeProxyMode(account.ProxyMode); mode != "" { + policy.Mode = mode + } + if value := strings.TrimSpace(account.ProxyURL); value != "" { + policy.URL = value + } + if value := strings.TrimSpace(account.ProxyHTTPURL); value != "" { + policy.HTTPURL = value + } + if value := strings.TrimSpace(account.ProxyHTTPSURL); value != "" { + policy.HTTPSURL = value + } + if account.ResinEnabled { + policy.Resin.Enabled = true + } + if value := strings.TrimSpace(account.ResinURL); value != "" { + policy.Resin.URL = value + } + if value := strings.TrimSpace(account.ResinPlatform); value != "" { + policy.Resin.Platform = value + } + if value := normalizeResinMode(account.ResinMode); value != "" { + policy.Resin.Mode = value + } + if policy.Resin.Mode == "" { + policy.Resin.Mode = resinModeForward + } + if policy.Mode == proxyModeResinForward { + policy.Resin.Enabled = true + } + if policy.Mode == proxyModeEnv { + policy.HTTPURL = firstNonEmpty(resolveProxyURLForSchemeFromEnv("http"), policy.HTTPURL, policy.URL) + policy.HTTPSURL = firstNonEmpty(resolveProxyURLForSchemeFromEnv("https"), policy.HTTPSURL, policy.URL) + } else { + policy.HTTPURL = firstNonEmpty(policy.HTTPURL, policy.URL) + policy.HTTPSURL = firstNonEmpty(policy.HTTPSURL, policy.URL) + } + return policy +} + +func (p ProxyPolicy) proxyURLForScheme(scheme string) string { + if strings.EqualFold(strings.TrimSpace(scheme), "https") { + return strings.TrimSpace(firstNonEmpty(p.HTTPSURL, p.URL)) + } + return strings.TrimSpace(firstNonEmpty(p.HTTPURL, p.URL)) +} diff --git a/internal/app/proxy_resolver.go b/internal/app/proxy_resolver.go new file mode 100644 index 0000000..5aa51ba --- /dev/null +++ b/internal/app/proxy_resolver.go @@ -0,0 +1,131 @@ +package app + +import ( + "fmt" + "net/url" + "strings" +) + +type ProxyResolver struct { + cfg AppConfig +} + +func NewProxyResolver(cfg AppConfig) *ProxyResolver { + return &ProxyResolver{cfg: normalizeConfig(cfg)} +} + +func (r *ProxyResolver) ResolveProxyForRequest(accountEmail string, target *url.URL) (*url.URL, map[string]string, error) { + if r == nil { + return nil, nil, nil + } + if target == nil { + return nil, nil, nil + } + policy := r.cfg.ResolveProxyPolicyForAccount(accountEmail) + mode := normalizeProxyMode(policy.Mode) + if mode == "" { + mode = proxyModeOff + } + headers := map[string]string{} + switch mode { + case proxyModeOff: + return nil, nil, nil + case proxyModeEnv, proxyModeHTTP, proxyModeHTTPS, proxyModeSOCKS5: + raw := policy.proxyURLForScheme(target.Scheme) + if strings.TrimSpace(raw) == "" { + return nil, nil, nil + } + parsed, err := parseProxyURL(raw) + if err != nil { + return nil, nil, err + } + return parsed, nil, nil + case proxyModeResinForward: + proxyURL, stickyAccount, err := resolveResinForwardProxyURL(policy, accountEmail, r.cfg) + if err != nil { + return nil, nil, err + } + if stickyAccount != "" { + headers[policy.Resin.AccountHeader] = stickyAccount + } + if len(headers) == 0 { + return proxyURL, nil, nil + } + return proxyURL, headers, nil + default: + return nil, nil, nil + } +} + +func parseProxyURL(raw string) (*url.URL, error) { + clean := strings.TrimSpace(raw) + if clean == "" { + return nil, nil + } + parsed, err := url.Parse(clean) + if err != nil { + return nil, fmt.Errorf("parse proxy url %q: %w", clean, err) + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + switch scheme { + case "http", "https", "socks5", "socks5h": + return parsed, nil + default: + return nil, fmt.Errorf("unsupported proxy scheme %q", parsed.Scheme) + } +} + +func resolveResinForwardProxyURL(policy ProxyPolicy, email string, cfg AppConfig) (*url.URL, string, error) { + if !policy.Resin.Enabled { + return nil, "", nil + } + baseURL, token, err := splitResinURL(policy.Resin.URL) + if err != nil { + return nil, "", err + } + if token == "" { + return nil, "", fmt.Errorf("resin token missing") + } + platform := strings.TrimSpace(policy.Resin.Platform) + if platform == "" { + platform = "Default" + } + stickyAccount := resinStickyAccountForEmail(cfg, email) + if stickyAccount == "" { + stickyAccount = "account" + } + username := fmt.Sprintf("%s.%s", platform, stickyAccount) + proxyURL := *baseURL + proxyURL.User = url.UserPassword(username, token) + return &proxyURL, stickyAccount, nil +} + +func splitResinURL(raw string) (*url.URL, string, error) { + parsed, err := parseProxyURL(raw) + if err != nil { + return nil, "", err + } + token := strings.Trim(strings.TrimSpace(parsed.Path), "/") + if token == "" && parsed.User != nil { + token, _ = parsed.User.Password() + } + baseURL := *parsed + baseURL.Path = "" + baseURL.RawPath = "" + baseURL.User = nil + baseURL.RawQuery = "" + baseURL.Fragment = "" + return &baseURL, token, nil +} + +func resinStickyAccountForEmail(cfg AppConfig, email string) string { + if account, _, ok := cfg.FindAccount(email); ok { + if value := strings.TrimSpace(account.StickyProxyAccount); value != "" { + return value + } + if value := accountPathSlug(account.Email); value != "" { + return value + } + } + return accountPathSlug(email) +} diff --git a/internal/app/request_dispatch.go b/internal/app/request_dispatch.go index f048b17..fd94585 100644 --- a/internal/app/request_dispatch.go +++ b/internal/app/request_dispatch.go @@ -122,7 +122,7 @@ func (a *App) probeAccountProtocolHealth(ctx context.Context, cfg AppConfig, ses } probeCtx, cancel := context.WithTimeout(ctx, dispatchProtocolProbeTimeout(cfg)) defer cancel() - client := newNotionAIClient(session, cfg) + client := newNotionAIClient(session, cfg, "") _, err := client.listInferenceTranscripts(probeCtx) if isDispatchContextAbort(probeCtx, err) { return nil @@ -188,7 +188,7 @@ func (a *App) runPromptActiveFallback(r *http.Request, request PromptRunRequest, return onDelta(delta) } - result, err := a.runPromptWithSession(ctx, cfg, session, request, wrappedDelta) + result, err := a.runPromptWithSession(ctx, cfg, session, "", request, wrappedDelta) if err == nil { return result, nil } @@ -199,7 +199,7 @@ func (a *App) runPromptActiveFallback(r *http.Request, request PromptRunRequest, if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed); probeErr != nil { return InferenceResult{}, probeErr } - return a.runPromptWithSession(ctx, cfg, refreshed, request, wrappedDelta) + return a.runPromptWithSession(ctx, cfg, refreshed, "", request, wrappedDelta) } } } @@ -240,7 +240,7 @@ func (a *App) runPromptActiveFallbackWithSink(r *http.Request, request PromptRun return sink.EmitKeepAlive() } - result, err := a.runPromptWithSessionWithSink(ctx, cfg, session, request, InferenceStreamSink{ + result, err := a.runPromptWithSessionWithSink(ctx, cfg, session, "", request, InferenceStreamSink{ Text: wrappedText, Reasoning: wrappedReasoning, ReasoningWarmup: wrappedReasoningWarmup, @@ -256,7 +256,7 @@ func (a *App) runPromptActiveFallbackWithSink(r *http.Request, request PromptRun if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed); probeErr != nil { return InferenceResult{}, probeErr } - return a.runPromptWithSessionWithSink(ctx, cfg, refreshed, request, InferenceStreamSink{ + return a.runPromptWithSessionWithSink(ctx, cfg, refreshed, "", request, InferenceStreamSink{ Text: wrappedText, Reasoning: wrappedReasoning, ReasoningWarmup: wrappedReasoningWarmup, @@ -303,7 +303,7 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest account := markAccountDispatchStart(original, time.Now()) session, err := a.loadReadyDispatchSession(ctx, cfg, account) if err == nil { - result, runErr := a.runPromptWithSession(ctx, cfg, session, request, wrappedDelta) + result, runErr := a.runPromptWithSession(ctx, cfg, session, account.Email, request, wrappedDelta) if runErr == nil { result.AccountEmail = account.Email account.UserID = firstNonEmpty(session.UserID, account.UserID) @@ -335,7 +335,7 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest if ok { refreshedSession, loadErr := a.loadReadyDispatchSession(ctx, cfg, refreshedAccount) if loadErr == nil { - result, retryErr := a.runPromptWithSession(ctx, cfg, refreshedSession, request, wrappedDelta) + result, retryErr := a.runPromptWithSession(ctx, cfg, refreshedSession, refreshedAccount.Email, request, wrappedDelta) if retryErr == nil { result.AccountEmail = refreshedAccount.Email refreshedAccount.UserID = firstNonEmpty(refreshedSession.UserID, refreshedAccount.UserID) @@ -432,7 +432,7 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu account := markAccountDispatchStart(original, time.Now()) session, err := a.loadReadyDispatchSession(ctx, cfg, account) if err == nil { - result, runErr := a.runPromptWithSessionWithSink(ctx, cfg, session, request, InferenceStreamSink{ + result, runErr := a.runPromptWithSessionWithSink(ctx, cfg, session, account.Email, request, InferenceStreamSink{ Text: wrappedText, Reasoning: wrappedReasoning, ReasoningWarmup: wrappedReasoningWarmup, @@ -468,7 +468,7 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu if refreshedAccount, _, ok := cfg.FindAccount(account.Email); ok { refreshedSession, loadErr := a.loadReadyDispatchSession(ctx, cfg, refreshedAccount) if loadErr == nil { - result, retryErr := a.runPromptWithSessionWithSink(ctx, cfg, refreshedSession, request, InferenceStreamSink{ + result, retryErr := a.runPromptWithSessionWithSink(ctx, cfg, refreshedSession, refreshedAccount.Email, request, InferenceStreamSink{ Text: wrappedText, Reasoning: wrappedReasoning, ReasoningWarmup: wrappedReasoningWarmup, diff --git a/internal/app/session_refresh.go b/internal/app/session_refresh.go index 23bafcd..9e1d010 100644 --- a/internal/app/session_refresh.go +++ b/internal/app/session_refresh.go @@ -80,7 +80,8 @@ func loadSessionInfoForAccountRefresh(cfg AppConfig, account NotionAccount) (Ses func buildRefreshedSession(ctx context.Context, cfg AppConfig, account NotionAccount, prior SessionInfo) (SessionInfo, error) { upstream := cfg.NotionUpstream() - client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream) + resolver := NewProxyResolver(cfg) + client, err := newNotionLoginHTTPClient(helperTimeout(cfg), upstream, resolver, account.Email) if err != nil { return SessionInfo{}, err } From fe3766399742b968a527614f0bcd0221ca3dc198 Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 26 Apr 2026 01:33:49 +0800 Subject: [PATCH 3/8] update --- .github/pull.yml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .github/pull.yml diff --git a/.github/pull.yml b/.github/pull.yml new file mode 100644 index 0000000..ecae9fa --- /dev/null +++ b/.github/pull.yml @@ -0,0 +1,6 @@ +version: "1" +rules: + - base: main + upstream: GALIAIS:main # change `wei` to the owner of upstream repo + mergeMethod: merge + mergeUnstable: true \ No newline at end of file From dc7d24c27bf5c8bdb0974f0a1442887e1cee9f9e Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 26 Apr 2026 02:03:50 +0800 Subject: [PATCH 4/8] update --- docker-compose.yml | 27 ++++++++++++++++++++------- internal/app/proxy_resolver.go | 9 ++++----- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 5514c3d..bd35f7f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,14 +2,27 @@ services: notion2api: build: context: . - dockerfile: Dockerfile - container_name: notion2api - restart: unless-stopped + dockerfile: ${N2A_DOCKERFILE:-Dockerfile} + image: ${N2A_IMAGE:-notion2api:latest} + container_name: ${N2A_CONTAINER_NAME:-notion2api} + restart: ${N2A_RESTART_POLICY:-unless-stopped} ports: - - "8787:8787" + - "${N2A_PORT:-8787}:8787" volumes: - - ./config.docker.json:/app/config/config.json - - ./data:/app/data + - ${N2A_CONFIG_FILE:-./config.docker.json}:/app/config/config.json:ro + - ${N2A_DATA_DIR:-./data}:/app/data environment: - TZ: Asia/Shanghai + TZ: ${TZ:-Asia/Shanghai} + N2A_PROXY_MODE: ${N2A_PROXY_MODE:-} + N2A_PROXY_URL: ${N2A_PROXY_URL:-} + N2A_PROXY_HTTP_URL: ${N2A_PROXY_HTTP_URL:-} + N2A_PROXY_HTTPS_URL: ${N2A_PROXY_HTTPS_URL:-} + N2A_RESIN_ENABLED: ${N2A_RESIN_ENABLED:-} + N2A_RESIN_URL: ${N2A_RESIN_URL:-} + N2A_RESIN_PLATFORM: ${N2A_RESIN_PLATFORM:-} + N2A_RESIN_MODE: ${N2A_RESIN_MODE:-} + HTTP_PROXY: ${HTTP_PROXY:-} + HTTPS_PROXY: ${HTTPS_PROXY:-} + ALL_PROXY: ${ALL_PROXY:-} + NO_PROXY: ${NO_PROXY:-} command: ["./notion2api", "--config", "/app/config/config.json"] diff --git a/internal/app/proxy_resolver.go b/internal/app/proxy_resolver.go index 5aa51ba..40db4dc 100644 --- a/internal/app/proxy_resolver.go +++ b/internal/app/proxy_resolver.go @@ -83,9 +83,6 @@ func resolveResinForwardProxyURL(policy ProxyPolicy, email string, cfg AppConfig if err != nil { return nil, "", err } - if token == "" { - return nil, "", fmt.Errorf("resin token missing") - } platform := strings.TrimSpace(policy.Resin.Platform) if platform == "" { platform = "Default" @@ -94,9 +91,11 @@ func resolveResinForwardProxyURL(policy ProxyPolicy, email string, cfg AppConfig if stickyAccount == "" { stickyAccount = "account" } - username := fmt.Sprintf("%s.%s", platform, stickyAccount) proxyURL := *baseURL - proxyURL.User = url.UserPassword(username, token) + if strings.TrimSpace(token) != "" { + username := fmt.Sprintf("%s.%s", platform, stickyAccount) + proxyURL.User = url.UserPassword(username, token) + } return &proxyURL, stickyAccount, nil } From 970237b75fe170acb8d3d286dae3458005a3c321 Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 26 Apr 2026 15:22:05 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat(dispatch):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7=E5=B9=B6=E5=8F=91=E6=A7=BD=E4=BD=8D=E4=B8=8E?= =?UTF-8?q?=E5=AE=B9=E9=87=8F429=E4=BF=9D=E6=8A=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为每个账号引入 max_concurrency(默认 1)并在账号池无空闲槽位时立即返回 429,避免请求超量时继续调度。同步打通 Admin 前后端字段与回归测试,确保配置可编辑、错误映射稳定且行为可验证。 --- frontend/components/admin/accounts-panel.tsx | 27 +++- frontend/lib/services/admin/types.ts | 1 + internal/app/admin_accounts.go | 11 ++ internal/app/config.go | 2 + internal/app/main.go | 140 +++++++++++++++++ internal/app/main_fresh_thread_test.go | 90 +++++++++++ internal/app/request_dispatch.go | 154 ++++++++++++++----- 7 files changed, 381 insertions(+), 44 deletions(-) diff --git a/frontend/components/admin/accounts-panel.tsx b/frontend/components/admin/accounts-panel.tsx index 43124c3..27d29e0 100644 --- a/frontend/components/admin/accounts-panel.tsx +++ b/frontend/components/admin/accounts-panel.tsx @@ -35,6 +35,7 @@ import type { AccountItem, AccountsPayload, JsonResult, ModelItem } from '@/lib/ interface AccountEditState { priority: number; hourlyQuota: number; + maxConcurrency: number; disabled: boolean; } @@ -83,6 +84,7 @@ function buildAccountEditMap(items: AccountItem[]): Record models.filter((item) => item.id), [models]); const selectedEdit = selectedAccount?.email - ? accountEdits[selectedAccount.email] || { priority: 0, hourlyQuota: 0, disabled: false } - : { priority: 0, hourlyQuota: 0, disabled: false }; + ? accountEdits[selectedAccount.email] || { priority: 0, hourlyQuota: 0, maxConcurrency: 1, disabled: false } + : { priority: 0, hourlyQuota: 0, maxConcurrency: 1, disabled: false }; const summaryCards = [ { @@ -347,6 +349,7 @@ export function AccountsPanel({ [email]: { priority: current[email]?.priority ?? 0, hourlyQuota: current[email]?.hourlyQuota ?? 0, + maxConcurrency: current[email]?.maxConcurrency ?? 1, disabled: current[email]?.disabled ?? false, ...patch, }, @@ -383,6 +386,7 @@ export function AccountsPanel({ email, priority: edit.priority, hourly_quota: edit.hourlyQuota, + max_concurrency: edit.maxConcurrency, disabled: edit.disabled, }); toast.success(`已保存 ${email}`); @@ -724,7 +728,7 @@ export function AccountsPanel({
调度与限额

保存后直接写回账号池。

-
+
- + updateAccountEdit(selectedAccount.email, { - hourlyQuota: Number(event.target.value || 0), + hourlyQuota: Math.max(0, Number(event.target.value || 0)), + }) + } + className={FIELD_CLASS} + /> + + + + updateAccountEdit(selectedAccount.email, { + maxConcurrency: Math.max(1, Number(event.target.value || 1)), }) } className={FIELD_CLASS} diff --git a/frontend/lib/services/admin/types.ts b/frontend/lib/services/admin/types.ts index 45586bc..5615141 100644 --- a/frontend/lib/services/admin/types.ts +++ b/frontend/lib/services/admin/types.ts @@ -155,6 +155,7 @@ export interface AccountItem { last_relogin_at?: string; priority?: number; hourly_quota?: number; + max_concurrency?: number; quota_limited?: boolean; remaining_quota?: number; cooldown_active?: boolean; diff --git a/internal/app/admin_accounts.go b/internal/app/admin_accounts.go index dbc2a8b..279e811 100644 --- a/internal/app/admin_accounts.go +++ b/internal/app/admin_accounts.go @@ -72,6 +72,7 @@ func (a *App) accountRuntimeSummary(cfg AppConfig, account NotionAccount) map[st "disabled": account.Disabled, "priority": account.Priority, "hourly_quota": account.HourlyQuota, + "max_concurrency": normalizeAccountMaxConcurrency(account.MaxConcurrency), "quota_limited": quotaLimited, "remaining_quota": remainingQuota, "window_started_at": account.WindowStartedAt, @@ -239,6 +240,16 @@ func mergeEditableAccountFields(existing NotionAccount, payload map[string]any) } next.HourlyQuota = quota } + if raw, ok := accountPayload["max_concurrency"]; ok { + limit, err := intFromPayloadValue(raw) + if err != nil { + return NotionAccount{}, false, fmt.Errorf("max_concurrency invalid: %w", err) + } + if limit < 1 { + return NotionAccount{}, false, fmt.Errorf("max_concurrency must be >= 1") + } + next.MaxConcurrency = limit + } makeActive, _ := payload["active"].(bool) return next, makeActive, nil } diff --git a/internal/app/config.go b/internal/app/config.go index 1fe44f4..d54b314 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -88,6 +88,7 @@ type NotionAccount struct { Disabled bool `json:"disabled,omitempty"` Priority int `json:"priority,omitempty"` HourlyQuota int `json:"hourly_quota,omitempty"` + MaxConcurrency int `json:"max_concurrency,omitempty"` WindowStartedAt string `json:"window_started_at,omitempty"` WindowRequestCount int `json:"window_request_count,omitempty"` CooldownUntil string `json:"cooldown_until,omitempty"` @@ -587,6 +588,7 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.Accounts[i].ResinMode, ) cfg.Accounts[i].StickyProxyAccount = strings.TrimSpace(cfg.Accounts[i].StickyProxyAccount) + cfg.Accounts[i].MaxConcurrency = normalizeAccountMaxConcurrency(cfg.Accounts[i].MaxConcurrency) if cfg.Accounts[i].ProxyMode == "" { cfg.Accounts[i].ProxyMode = cfg.normalizedProxyMode() } diff --git a/internal/app/main.go b/internal/app/main.go index 632ef3e..5e96d4a 100644 --- a/internal/app/main.go +++ b/internal/app/main.go @@ -32,10 +32,16 @@ type ServerState struct { Conversations *ConversationStore AdminTokens map[string]time.Time AdminLoginAttempts map[string]AdminLoginAttempt + AccountDispatchSlots map[string]accountDispatchState LastSessionRefresh time.Time LastSessionRefreshError string } +type accountDispatchState struct { + MaxConcurrency int + InFlight int +} + type App struct { State *ServerState runPromptOverride func(*http.Request, PromptRunRequest) (InferenceResult, error) @@ -99,6 +105,135 @@ func maxInt(a int, b int) int { return b } +func normalizeAccountMaxConcurrency(raw int) int { + if raw <= 0 { + return 1 + } + return raw +} + +func (s *ServerState) initializeAccountDispatchSlotsLocked() { + if s.AccountDispatchSlots == nil { + s.AccountDispatchSlots = map[string]accountDispatchState{} + } + next := map[string]accountDispatchState{} + for _, account := range s.Config.Accounts { + emailKey := canonicalEmailKey(account.Email) + if emailKey == "" { + continue + } + maxConcurrency := normalizeAccountMaxConcurrency(account.MaxConcurrency) + state := s.AccountDispatchSlots[emailKey] + state.MaxConcurrency = maxConcurrency + if state.InFlight < 0 { + state.InFlight = 0 + } + if state.InFlight > state.MaxConcurrency { + state.InFlight = state.MaxConcurrency + } + next[emailKey] = state + } + s.AccountDispatchSlots = next +} + +func (s *ServerState) TryAcquireAccountDispatchSlot(email string) bool { + emailKey := canonicalEmailKey(email) + if emailKey == "" { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + s.initializeAccountDispatchSlotsLocked() + state, ok := s.AccountDispatchSlots[emailKey] + if !ok { + return false + } + if state.InFlight >= state.MaxConcurrency { + return false + } + state.InFlight++ + s.AccountDispatchSlots[emailKey] = state + return true +} + +func (s *ServerState) ReleaseAccountDispatchSlot(email string) { + emailKey := canonicalEmailKey(email) + if emailKey == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if s.AccountDispatchSlots == nil { + return + } + state, ok := s.AccountDispatchSlots[emailKey] + if !ok { + return + } + if state.InFlight > 0 { + state.InFlight-- + } + s.AccountDispatchSlots[emailKey] = state +} + +func (s *ServerState) RemainingAccountDispatchSlots(email string) int { + emailKey := canonicalEmailKey(email) + if emailKey == "" { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + s.initializeAccountDispatchSlotsLocked() + state, ok := s.AccountDispatchSlots[emailKey] + if !ok { + return 0 + } + remaining := state.MaxConcurrency - state.InFlight + if remaining < 0 { + remaining = 0 + } + return remaining +} + +func (s *ServerState) AvailableDispatchCapacity(emails []string) int { + s.mu.Lock() + defer s.mu.Unlock() + s.initializeAccountDispatchSlotsLocked() + total := 0 + seen := map[string]struct{}{} + for _, email := range emails { + emailKey := canonicalEmailKey(email) + if emailKey == "" { + continue + } + if _, exists := seen[emailKey]; exists { + continue + } + seen[emailKey] = struct{}{} + state, ok := s.AccountDispatchSlots[emailKey] + if !ok { + continue + } + remaining := state.MaxConcurrency - state.InFlight + if remaining > 0 { + total += remaining + } + } + return total +} + +func (s *ServerState) AccountDispatchSnapshot() map[string]accountDispatchState { + s.mu.Lock() + defer s.mu.Unlock() + s.initializeAccountDispatchSlotsLocked() + out := make(map[string]accountDispatchState, len(s.AccountDispatchSlots)) + for key, value := range s.AccountDispatchSlots { + out[key] = value + } + return out +} + + func maxFloat(a float64, b float64) float64 { if a > b { return a @@ -134,6 +269,7 @@ func newServerState(cfg AppConfig) (*ServerState, error) { Conversations: newConversationStore(), AdminTokens: map[string]time.Time{}, AdminLoginAttempts: map[string]AdminLoginAttempt{}, + AccountDispatchSlots: map[string]accountDispatchState{}, Store: store, } persistedAccountsLoaded := false @@ -1320,6 +1456,10 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { func (a *App) writeUpstreamError(w http.ResponseWriter, err error) { message := err.Error() lower := strings.ToLower(message) + if isDispatchCapacityExceededError(err) { + writeOpenAIError(w, http.StatusTooManyRequests, message, "rate_limit_error", "dispatch_capacity_exceeded") + return + } if strings.Contains(lower, "context deadline exceeded") || strings.Contains(lower, "timeout") { writeOpenAIError(w, http.StatusGatewayTimeout, message, "api_timeout_error", "upstream_timeout") return diff --git a/internal/app/main_fresh_thread_test.go b/internal/app/main_fresh_thread_test.go index 6f7a760..fd2a3ad 100644 --- a/internal/app/main_fresh_thread_test.go +++ b/internal/app/main_fresh_thread_test.go @@ -3,9 +3,12 @@ package app import ( "bytes" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" ) @@ -397,3 +400,90 @@ func TestHandleChatCompletionsStreamWritesErrorAfterHeadersSent(t *testing.T) { t.Fatalf("expected stream done marker, got body=%s", body) } } + +func TestNormalizeConfigDefaultsAccountMaxConcurrencyToOne(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Accounts: []NotionAccount{ + {Email: "alice@example.com", MaxConcurrency: 0}, + {Email: "bob@example.com", MaxConcurrency: -3}, + {Email: "carol@example.com", MaxConcurrency: 4}, + }, + }) + if got := cfg.Accounts[0].MaxConcurrency; got != 1 { + t.Fatalf("expected default max concurrency 1 for zero, got %d", got) + } + if got := cfg.Accounts[1].MaxConcurrency; got != 1 { + t.Fatalf("expected default max concurrency 1 for negative, got %d", got) + } + if got := cfg.Accounts[2].MaxConcurrency; got != 4 { + t.Fatalf("expected explicit max concurrency preserved, got %d", got) + } +} + +func TestWriteUpstreamErrorMapsDispatchCapacityTo429(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + rec := httptest.NewRecorder() + app.writeUpstreamError(rec, noDispatchCapacityError()) + + if rec.Code != http.StatusTooManyRequests { + t.Fatalf("expected status 429, got %d body=%s", rec.Code, rec.Body.String()) + } + body := rec.Body.String() + if !strings.Contains(body, `"type":"rate_limit_error"`) { + t.Fatalf("expected rate_limit_error body, got %s", body) + } + if !strings.Contains(body, `"code":"dispatch_capacity_exceeded"`) { + t.Fatalf("expected dispatch_capacity_exceeded code, got %s", body) + } +} + +func TestRunPromptWithAccountPoolReturnsCapacityErrorWhenAllSlotsOccupied(t *testing.T) { + probePath := filepath.Join(t.TempDir(), "probe.json") + if err := os.WriteFile(probePath, []byte(`{"cookies":[{"name":"token_v2","value":"test-cookie"}]}`), 0o644); err != nil { + t.Fatalf("write probe file failed: %v", err) + } + + cfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Accounts: []NotionAccount{ + {Email: "alice@example.com", MaxConcurrency: 1, ProbeJSON: probePath}, + }, + }) + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + if !state.TryAcquireAccountDispatchSlot("alice@example.com") { + t.Fatal("expected pre-acquire slot success") + } + defer state.ReleaseAccountDispatchSlot("alice@example.com") + + app := &App{State: state} + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + _, runErr := app.runPromptWithAccountPool(req, PromptRunRequest{Prompt: "hello"}, nil) + if runErr == nil { + t.Fatal("expected dispatch capacity error, got nil") + } + if !isDispatchCapacityExceededError(runErr) { + t.Fatalf("expected dispatch capacity sentinel error, got %v", runErr) + } + if !errors.Is(runErr, errDispatchCapacityExceeded) { + t.Fatalf("expected wrapped sentinel error, got %v", runErr) + } +} diff --git a/internal/app/request_dispatch.go b/internal/app/request_dispatch.go index fd94585..68e1cb3 100644 --- a/internal/app/request_dispatch.go +++ b/internal/app/request_dispatch.go @@ -14,6 +14,8 @@ const ( dispatchProtocolProbeTimeoutCapSec = 20 ) +var errDispatchCapacityExceeded = errors.New("dispatch capacity exceeded") + func requestTimeout(cfg AppConfig) time.Duration { return time.Duration(maxInt(cfg.TimeoutSec, 10)) * time.Second } @@ -26,6 +28,14 @@ func noEligibleAccountsError() error { return fmt.Errorf("no usable accounts available; check disabled state, local artifacts, or login status") } +func noDispatchCapacityError() error { + return fmt.Errorf("%w: too many concurrent requests for available accounts", errDispatchCapacityExceeded) +} + +func isDispatchCapacityExceededError(err error) bool { + return errors.Is(err, errDispatchCapacityExceeded) +} + func mergeDispatchCandidates(preferred *NotionAccount, candidates []NotionAccount) []NotionAccount { out := make([]NotionAccount, 0, len(candidates)+1) seen := map[string]struct{}{} @@ -286,6 +296,13 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest if err != nil { return InferenceResult{}, err } + candidateEmails := make([]string, 0, len(candidates)) + for _, candidate := range candidates { + candidateEmails = append(candidateEmails, candidate.Email) + } + if a.State.AvailableDispatchCapacity(candidateEmails) <= 0 { + return InferenceResult{}, noDispatchCapacityError() + } emittedAny := false wrappedDelta := func(delta string) error { @@ -300,11 +317,19 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest var lastErr error for _, original := range candidates { + if !a.State.TryAcquireAccountDispatchSlot(original.Email) { + continue + } + slotAcquired := true account := markAccountDispatchStart(original, time.Now()) session, err := a.loadReadyDispatchSession(ctx, cfg, account) if err == nil { result, runErr := a.runPromptWithSession(ctx, cfg, session, account.Email, request, wrappedDelta) if runErr == nil { + if slotAcquired { + a.State.ReleaseAccountDispatchSlot(account.Email) + slotAcquired = false + } result.AccountEmail = account.Email account.UserID = firstNonEmpty(session.UserID, account.UserID) account.UserName = firstNonEmpty(session.UserName, account.UserName) @@ -321,6 +346,10 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest } err = runErr } + if slotAcquired { + a.State.ReleaseAccountDispatchSlot(account.Email) + slotAcquired = false + } if isDispatchContextAbort(ctx, err) { return InferenceResult{}, err } @@ -335,24 +364,38 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest if ok { refreshedSession, loadErr := a.loadReadyDispatchSession(ctx, cfg, refreshedAccount) if loadErr == nil { - result, retryErr := a.runPromptWithSession(ctx, cfg, refreshedSession, refreshedAccount.Email, request, wrappedDelta) - if retryErr == nil { - result.AccountEmail = refreshedAccount.Email - refreshedAccount.UserID = firstNonEmpty(refreshedSession.UserID, refreshedAccount.UserID) - refreshedAccount.UserName = firstNonEmpty(refreshedSession.UserName, refreshedAccount.UserName) - refreshedAccount.SpaceID = firstNonEmpty(refreshedSession.SpaceID, refreshedAccount.SpaceID) - refreshedAccount.SpaceViewID = firstNonEmpty(refreshedSession.SpaceViewID, refreshedAccount.SpaceViewID) - refreshedAccount.SpaceName = firstNonEmpty(refreshedSession.SpaceName, refreshedAccount.SpaceName) - refreshedAccount.ClientVersion = firstNonEmpty(refreshedSession.ClientVersion, refreshedAccount.ClientVersion) - refreshedAccount = markAccountDispatchSuccess(refreshedAccount, time.Now()) - nextCfg := applyAccountUpdate(cfg, refreshedAccount, shouldPersistDispatchedAccountAsActive(cfg, request, refreshedAccount.Email)) - if saveErr := a.State.SaveAndApply(nextCfg); saveErr != nil { - return InferenceResult{}, saveErr + if !a.State.TryAcquireAccountDispatchSlot(refreshedAccount.Email) { + err = noDispatchCapacityError() + retryable = false + } else { + retrySlotAcquired := true + result, retryErr := a.runPromptWithSession(ctx, cfg, refreshedSession, refreshedAccount.Email, request, wrappedDelta) + if retryErr == nil { + if retrySlotAcquired { + a.State.ReleaseAccountDispatchSlot(refreshedAccount.Email) + retrySlotAcquired = false + } + result.AccountEmail = refreshedAccount.Email + refreshedAccount.UserID = firstNonEmpty(refreshedSession.UserID, refreshedAccount.UserID) + refreshedAccount.UserName = firstNonEmpty(refreshedSession.UserName, refreshedAccount.UserName) + refreshedAccount.SpaceID = firstNonEmpty(refreshedSession.SpaceID, refreshedAccount.SpaceID) + refreshedAccount.SpaceViewID = firstNonEmpty(refreshedSession.SpaceViewID, refreshedAccount.SpaceViewID) + refreshedAccount.SpaceName = firstNonEmpty(refreshedSession.SpaceName, refreshedAccount.SpaceName) + refreshedAccount.ClientVersion = firstNonEmpty(refreshedSession.ClientVersion, refreshedAccount.ClientVersion) + refreshedAccount = markAccountDispatchSuccess(refreshedAccount, time.Now()) + nextCfg := applyAccountUpdate(cfg, refreshedAccount, shouldPersistDispatchedAccountAsActive(cfg, request, refreshedAccount.Email)) + if saveErr := a.State.SaveAndApply(nextCfg); saveErr != nil { + return InferenceResult{}, saveErr + } + return result, nil } - return result, nil + if retrySlotAcquired { + a.State.ReleaseAccountDispatchSlot(refreshedAccount.Email) + retrySlotAcquired = false + } + err = retryErr + retryable = isSessionRetryableError(err) } - err = retryErr - retryable = isSessionRetryableError(err) } else { err = loadErr retryable = isSessionRetryableError(err) @@ -388,7 +431,7 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest if lastErr != nil { return InferenceResult{}, lastErr } - return InferenceResult{}, noEligibleAccountsError() + return InferenceResult{}, noDispatchCapacityError() } func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { @@ -406,6 +449,13 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu if err != nil { return InferenceResult{}, err } + candidateEmails := make([]string, 0, len(candidates)) + for _, candidate := range candidates { + candidateEmails = append(candidateEmails, candidate.Email) + } + if a.State.AvailableDispatchCapacity(candidateEmails) <= 0 { + return InferenceResult{}, noDispatchCapacityError() + } emittedAny := false wrappedText := func(delta string) error { @@ -429,6 +479,10 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu var lastErr error for _, original := range candidates { + if !a.State.TryAcquireAccountDispatchSlot(original.Email) { + continue + } + slotAcquired := true account := markAccountDispatchStart(original, time.Now()) session, err := a.loadReadyDispatchSession(ctx, cfg, account) if err == nil { @@ -439,6 +493,10 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu KeepAlive: wrappedKeepAlive, }) if runErr == nil { + if slotAcquired { + a.State.ReleaseAccountDispatchSlot(account.Email) + slotAcquired = false + } result.AccountEmail = account.Email account.UserID = firstNonEmpty(session.UserID, account.UserID) account.UserName = firstNonEmpty(session.UserName, account.UserName) @@ -455,6 +513,10 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu } err = runErr } + if slotAcquired { + a.State.ReleaseAccountDispatchSlot(account.Email) + slotAcquired = false + } if isDispatchContextAbort(ctx, err) { return InferenceResult{}, err } @@ -468,29 +530,43 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu if refreshedAccount, _, ok := cfg.FindAccount(account.Email); ok { refreshedSession, loadErr := a.loadReadyDispatchSession(ctx, cfg, refreshedAccount) if loadErr == nil { - result, retryErr := a.runPromptWithSessionWithSink(ctx, cfg, refreshedSession, refreshedAccount.Email, request, InferenceStreamSink{ - Text: wrappedText, - Reasoning: wrappedReasoning, - ReasoningWarmup: wrappedReasoningWarmup, - KeepAlive: wrappedKeepAlive, - }) - if retryErr == nil { - result.AccountEmail = refreshedAccount.Email - refreshedAccount.UserID = firstNonEmpty(refreshedSession.UserID, refreshedAccount.UserID) - refreshedAccount.UserName = firstNonEmpty(refreshedSession.UserName, refreshedAccount.UserName) - refreshedAccount.SpaceID = firstNonEmpty(refreshedSession.SpaceID, refreshedAccount.SpaceID) - refreshedAccount.SpaceViewID = firstNonEmpty(refreshedSession.SpaceViewID, refreshedAccount.SpaceViewID) - refreshedAccount.SpaceName = firstNonEmpty(refreshedSession.SpaceName, refreshedAccount.SpaceName) - refreshedAccount.ClientVersion = firstNonEmpty(refreshedSession.ClientVersion, refreshedAccount.ClientVersion) - refreshedAccount = markAccountDispatchSuccess(refreshedAccount, time.Now()) - nextCfg := applyAccountUpdate(cfg, refreshedAccount, shouldPersistDispatchedAccountAsActive(cfg, request, refreshedAccount.Email)) - if saveErr := a.State.SaveAndApply(nextCfg); saveErr != nil { - return InferenceResult{}, saveErr + if !a.State.TryAcquireAccountDispatchSlot(refreshedAccount.Email) { + err = noDispatchCapacityError() + retryable = false + } else { + retrySlotAcquired := true + result, retryErr := a.runPromptWithSessionWithSink(ctx, cfg, refreshedSession, refreshedAccount.Email, request, InferenceStreamSink{ + Text: wrappedText, + Reasoning: wrappedReasoning, + ReasoningWarmup: wrappedReasoningWarmup, + KeepAlive: wrappedKeepAlive, + }) + if retryErr == nil { + if retrySlotAcquired { + a.State.ReleaseAccountDispatchSlot(refreshedAccount.Email) + retrySlotAcquired = false + } + result.AccountEmail = refreshedAccount.Email + refreshedAccount.UserID = firstNonEmpty(refreshedSession.UserID, refreshedAccount.UserID) + refreshedAccount.UserName = firstNonEmpty(refreshedSession.UserName, refreshedAccount.UserName) + refreshedAccount.SpaceID = firstNonEmpty(refreshedSession.SpaceID, refreshedAccount.SpaceID) + refreshedAccount.SpaceViewID = firstNonEmpty(refreshedSession.SpaceViewID, refreshedAccount.SpaceViewID) + refreshedAccount.SpaceName = firstNonEmpty(refreshedSession.SpaceName, refreshedAccount.SpaceName) + refreshedAccount.ClientVersion = firstNonEmpty(refreshedSession.ClientVersion, refreshedAccount.ClientVersion) + refreshedAccount = markAccountDispatchSuccess(refreshedAccount, time.Now()) + nextCfg := applyAccountUpdate(cfg, refreshedAccount, shouldPersistDispatchedAccountAsActive(cfg, request, refreshedAccount.Email)) + if saveErr := a.State.SaveAndApply(nextCfg); saveErr != nil { + return InferenceResult{}, saveErr + } + return result, nil } - return result, nil + if retrySlotAcquired { + a.State.ReleaseAccountDispatchSlot(refreshedAccount.Email) + retrySlotAcquired = false + } + err = retryErr + retryable = isSessionRetryableError(err) } - err = retryErr - retryable = isSessionRetryableError(err) } else { err = loadErr retryable = isSessionRetryableError(err) @@ -526,5 +602,5 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu if lastErr != nil { return InferenceResult{}, lastErr } - return InferenceResult{}, noEligibleAccountsError() + return InferenceResult{}, noDispatchCapacityError() } From e192330084cc8f35f6c2b6c0e52381ff2fca7166 Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 3 May 2026 00:27:51 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E4=B8=BA=E9=99=8D=E4=BD=8E=E9=AB=98?= =?UTF-8?q?=E5=B9=B6=E5=8F=91=E5=BB=B6=E8=BF=9F=E5=B9=B6=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E5=8F=91=E5=B8=83=E7=A8=B3=E5=AE=9A=E6=80=A7=EF=BC=8C=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E5=AE=8C=E6=88=90=E8=AF=B7=E6=B1=82=E7=83=AD=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E5=8E=BB=E9=94=81=E3=80=81=E6=B5=81=E5=BC=8F=E4=BC=A0?= =?UTF-8?q?=E8=BE=93=E4=B8=8E=E5=8F=AF=E8=A7=82=E6=B5=8B=E6=80=A7=E6=B2=BB?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次提交聚焦于把 PLAN 中的性能与稳定性改造落到可发布代码:减少请求链路重复分配与锁竞争、缩短流式首包等待、补齐运行期观测与回归测试,并同步容器与配置默认值以降低线上回归风险。 文件级变更说明: - `.gitignore`:允许跟踪 `internal/wreq/wreq_streaming_test.go` 与 `internal/wreq/wreq_streaming_stub_test.go`,避免被通配 `*_test.go` 误忽略。 - `Dockerfile`:Rust target 缓存改为 `sharing=private`;runtime 安装 `node-wreq` 改为 `npm pack + tar`,去掉 `npm init/npm install` 的额外开销与噪音。 - `config.docker.json`:补充 `limits.max_request_body_bytes`、`dispatch.probe_cache_ttl_seconds`、`browser.helper_pool_size`、`debug.pprof_*` 示例默认项。 - `config.example.json`:同步新增 limits/dispatch/browser/debug 配置示例,保证本地与容器配置面一致。 - `internal/app/account_pool.go`:分发候选排序改为复用 `emailKey`;新增基于 snapshot 的候选读取;补充 wreq client 创建计数埋点。 - `internal/app/accounts.go`:账户路径与绝对路径匹配 regex 提升为包级;新增 `getAccountEmailKey` 复用 canonical key,减少重复计算。 - `internal/app/admin.go`:管理端请求体解码切到统一 `decodeBody` 流程,并统一非法请求体错误返回。 - `internal/app/admin_accounts.go`:账号管理读写路径统一使用 `emailKey`;更新/删除/启停账号后触发 dispatch probe 缓存失效。 - `internal/app/admin_conversations.go`:CORS header 改为复用统一中间层函数。 - `internal/app/assets/browser-helper.cjs`:新增可复用的浏览器 helper 资产脚本,支持 pool 模式与长度前缀协议处理。 - `internal/app/assets/browser-login-helper.cjs`:新增独立登录 helper 资产脚本,统一 cookie jar 与请求执行逻辑。 - `internal/app/config.go`:新增 `DispatchConfig`/`BrowserConfig`/`DebugConfig`/`LimitsConfig`;补齐默认值与 CLI 参数(pprof、max-request-body);预计算 prompt retry 前缀。 - `internal/app/conversations.go`:ConversationStore 改为 pointer COW + 预览缓存字段,降低列表与读取路径的深拷贝与扫描开销。 - `internal/app/httpclient_audit.md`:补充 T-4-2 审计结论,记录 http client/transport 复用风险与改造建议。 - `internal/app/main.go`:ServerState 引入 atomic snapshot、slot 原子计数、responseStore、sqliteWriter、静态 JSON 缓存、pprof 与统一 CORS/请求体限制入口。 - `internal/app/main_fresh_thread_test.go`:扩展大规模回归测试,覆盖默认配置、dispatch/probe/cache、decode limit、response store、metrics、browser pool 等关键路径。 - `internal/app/metrics.go`:新增 Prometheus 文本导出实现,补齐请求时延、dispatch inflight、wreq 调用、sqlite 操作、browser helper 等指标面。 - `internal/app/models.go`:Probe 模型路径抽取与并行加载,减少多账号场景下模型构建延迟。 - `internal/app/notion_client.go`:引入 HTTP transport cache(按上游/代理/账号策略键控),减少重复建连与连接池碎片。 - `internal/app/notion_client_best_effort_test.go`:新增 dispatch probe TTL 行为测试、失败重探活与缓存失效回归。 - `internal/app/notion_client_browser_fallback_test.go`:补充 helper 错误分类、脚本路径稳定性、环境变量传播与 pool 相关行为测试。 - `internal/app/notion_client_browser_transport.go`:实现 browser helper pool、帧协议读写、worker 生命周期治理、env 优先级与故障重建机制。 - `internal/app/notion_client_protocol_test.go`:新增 transport cache 复用与命中计数回归,验证 client 构建路径一致性。 - `internal/app/notion_client_wreq_transport.go`:移除内嵌 JS 脚本字面量,改为加载资产脚本;对接新的 wreq 流式读取接口。 - `internal/app/openai.go`:chat/responses 规范化逻辑拆分为“typed parts”入口,减少 payload map 热路径访问。 - `internal/app/openai_types.go`:新增 typed request envelope 与字段解析辅助,为 typed-first decode 提供结构化承载。 - `internal/app/prompt_guard.go`:新增 retry 前缀一次性构建函数,运行期复用预计算 slice。 - `internal/app/request_dispatch.go`:新增 probe cache、wreq client new 指标、dispatch 探活缓存与失效策略。 - `internal/app/response_store.go`:新增 map + min-heap 的响应存储与过期清理结构,替换全表扫描过期策略。 - `internal/app/session_refresh.go`:刷新流程增加 test hook 注入点,并在刷新成功后联动 probe 缓存失效。 - `internal/app/sqlite_store.go`:增加 sqlite pragma、读写分离连接(roDB)与操作耗时指标上报。 - `internal/app/sqlite_writer.go`:新增异步持久化写队列与回退计数,降低请求路径同步写阻塞。 - `internal/wreq/wreq_cgo.go`:cgo 链接参数按平台分支(Windows 去 `-ldl`);引入 begin/read/close 流式 FFI 声明与 Go 侧读关闭实现。 - `internal/wreq/wreq_ffi_compat.h`:新增兼容头,固定 cgo 侧所需的流式 FFI 原型。 - `internal/wreq/wreq_streaming_stub_test.go`:新增非 cgo stub 路径流式接口行为回归。 - `internal/wreq/wreq_streaming_test.go`:新增流式首包与 chunk 间隔行为测试,锁定读写延迟形态。 - `internal/wreq/wreq_stub.go`:stub 实现补齐 `Begin/Read/Close` 接口与字段调整,保持无 cgo 路径可编译一致性。 - `scripts/perf/baseline.sh`:新增一键 baseline 脚本(流式/非流式压测 + pprof 抓取 + summary 汇总)。 - `scripts/perf/payload-chat.json`:新增 baseline 脚本默认请求体模板。 - `wreq-ffi/Cargo.toml`:移除 `once_cell` 依赖,转向标准库 `OnceLock`。 - `wreq-ffi/README.md`:更新 FFI 文档为 `wreq_request_begin/wreq_response_read/wreq_response_close` 流式协议说明。 - `wreq-ffi/build.rs`:生成/刷新仓库内兼容头文件,保障未生成 include 头时的 cgo 编译连通性。 - `wreq-ffi/src/lib.rs`:Rust FFI 主体从 base64 包装迁移到流式句柄读写;新增错误码、响应状态管理与读取超时处理。 Constraint: 本次提交必须排除 `.omx/`、`docs/` 与 `PLAN.md`,并保持现有分支可直接用于发布前验证。 Rejected: 拆分为多次小提交 | 当前改造跨传输层/状态层/测试层强耦合,保持单次一致快照更利于回归与回滚。 Confidence: high Scope-risk: broad Directive: 后续若继续推进 Rust T-3-5,请先安装 `cmake` 并补跑 `cargo test` 后再做 FFI 层迭代。 Tested: `go test ./...`; `go test -race ./...`; `go build ./cmd/notion2api` Not-tested: `cargo test`(当前环境缺少 `cmake`);`scripts/perf/baseline.sh` 的 60s 实压与 profile 采集本次未执行。 --- .gitignore | 2 + Dockerfile | 9 +- config.docker.json | 13 + config.example.json | 13 + internal/app/account_pool.go | 24 +- internal/app/accounts.go | 27 +- internal/app/admin.go | 8 +- internal/app/admin_accounts.go | 39 +- internal/app/admin_conversations.go | 2 +- internal/app/assets/browser-helper.cjs | 229 ++ internal/app/assets/browser-login-helper.cjs | 80 + internal/app/config.go | 63 + internal/app/conversations.go | 246 ++- internal/app/httpclient_audit.md | 69 + internal/app/main.go | 867 ++++++-- internal/app/main_fresh_thread_test.go | 1895 +++++++++++++++++ internal/app/metrics.go | 435 ++++ internal/app/models.go | 73 +- internal/app/notion_client.go | 175 +- .../app/notion_client_best_effort_test.go | 184 +- .../notion_client_browser_fallback_test.go | 348 +++ .../app/notion_client_browser_transport.go | 382 +++- internal/app/notion_client_protocol_test.go | 156 ++ internal/app/notion_client_wreq_transport.go | 242 +-- internal/app/openai.go | 19 +- internal/app/openai_types.go | 339 +++ internal/app/prompt_guard.go | 39 +- internal/app/request_dispatch.go | 214 +- internal/app/response_store.go | 217 ++ internal/app/session_refresh.go | 36 +- internal/app/sqlite_store.go | 117 +- internal/app/sqlite_writer.go | 203 ++ internal/wreq/wreq_cgo.go | 220 +- internal/wreq/wreq_ffi_compat.h | 26 + internal/wreq/wreq_streaming_stub_test.go | 30 + internal/wreq/wreq_streaming_test.go | 85 + internal/wreq/wreq_stub.go | 20 +- scripts/perf/baseline.sh | 299 +++ scripts/perf/payload-chat.json | 11 + wreq-ffi/Cargo.toml | 1 - wreq-ffi/README.md | 52 +- wreq-ffi/build.rs | 38 + wreq-ffi/src/lib.rs | 417 ++-- 43 files changed, 7112 insertions(+), 852 deletions(-) create mode 100644 internal/app/assets/browser-helper.cjs create mode 100644 internal/app/assets/browser-login-helper.cjs create mode 100644 internal/app/httpclient_audit.md create mode 100644 internal/app/metrics.go create mode 100644 internal/app/openai_types.go create mode 100644 internal/app/response_store.go create mode 100644 internal/app/sqlite_writer.go create mode 100644 internal/wreq/wreq_ffi_compat.h create mode 100644 internal/wreq/wreq_streaming_stub_test.go create mode 100644 internal/wreq/wreq_streaming_test.go create mode 100644 scripts/perf/baseline.sh create mode 100644 scripts/perf/payload-chat.json diff --git a/.gitignore b/.gitignore index 1abd7ba..95aa032 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,8 @@ frontend/out/ # Local test sources stay out of git *_test.go +!internal/wreq/wreq_streaming_test.go +!internal/wreq/wreq_streaming_stub_test.go *.test.ts *.test.tsx *.spec.ts diff --git a/Dockerfile b/Dockerfile index a0321c7..80daf2c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,7 +44,7 @@ COPY wreq-ffi ./wreq-ffi RUN --mount=type=cache,target=/usr/local/cargo/registry \ --mount=type=cache,target=/usr/local/cargo/git \ - --mount=type=cache,target=/cargo-target,id=cargo-target-${TARGETARCH},sharing=locked \ + --mount=type=cache,target=/cargo-target,id=cargo-target-${TARGETARCH},sharing=private \ set -eux; \ RUST_TARGET=$(cat /tmp/rust_target); \ case "${TARGETARCH}" in \ @@ -133,8 +133,11 @@ RUN apt-get update \ && mkdir -p /opt/notion2api-helper /app/config /app/data/notion_accounts /app/static RUN cd /opt/notion2api-helper \ - && npm init -y >/dev/null 2>&1 \ - && npm install --omit=dev --no-package-lock node-wreq@2.2.1 \ + && npm pack node-wreq@2.2.1 \ + && tar -xf node-wreq-2.2.1.tgz \ + && mkdir -p "$NODE_PATH" \ + && mv package "$NODE_PATH/node-wreq" \ + && rm -f node-wreq-2.2.1.tgz \ && test -d "$NODE_PATH/node-wreq" \ && npm cache clean --force >/dev/null 2>&1 diff --git a/config.docker.json b/config.docker.json index 9bf604d..c18905f 100644 --- a/config.docker.json +++ b/config.docker.json @@ -29,6 +29,9 @@ "persist_continuation_sessions": true, "persist_sillytavern_bindings": true }, + "limits": { + "max_request_body_bytes": 4194304 + }, "admin": { "enabled": true, "password": "change-me-admin-password", @@ -46,6 +49,16 @@ "retry_on_auth_error": true, "auto_switch_account": true }, + "dispatch": { + "probe_cache_ttl_seconds": 45 + }, + "browser": { + "helper_pool_size": 0 + }, + "debug": { + "pprof_enabled": false, + "pprof_addr": "127.0.0.1:6060" + }, "features": { "use_web_search": true, "use_read_only_mode": false, diff --git a/config.example.json b/config.example.json index aebb2b4..d7651ea 100644 --- a/config.example.json +++ b/config.example.json @@ -43,6 +43,9 @@ "persist_continuation_sessions": true, "persist_sillytavern_bindings": true }, + "limits": { + "max_request_body_bytes": 4194304 + }, "features": { "use_web_search": true, "use_read_only_mode": false, @@ -69,6 +72,16 @@ "retry_on_auth_error": true, "auto_switch_account": true }, + "dispatch": { + "probe_cache_ttl_seconds": 45 + }, + "browser": { + "helper_pool_size": 0 + }, + "debug": { + "pprof_enabled": false, + "pprof_addr": "127.0.0.1:6060" + }, "accounts": [ { "email": "alice@example.com", diff --git a/internal/app/account_pool.go b/internal/app/account_pool.go index a287faf..3cb4221 100644 --- a/internal/app/account_pool.go +++ b/internal/app/account_pool.go @@ -156,8 +156,10 @@ func sortDispatchCandidates(cfg AppConfig, accounts []NotionAccount, now time.Ti sort.Slice(accounts, func(i, j int) bool { left := accounts[i] right := accounts[j] - leftActive := canonicalEmailKey(left.Email) == activeKey - rightActive := canonicalEmailKey(right.Email) == activeKey + leftKey := getAccountEmailKey(left) + rightKey := getAccountEmailKey(right) + leftActive := leftKey == activeKey + rightActive := rightKey == activeKey if leftActive != rightActive { return leftActive } @@ -183,11 +185,11 @@ func sortDispatchCandidates(cfg AppConfig, accounts []NotionAccount, now time.Ti if !leftUsed.Equal(rightUsed) { return leftUsed.Before(rightUsed) } - return canonicalEmailKey(left.Email) < canonicalEmailKey(right.Email) + return leftKey < rightKey }) } -func pickDispatchCandidates(cfg AppConfig, now time.Time) []NotionAccount { +func buildDispatchCandidateOrder(cfg AppConfig, now time.Time) []NotionAccount { candidates := make([]NotionAccount, 0, len(cfg.Accounts)) for _, account := range cfg.Accounts { account = ensureAccountPaths(cfg, account) @@ -199,6 +201,16 @@ func pickDispatchCandidates(cfg AppConfig, now time.Time) []NotionAccount { return candidates } +func pickDispatchCandidatesFromSnapshot(bundle *snapshotBundle, now time.Time) []NotionAccount { + if bundle == nil { + return nil + } + if len(bundle.DispatchOrder) > 0 { + return bundle.DispatchOrder + } + return buildDispatchCandidateOrder(bundle.Config, now) +} + func applyAccountUpdate(cfg AppConfig, account NotionAccount, makeActive bool) AppConfig { account = ensureAccountPaths(cfg, account) cfg.UpsertAccount(account) @@ -241,8 +253,10 @@ func (a *App) runPromptWithSession(ctx context.Context, cfg AppConfig, session S if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, onDelta) } + wreqClientNewTotalMetric.Add("standard", 1) client := newNotionAIClient(session, cfg, accountEmail) if onDelta != nil { + wreqClientNewTotalMetric.Add("streaming", 1) client = newNotionAIStreamingClient(session, cfg, accountEmail) } execute := func(ctx context.Context, current PromptRunRequest, forward func(string) error) (InferenceResult, error) { @@ -261,8 +275,10 @@ func (a *App) runPromptWithSessionWithSink(ctx context.Context, cfg AppConfig, s if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, sink.Text) } + wreqClientNewTotalMetric.Add("streaming", 1) client := newNotionAIStreamingClient(session, cfg, accountEmail) if sink.Text == nil && sink.Reasoning == nil && sink.ReasoningWarmup == nil && sink.KeepAlive == nil { + wreqClientNewTotalMetric.Add("standard", 1) client = newNotionAIClient(session, cfg, accountEmail) } if sink.Reasoning != nil || sink.ReasoningWarmup != nil || sink.KeepAlive != nil { diff --git a/internal/app/accounts.go b/internal/app/accounts.go index a1c7331..021a2a6 100644 --- a/internal/app/accounts.go +++ b/internal/app/accounts.go @@ -10,6 +10,11 @@ import ( "strings" ) +var ( + accountPathSlugPattern = regexp.MustCompile(`[^a-z0-9]+`) + windowsAbsolutePathPattern = regexp.MustCompile(`^[A-Za-z]:[\\/].*`) +) + type ResolvedLoginHelper struct { SessionsDir string `json:"sessions_dir"` TimeoutSec int `json:"timeout_sec"` @@ -50,13 +55,19 @@ func canonicalEmailKey(email string) string { return strings.ToLower(strings.TrimSpace(email)) } +func getAccountEmailKey(account NotionAccount) string { + if account.emailKey != "" { + return account.emailKey + } + return canonicalEmailKey(account.Email) +} + func accountPathSlug(email string) string { clean := canonicalEmailKey(email) if clean == "" { return "account" } - re := regexp.MustCompile(`[^a-z0-9]+`) - clean = re.ReplaceAllString(clean, "_") + clean = accountPathSlugPattern.ReplaceAllString(clean, "_") clean = strings.Trim(clean, "_") if clean == "" { return "account" @@ -99,7 +110,7 @@ func pathLooksAbsoluteAnyOS(value string) bool { if filepath.IsAbs(clean) { return true } - if matched, _ := regexp.MatchString(`^[A-Za-z]:[\\/].*`, clean); matched { + if windowsAbsolutePathPattern.MatchString(clean) { return true } if strings.HasPrefix(clean, `\\`) { @@ -119,7 +130,7 @@ func isForeignAbsolutePath(value string) bool { if runtime.GOOS == "windows" { return strings.HasPrefix(clean, "/") } - if matched, _ := regexp.MatchString(`^[A-Za-z]:[\\/].*`, clean); matched { + if windowsAbsolutePathPattern.MatchString(clean) { return true } if strings.HasPrefix(clean, `\\`) { @@ -150,7 +161,7 @@ func (cfg AppConfig) FindAccount(email string) (NotionAccount, int, bool) { return NotionAccount{}, -1, false } for i, account := range cfg.Accounts { - if canonicalEmailKey(account.Email) == target { + if getAccountEmailKey(account) == target { return account, i, true } } @@ -215,6 +226,7 @@ func (helper ResolvedLoginHelper) ProbePath(profileDir string) string { } func ensureAccountPaths(cfg AppConfig, account NotionAccount) NotionAccount { + account.emailKey = canonicalEmailKey(account.Email) helper := cfg.ResolveLoginHelper() if strings.TrimSpace(account.ProfileDir) == "" || isForeignAbsolutePath(account.ProfileDir) { account.ProfileDir = helper.ProfileDirFor(account.Email) @@ -328,12 +340,13 @@ func (cfg *AppConfig) UpsertAccount(account NotionAccount) (NotionAccount, int) } func (cfg *AppConfig) DeleteAccount(email string) bool { - _, index, ok := cfg.FindAccount(email) + target := canonicalEmailKey(email) + _, index, ok := cfg.FindAccount(target) if !ok { return false } cfg.Accounts = append(cfg.Accounts[:index], cfg.Accounts[index+1:]...) - if canonicalEmailKey(cfg.ActiveAccount) == canonicalEmailKey(email) { + if canonicalEmailKey(cfg.ActiveAccount) == target { cfg.ActiveAccount = "" cfg.ProbeJSON = "" } diff --git a/internal/app/admin.go b/internal/app/admin.go index 6888344..19f8ea0 100644 --- a/internal/app/admin.go +++ b/internal/app/admin.go @@ -426,9 +426,9 @@ func (a *App) handleAdminLogin(w http.ResponseWriter, r *http.Request) { }) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } if !securePasswordEqual(password, stringValue(payload["password"])) { @@ -671,9 +671,9 @@ func (a *App) handleAdminTest(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } cfg, _, registry := a.State.Snapshot() diff --git a/internal/app/admin_accounts.go b/internal/app/admin_accounts.go index 279e811..32a9127 100644 --- a/internal/app/admin_accounts.go +++ b/internal/app/admin_accounts.go @@ -87,7 +87,7 @@ func (a *App) accountRuntimeSummary(cfg AppConfig, account NotionAccount) map[st "consecutive_failures": account.ConsecutiveFailures, "total_successes": account.TotalSuccesses, "total_failures": account.TotalFailures, - "active": canonicalEmailKey(cfg.ActiveAccount) == canonicalEmailKey(account.Email), + "active": canonicalEmailKey(cfg.ActiveAccount) == getAccountEmailKey(account), } if status, err := readLoginStatusFile(account.PendingStatePath); err == nil { item["login_status"] = status @@ -262,9 +262,9 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { case http.MethodGet: writeJSON(w, http.StatusOK, a.buildAccountsPayload()) case http.MethodPost: - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } account, makeActive, err := decodeAccountPayload(payload) @@ -286,11 +286,12 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) case http.MethodPut: - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := accountEmailFromPayload(payload) @@ -314,7 +315,7 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { return } cfg.Accounts[index] = ensureAccountPaths(cfg, next) - if canonicalEmailKey(cfg.ActiveAccount) == canonicalEmailKey(next.Email) && next.Disabled { + if canonicalEmailKey(cfg.ActiveAccount) == getAccountEmailKey(next) && next.Disabled { cfg.ActiveAccount = "" cfg.ProbeJSON = "" } @@ -330,6 +331,7 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) default: writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) @@ -362,6 +364,7 @@ func (a *App) handleAdminAccountDelete(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) } @@ -373,9 +376,9 @@ func (a *App) handleAdminAccountsActivate(w http.ResponseWriter, r *http.Request writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := strings.TrimSpace(stringValue(payload["email"])) @@ -400,6 +403,7 @@ func (a *App) handleAdminAccountsActivate(w http.ResponseWriter, r *http.Request writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) } @@ -411,9 +415,9 @@ func (a *App) handleAdminAccountsTest(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } cfg, _, registry := a.State.Snapshot() @@ -676,9 +680,9 @@ func (a *App) handleAdminAccountManualImport(w http.ResponseWriter, r *http.Requ writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } req, err := decodeManualImportRequest(payload) @@ -750,6 +754,7 @@ func (a *App) handleAdminAccountManualImport(w http.ResponseWriter, r *http.Requ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() cfg, _, _ = a.State.Snapshot() account, _, _ = cfg.FindAccount(accountEmail) writeJSON(w, http.StatusOK, map[string]any{ @@ -767,9 +772,9 @@ func (a *App) handleAdminAccountLoginStart(w http.ResponseWriter, r *http.Reques writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := strings.TrimSpace(stringValue(payload["email"])) @@ -819,6 +824,7 @@ func (a *App) handleAdminAccountLoginStart(w http.ResponseWriter, r *http.Reques writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, map[string]any{ "success": true, "account": a.accountRuntimeSummary(cfg, account), @@ -834,9 +840,9 @@ func (a *App) handleAdminAccountLoginVerify(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := strings.TrimSpace(stringValue(payload["email"])) @@ -893,6 +899,7 @@ func (a *App) handleAdminAccountLoginVerify(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, map[string]any{ "success": true, "account": a.accountRuntimeSummary(cfg, account), diff --git a/internal/app/admin_conversations.go b/internal/app/admin_conversations.go index fe1721c..da94b3f 100644 --- a/internal/app/admin_conversations.go +++ b/internal/app/admin_conversations.go @@ -334,7 +334,7 @@ func (a *App) handleAdminEvents(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("Access-Control-Allow-Origin", "*") + applyCORSHeaders(w) w.WriteHeader(http.StatusOK) subID, events := a.State.conversations().Subscribe() diff --git a/internal/app/assets/browser-helper.cjs b/internal/app/assets/browser-helper.cjs new file mode 100644 index 0000000..dc0a5f9 --- /dev/null +++ b/internal/app/assets/browser-helper.cjs @@ -0,0 +1,229 @@ +const fs = require('fs'); +const { fetch } = require('node-wreq'); + +const poolMode = String(process.env.NOTION2API_BROWSER_HELPER_MODE || '').trim().toLowerCase() === 'pool' + && String(process.env.NOTION2API_BROWSER_HELPER_PROTOCOL || '').trim() === 'N2A_HELPER_POOL_V1'; + +function formatError(error) { + return error && error.stack ? error.stack : String(error); +} + +function buildCookieJar(items) { + const cookieMap = new Map(); + for (const item of items || []) { + const name = String((item && item.name) || '').trim(); + if (!name) continue; + cookieMap.set(name, String((item && item.value) || '')); + } + return { + getCookies() { + return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); + }, + setCookie(cookie) { + const text = String(cookie || ''); + const semi = text.indexOf(';'); + const pair = semi === -1 ? text : text.slice(0, semi); + const eq = pair.indexOf('='); + if (eq <= 0) return; + const name = pair.slice(0, eq).trim(); + const value = pair.slice(eq + 1).trim(); + if (name) cookieMap.set(name, value); + }, + }; +} + +function buildHeaders(rawHeaders) { + const headers = {}; + for (const [key, value] of Object.entries(rawHeaders || {})) { + if (key === undefined || key === null) continue; + if (String(key).toLowerCase() === 'cookie') continue; + headers[String(key)] = String(value == null ? '' : value); + } + return headers; +} + +function markLineState(line, state) { + if (!line || !state) return; + try { + const parsed = JSON.parse(line); + if (String(parsed.type || '').toLowerCase() !== 'agent-inference' || !Array.isArray(parsed.value)) return; + const hasVisibleText = parsed.value.some((entry) => { + const t = String((entry && entry.type) || '').toLowerCase(); + const c = String((entry && entry.content) || ''); + return t === 'text' && c.trim() !== ''; + }); + if (!hasVisibleText) return; + state.sawAnswer = true; + if (parsed.finishedAt != null) state.sawTerminal = true; + } catch (_) {} +} + +async function runSingleRequest(input) { + const cookieJar = buildCookieJar(input.cookies || []); + const headers = buildHeaders(input.headers || {}); + const fetchOptions = { + method: 'POST', + browser: input.browser_profile || 'chrome_142', + headers, + body: JSON.stringify(input.payload || {}), + cookieJar, + timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), + throwHttpErrors: false, + }; + const proxy = String(input.proxy || '').trim(); + if (proxy) fetchOptions.proxy = proxy; + + const result = { status: 0, content_type: '', text: '' }; + const response = await fetch(input.run_url, fetchOptions); + result.status = response.status; + result.content_type = response.headers.get('content-type') || ''; + const isNDJSON = String(result.content_type).toLowerCase().includes('application/x-ndjson'); + if (!isNDJSON) { + result.text = await response.text(); + return result; + } + + const idleAfterAnswerMs = Math.max(Number(input.idle_after_answer_ms || 0), 0); + const readable = response.wreq && typeof response.wreq.readable === 'function' + ? response.wreq.readable() + : null; + if (!readable) { + result.text = await response.text(); + return result; + } + + let pending = ''; + const state = { sawAnswer: false, sawTerminal: false }; + let settled = false; + let idleTimer = null; + + await new Promise((resolve, reject) => { + const settle = () => { + if (settled) return; + settled = true; + if (idleTimer) { + clearTimeout(idleTimer); + idleTimer = null; + } + const remaining = pending.trim(); + if (remaining) markLineState(remaining, state); + try { readable.destroy(); } catch (_) {} + resolve(); + }; + const armIdle = () => { + if (idleTimer) { + clearTimeout(idleTimer); + idleTimer = null; + } + if (state.sawAnswer && idleAfterAnswerMs > 0) { + idleTimer = setTimeout(settle, idleAfterAnswerMs); + } + }; + readable.on('data', (chunk) => { + const text = Buffer.isBuffer(chunk) ? chunk.toString('utf8') : String(chunk); + result.text += text; + pending += text; + while (true) { + const newlineIndex = pending.indexOf('\n'); + if (newlineIndex === -1) break; + const line = pending.slice(0, newlineIndex).trim(); + pending = pending.slice(newlineIndex + 1); + markLineState(line, state); + if (state.sawTerminal) { + settle(); + return; + } + } + armIdle(); + }); + readable.on('end', settle); + readable.on('close', settle); + readable.on('error', (err) => { + if (settled) return; + settled = true; + if (idleTimer) clearTimeout(idleTimer); + reject(err); + }); + }); + + return result; +} + +function writeFrame(payloadBuffer) { + const header = Buffer.allocUnsafe(4); + header.writeUInt32LE(payloadBuffer.length, 0); + process.stdout.write(header); + process.stdout.write(payloadBuffer); +} + +function runPoolLoop() { + let pending = Buffer.alloc(0); + const queue = []; + let draining = false; + + const drainQueue = async () => { + if (draining) return; + draining = true; + while (queue.length > 0) { + const payload = queue.shift(); + let input; + try { + input = JSON.parse(payload.toString('utf8')); + } catch (err) { + process.stderr.write(formatError(err) + '\n'); + process.exit(2); + return; + } + let result; + try { + result = await runSingleRequest(input); + } catch (err) { + process.stderr.write(formatError(err) + '\n'); + process.exit(2); + return; + } + const body = Buffer.from(JSON.stringify(result)); + writeFrame(body); + } + draining = false; + }; + + process.stdin.on('data', (chunk) => { + const incoming = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); + pending = Buffer.concat([pending, incoming]); + while (pending.length >= 4) { + const bodyLen = pending.readUInt32LE(0); + if (pending.length < 4 + bodyLen) { + break; + } + const body = pending.subarray(4, 4 + bodyLen); + pending = pending.subarray(4 + bodyLen); + queue.push(body); + } + void drainQueue(); + }); + + process.stdin.on('end', () => { + if (pending.length !== 0) { + process.stderr.write('incomplete pool frame at stdin end\n'); + process.exit(2); + return; + } + if (!draining && queue.length === 0) { + process.exit(0); + } + }); +} + +if (poolMode) { + runPoolLoop(); +} else { + (async () => { + const input = JSON.parse(fs.readFileSync(0, 'utf8')); + const result = await runSingleRequest(input); + process.stdout.write(JSON.stringify(result)); + })().catch((error) => { + process.stderr.write(formatError(error) + '\n'); + process.exit(2); + }); +} diff --git a/internal/app/assets/browser-login-helper.cjs b/internal/app/assets/browser-login-helper.cjs new file mode 100644 index 0000000..4b63372 --- /dev/null +++ b/internal/app/assets/browser-login-helper.cjs @@ -0,0 +1,80 @@ +const fs = require('fs'); +const { fetch } = require('node-wreq'); + +(async () => { + const input = JSON.parse(fs.readFileSync(0, 'utf8')); + + const cookieMap = new Map(); + for (const item of input.cookies || []) { + const name = String((item && (item.name || item.Name)) || '').trim(); + if (!name) continue; + const rawValue = item && (item.value !== undefined ? item.value : item.Value); + cookieMap.set(name, String(rawValue == null ? '' : rawValue)); + } + const setCookieRecord = new Map(); + const cookieJar = { + getCookies() { + return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); + }, + setCookie(cookie) { + const text = String(cookie || ''); + const semi = text.indexOf(';'); + const pair = semi === -1 ? text : text.slice(0, semi); + const eq = pair.indexOf('='); + if (eq <= 0) return; + const name = pair.slice(0, eq).trim(); + const value = pair.slice(eq + 1).trim(); + if (!name) return; + cookieMap.set(name, value); + setCookieRecord.set(name, value); + }, + }; + + const headers = {}; + for (const [key, value] of Object.entries(input.headers || {})) { + if (key === undefined || key === null) continue; + if (String(key).toLowerCase() === 'cookie') continue; + headers[String(key)] = String(value == null ? '' : value); + } + + const method = String(input.method || 'GET').toUpperCase(); + const fetchOptions = { + method, + browser: input.browser_profile || 'chrome_142', + headers, + cookieJar, + timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), + throwHttpErrors: false, + }; + if (typeof input.body === 'string' && input.body.length > 0) { + fetchOptions.body = input.body; + } + const proxy = String(input.proxy || '').trim(); + if (proxy) fetchOptions.proxy = proxy; + + const result = { status: 0, content_type: '', headers: {}, body: '', set_cookies: [] }; + let response; + try { + response = await fetch(String(input.url || ''), fetchOptions); + } catch (err) { + process.stderr.write((err && err.stack ? err.stack : String(err)) + '\n'); + process.exit(2); + return; + } + + result.status = response.status; + if (response.headers && typeof response.headers.forEach === 'function') { + response.headers.forEach((value, key) => { + const lk = String(key).toLowerCase(); + if (lk === 'set-cookie') return; + result.headers[lk] = String(value); + }); + } + result.content_type = result.headers['content-type'] || ''; + result.body = await response.text(); + result.set_cookies = [...setCookieRecord.entries()].map(([name, value]) => ({ Name: name, Value: value })); + process.stdout.write(JSON.stringify(result)); +})().catch((error) => { + process.stderr.write((error && error.stack ? error.stack : String(error)) + '\n'); + process.exit(1); +}); diff --git a/internal/app/config.go b/internal/app/config.go index d54b314..202715f 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -47,6 +47,19 @@ type SessionRefreshConfig struct { AutoSwitch bool `json:"auto_switch_account"` } +type DispatchConfig struct { + ProbeCacheTTLSeconds int `json:"probe_cache_ttl_seconds,omitempty"` +} + +type BrowserConfig struct { + HelperPoolSize int `json:"helper_pool_size,omitempty"` +} + +type DebugConfig struct { + PprofEnabled bool `json:"pprof_enabled"` + PprofAddr string `json:"pprof_addr,omitempty"` +} + type StorageConfig struct { SQLitePath string `json:"sqlite_path,omitempty"` PersistConversations bool `json:"persist_conversations"` @@ -56,6 +69,10 @@ type StorageConfig struct { PersistSillyTavernBindings *bool `json:"persist_sillytavern_bindings,omitempty"` } +type LimitsConfig struct { + MaxRequestBodyBytes int64 `json:"max_request_body_bytes,omitempty"` +} + type PromptConfig struct { Profile string `json:"profile,omitempty"` CustomPrefix string `json:"custom_prefix,omitempty"` @@ -67,10 +84,12 @@ type PromptConfig struct { CodingRetryPrefixes []string `json:"coding_retry_prefixes,omitempty"` GeneralRetryPrefixes []string `json:"general_retry_prefixes,omitempty"` DirectAnswerRetryPrefixes []string `json:"direct_answer_retry_prefixes,omitempty"` + precomputedAllRetryPrefixes []string `json:"-"` } type NotionAccount struct { Email string `json:"email"` + emailKey string `json:"-"` ProbeJSON string `json:"probe_json,omitempty"` ProfileDir string `json:"profile_dir,omitempty"` StorageStatePath string `json:"storage_state_path,omitempty"` @@ -153,10 +172,14 @@ type AppConfig struct { Admin AdminConfig `json:"admin"` Responses ResponsesConfig `json:"responses"` Storage StorageConfig `json:"storage"` + Limits LimitsConfig `json:"limits,omitempty"` Prompt PromptConfig `json:"prompt"` Features FeatureConfig `json:"features"` LoginHelper LoginHelperConfig `json:"login_helper"` SessionRefresh SessionRefreshConfig `json:"session_refresh"` + Dispatch DispatchConfig `json:"dispatch"` + Browser BrowserConfig `json:"browser,omitempty"` + Debug DebugConfig `json:"debug"` Accounts []NotionAccount `json:"accounts,omitempty"` Models []ModelDefinition `json:"models,omitempty"` ModelAliases map[string]string `json:"model_aliases,omitempty"` @@ -418,6 +441,9 @@ func defaultConfig() AppConfig { Storage: StorageConfig{ PersistConversations: true, }, + Limits: LimitsConfig{ + MaxRequestBodyBytes: 4 * 1024 * 1024, + }, Prompt: PromptConfig{ Profile: "cognitive_reframing", FallbackProfiles: []string{"toolbox_capability_expansion"}, @@ -440,6 +466,13 @@ func defaultConfig() AppConfig { RetryOnAuthError: true, AutoSwitch: true, }, + Dispatch: DispatchConfig{ + ProbeCacheTTLSeconds: 45, + }, + Debug: DebugConfig{ + PprofEnabled: false, + PprofAddr: "127.0.0.1:6060", + }, Features: FeatureConfig{ UseWebSearch: true, UseReadOnlyMode: false, @@ -500,6 +533,10 @@ func normalizeConfig(cfg AppConfig) AppConfig { if cfg.PollMaxRounds <= 0 { cfg.PollMaxRounds = 40 } + cfg.Debug.PprofAddr = strings.TrimSpace(cfg.Debug.PprofAddr) + if cfg.Debug.PprofAddr == "" { + cfg.Debug.PprofAddr = "127.0.0.1:6060" + } if cfg.StreamChunkRunes <= 0 { cfg.StreamChunkRunes = 24 } @@ -512,6 +549,9 @@ func normalizeConfig(cfg AppConfig) AppConfig { if cfg.Responses.StoreTTLSeconds <= 0 { cfg.Responses.StoreTTLSeconds = 3600 } + if cfg.Limits.MaxRequestBodyBytes <= 0 { + cfg.Limits.MaxRequestBodyBytes = 4 * 1024 * 1024 + } cfg.Prompt.Profile = strings.TrimSpace(cfg.Prompt.Profile) if cfg.Prompt.Profile == "" { cfg.Prompt.Profile = "cognitive_reframing" @@ -535,6 +575,7 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.Prompt.CodingRetryPrefixes = normalizePromptTextList(cfg.Prompt.CodingRetryPrefixes) cfg.Prompt.GeneralRetryPrefixes = normalizePromptTextList(cfg.Prompt.GeneralRetryPrefixes) cfg.Prompt.DirectAnswerRetryPrefixes = normalizePromptTextList(cfg.Prompt.DirectAnswerRetryPrefixes) + cfg.Prompt.precomputedAllRetryPrefixes = buildPromptGuardAllRetryPrefixes(cfg.Prompt) cfg.Storage.SQLitePath = strings.TrimSpace(cfg.Storage.SQLitePath) if cfg.Storage.SQLitePath == "" && strings.TrimSpace(cfg.ConfigPath) != "" { cfg.Storage.SQLitePath = "data/notion2api.sqlite" @@ -548,6 +589,15 @@ func normalizeConfig(cfg AppConfig) AppConfig { if cfg.SessionRefresh.IntervalSec <= 0 { cfg.SessionRefresh.IntervalSec = 900 } + if cfg.Dispatch.ProbeCacheTTLSeconds < 0 { + cfg.Dispatch.ProbeCacheTTLSeconds = 0 + } + if cfg.Browser.HelperPoolSize < 0 { + cfg.Browser.HelperPoolSize = 0 + } + if cfg.Browser.HelperPoolSize > 8 { + cfg.Browser.HelperPoolSize = 8 + } cfg.Features.SearchScopes = normalizeStringList(cfg.Features.SearchScopes) cfg.Features.AISurface = strings.TrimSpace(cfg.Features.AISurface) if cfg.Features.AISurface == "" { @@ -566,6 +616,7 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.ActiveAccount = strings.TrimSpace(cfg.ActiveAccount) for i := range cfg.Accounts { cfg.Accounts[i].Email = strings.TrimSpace(cfg.Accounts[i].Email) + cfg.Accounts[i].emailKey = canonicalEmailKey(cfg.Accounts[i].Email) cfg.Accounts[i].ProbeJSON = strings.TrimSpace(cfg.Accounts[i].ProbeJSON) cfg.Accounts[i].ProfileDir = strings.TrimSpace(cfg.Accounts[i].ProfileDir) cfg.Accounts[i].StorageStatePath = strings.TrimSpace(cfg.Accounts[i].StorageStatePath) @@ -797,6 +848,9 @@ func parseCLI() AppConfig { timeoutSec := flag.Int("timeout-sec", 0, "request timeout sec") pollIntervalSec := flag.Float64("poll-interval-sec", 0, "poll interval sec") pollMaxRounds := flag.Int("poll-max-rounds", 0, "poll max rounds") + pprofEnabled := flag.Bool("pprof-enabled", false, "enable pprof debug server") + pprofAddr := flag.String("pprof-addr", "", "pprof listen address") + maxRequestBodyBytes := flag.Int64("max-request-body-bytes", 0, "max request body size in bytes for JSON API endpoints") userName := flag.String("user-name", "", "override user name") spaceName := flag.String("space-name", "", "override space name") flag.Parse() @@ -870,6 +924,15 @@ func parseCLI() AppConfig { if *pollMaxRounds > 0 { cfg.PollMaxRounds = *pollMaxRounds } + if *pprofEnabled { + cfg.Debug.PprofEnabled = true + } + if strings.TrimSpace(*pprofAddr) != "" { + cfg.Debug.PprofAddr = strings.TrimSpace(*pprofAddr) + } + if *maxRequestBodyBytes > 0 { + cfg.Limits.MaxRequestBodyBytes = *maxRequestBodyBytes + } if strings.TrimSpace(*userName) != "" { cfg.UserName = *userName } diff --git a/internal/app/conversations.go b/internal/app/conversations.go index affb6d7..767a529 100644 --- a/internal/app/conversations.go +++ b/internal/app/conversations.go @@ -58,6 +58,7 @@ type ConversationEntry struct { InputAttachments []ConversationAttachment `json:"input_attachments,omitempty"` OutputAttachments []UploadedAttachment `json:"output_attachments,omitempty"` Messages []ConversationMessage `json:"messages,omitempty"` + cachedPreview string `json:"-"` } type ConversationSummary struct { @@ -133,6 +134,7 @@ func newConversationStoreFromEntries(entries []ConversationEntry) *ConversationS store := newConversationStore() for _, entry := range entries { cloned := cloneConversationEntry(&entry) + refreshConversationDerivedFields(&cloned) store.items[cloned.ID] = &cloned store.order = append(store.order, cloned.ID) } @@ -285,17 +287,17 @@ func cloneConversationEntry(entry *ConversationEntry) ConversationEntry { return out } +func copyConversationEntryValue(entry *ConversationEntry) ConversationEntry { + if entry == nil { + return ConversationEntry{} + } + return *entry +} + func buildConversationSummary(entry *ConversationEntry) ConversationSummary { - preview := "" - for i := len(entry.Messages) - 1; i >= 0; i-- { - text := collapseWhitespace(entry.Messages[i].Content) - if text == "" && len(entry.Messages[i].Attachments) > 0 { - text = fmt.Sprintf("%d attachments", len(entry.Messages[i].Attachments)) - } - if text != "" { - preview = truncateRunes(text, 96) - break - } + preview := entry.cachedPreview + if preview == "" && len(entry.Messages) > 0 { + preview = conversationPreviewFromMessages(entry.Messages) } return ConversationSummary{ ID: entry.ID, @@ -327,6 +329,26 @@ func buildConversationSummary(entry *ConversationEntry) ConversationSummary { } } +func conversationPreviewFromMessages(messages []ConversationMessage) string { + for i := len(messages) - 1; i >= 0; i-- { + text := collapseWhitespace(messages[i].Content) + if text == "" && len(messages[i].Attachments) > 0 { + text = fmt.Sprintf("%d attachments", len(messages[i].Attachments)) + } + if text != "" { + return truncateRunes(text, 96) + } + } + return "" +} + +func refreshConversationDerivedFields(entry *ConversationEntry) { + if entry == nil { + return + } + entry.cachedPreview = conversationPreviewFromMessages(entry.Messages) +} + func conversationMessageSegments(entry *ConversationEntry) []conversationPromptSegment { if entry == nil || len(entry.Messages) == 0 { return nil @@ -410,7 +432,7 @@ func (s *ConversationStore) Create(req ConversationCreateRequest) ConversationEn if id == "" { id = "conv_" + strings.ReplaceAll(randomUUID(), "-", "") } - entry := &ConversationEntry{ + entry := ConversationEntry{ ID: id, Title: conversationTitle(req.Prompt, req.InputAttachments), Origin: "local", @@ -440,17 +462,19 @@ func (s *ConversationStore) Create(req ConversationCreateRequest) ConversationEn Attachments: cloneConversationAttachments(entry.InputAttachments), }) } + refreshConversationDerivedFields(&entry) s.mu.Lock() if s.items[id] != nil { id = "conv_" + strings.ReplaceAll(randomUUID(), "-", "") entry.ID = id } - s.items[id] = entry + entryPtr := &entry + s.items[id] = entryPtr s.order = append([]string{id}, s.order...) s.trimLocked() - cloned := cloneConversationEntry(entry) - summary := buildConversationSummary(entry) + cloned := copyConversationEntryValue(entryPtr) + summary := buildConversationSummary(entryPtr) s.mu.Unlock() s.broadcast(ConversationEvent{ @@ -458,7 +482,7 @@ func (s *ConversationStore) Create(req ConversationCreateRequest) ConversationEn ConversationID: id, At: now, Summary: &summary, - Conversation: &cloned, + Conversation: entryPtr, }) return cloned } @@ -469,39 +493,41 @@ func (s *ConversationStore) Continue(conversationID string, req ConversationCrea cloned ConversationEntry summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - entry.Source = firstNonEmpty(req.Source, entry.Source) - entry.Transport = firstNonEmpty(req.Transport, entry.Transport) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + next.Source = firstNonEmpty(req.Source, next.Source) + next.Transport = firstNonEmpty(req.Transport, next.Transport) if req.Ephemeral { - entry.Ephemeral = true - entry.EphemeralReason = firstNonEmpty(strings.TrimSpace(req.EphemeralReason), entry.EphemeralReason) + next.Ephemeral = true + next.EphemeralReason = firstNonEmpty(strings.TrimSpace(req.EphemeralReason), next.EphemeralReason) if !req.AutoDeleteAt.IsZero() { - entry.AutoDeleteAt = timePointer(req.AutoDeleteAt) + next.AutoDeleteAt = timePointer(req.AutoDeleteAt) } } if clean := strings.TrimSpace(req.Model); clean != "" { - entry.Model = clean + next.Model = clean } if clean := strings.TrimSpace(req.NotionModel); clean != "" { - entry.NotionModel = clean + next.NotionModel = clean } - entry.UseWebSearch = req.UseWebSearch - entry.Status = "running" - entry.Error = "" - entry.InputAttachments = cloneConversationAttachments(req.InputAttachments) - entry.UpdatedAt = now - if len(entry.Messages) > 0 { - last := &entry.Messages[len(entry.Messages)-1] + next.UseWebSearch = req.UseWebSearch + next.Status = "running" + next.Error = "" + next.InputAttachments = cloneConversationAttachments(req.InputAttachments) + next.UpdatedAt = now + if len(next.Messages) > 0 { + last := &next.Messages[len(next.Messages)-1] if last.Role == "assistant" && last.Status != "completed" { last.Status = "failed" last.UpdatedAt = now } } if strings.TrimSpace(req.Prompt) != "" || len(req.InputAttachments) > 0 { - entry.Messages = append(entry.Messages, ConversationMessage{ + next.Messages = append(next.Messages, ConversationMessage{ ID: "msg_user_" + strings.ReplaceAll(randomUUID(), "-", ""), Role: "user", Status: "completed", @@ -511,8 +537,11 @@ func (s *ConversationStore) Continue(conversationID string, req ConversationCrea Attachments: cloneConversationAttachments(req.InputAttachments), }) } + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) - cloned = cloneConversationEntry(entry) + cloned = copyConversationEntryValue(entry) summary = buildConversationSummary(entry) ok = true } @@ -525,7 +554,7 @@ func (s *ConversationStore) Continue(conversationID string, req ConversationCrea ConversationID: conversationID, At: now, Summary: &summary, - Conversation: &cloned, + Conversation: entry, }) return cloned, nil } @@ -553,17 +582,22 @@ func (s *ConversationStore) SetEnvelopeIDs(conversationID string, responseID str var ( summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) if strings.TrimSpace(responseID) != "" { - entry.ResponseID = strings.TrimSpace(responseID) + next.ResponseID = strings.TrimSpace(responseID) } if strings.TrimSpace(completionID) != "" { - entry.CompletionID = strings.TrimSpace(completionID) + next.CompletionID = strings.TrimSpace(completionID) } - entry.UpdatedAt = now + next.UpdatedAt = now + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) summary = buildConversationSummary(entry) ok = true @@ -575,6 +609,7 @@ func (s *ConversationStore) SetEnvelopeIDs(conversationID string, responseID str ConversationID: conversationID, At: now, Summary: &summary, + Conversation: entry, }) } } @@ -587,21 +622,26 @@ func (s *ConversationStore) AppendAssistantDelta(conversationID string, delta st now := time.Now().UTC() var ( summary ConversationSummary - msg ConversationMessage + msg *ConversationMessage ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - assistant := s.ensureAssistantMessageLocked(entry, now) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + assistant := s.ensureAssistantMessageLocked(&next, now) assistant.Content += delta assistant.Status = "streaming" assistant.UpdatedAt = now - entry.Status = "running" - entry.UpdatedAt = now + next.Status = "running" + next.UpdatedAt = now + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) summary = buildConversationSummary(entry) - msg = cloneConversationMessage(*assistant) + msg = assistant ok = true } s.mu.Unlock() @@ -612,7 +652,8 @@ func (s *ConversationStore) AppendAssistantDelta(conversationID string, delta st At: now, Delta: delta, Summary: &summary, - Message: &msg, + Conversation: entry, + Message: msg, }) } } @@ -620,46 +661,49 @@ func (s *ConversationStore) AppendAssistantDelta(conversationID string, delta st func (s *ConversationStore) Complete(conversationID string, result InferenceResult) { now := time.Now().UTC() var ( - cloned ConversationEntry summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - entry.Status = "completed" - entry.UpdatedAt = now - if entry.Ephemeral { - entry.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + next.Status = "completed" + next.UpdatedAt = now + if next.Ephemeral { + next.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) } - entry.ThreadID = strings.TrimSpace(result.ThreadID) - entry.TraceID = strings.TrimSpace(result.TraceID) - entry.MessageID = strings.TrimSpace(result.MessageID) - entry.AccountEmail = strings.TrimSpace(result.AccountEmail) - entry.Error = "" - entry.OutputAttachments = cloneUploadedAttachments(result.Attachments) - assistant := s.ensureAssistantMessageLocked(entry, now) + next.ThreadID = strings.TrimSpace(result.ThreadID) + next.TraceID = strings.TrimSpace(result.TraceID) + next.MessageID = strings.TrimSpace(result.MessageID) + next.AccountEmail = strings.TrimSpace(result.AccountEmail) + next.Error = "" + next.OutputAttachments = cloneUploadedAttachments(result.Attachments) + assistant := s.ensureAssistantMessageLocked(&next, now) assistant.Status = "completed" assistant.Content = sanitizeAssistantVisibleText(result.Text) assistant.Attachments = summarizeUploadedAttachments(result.Attachments) assistant.UpdatedAt = now - if len(entry.Messages) > 0 { - entry.Messages[len(entry.Messages)-1] = cloneConversationMessage(*assistant) + if len(next.Messages) > 0 { + next.Messages[len(next.Messages)-1] = cloneConversationMessage(*assistant) } + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) - cloned = cloneConversationEntry(entry) summary = buildConversationSummary(entry) ok = true } s.mu.Unlock() if ok { s.broadcast(ConversationEvent{ - Type: "conversation.completed", - ConversationID: conversationID, - At: now, - Summary: &summary, - Conversation: &cloned, - }) + Type: "conversation.completed", + ConversationID: conversationID, + At: now, + Summary: &summary, + Conversation: entry, + }) } } @@ -670,41 +714,44 @@ func (s *ConversationStore) Fail(conversationID string, err error) { now := time.Now().UTC() message := strings.TrimSpace(err.Error()) var ( - cloned ConversationEntry summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - entry.Status = "failed" - entry.Error = message - entry.UpdatedAt = now - if entry.Ephemeral { - entry.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + next.Status = "failed" + next.Error = message + next.UpdatedAt = now + if next.Ephemeral { + next.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) } - if len(entry.Messages) > 0 { - last := &entry.Messages[len(entry.Messages)-1] + if len(next.Messages) > 0 { + last := &next.Messages[len(next.Messages)-1] if last.Role == "assistant" && last.Status != "completed" { last.Status = "failed" last.UpdatedAt = now } } + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) - cloned = cloneConversationEntry(entry) summary = buildConversationSummary(entry) ok = true } s.mu.Unlock() if ok { s.broadcast(ConversationEvent{ - Type: "conversation.failed", - ConversationID: conversationID, - At: now, - Error: message, - Summary: &summary, - Conversation: &cloned, - }) + Type: "conversation.failed", + ConversationID: conversationID, + At: now, + Error: message, + Summary: &summary, + Conversation: entry, + }) } } @@ -761,7 +808,7 @@ func (s *ConversationStore) ListExpiredEphemeral(now time.Time, limit int) []Con if entry.AutoDeleteAt == nil || entry.AutoDeleteAt.After(now) { continue } - items = append(items, cloneConversationEntry(entry)) + items = append(items, copyConversationEntryValue(entry)) if len(items) >= limit { break } @@ -776,8 +823,7 @@ func (s *ConversationStore) Get(conversationID string) (ConversationEntry, bool) if entry == nil { return ConversationEntry{}, false } - cloned := cloneConversationEntry(entry) - return cloned, true + return copyConversationEntryValue(entry), true } func (s *ConversationStore) FindByThreadID(threadID string) (ConversationEntry, bool) { @@ -795,8 +841,7 @@ func (s *ConversationStore) FindByThreadID(threadID string) (ConversationEntry, if strings.TrimSpace(entry.ThreadID) != threadID { continue } - cloned := cloneConversationEntry(entry) - return cloned, true + return copyConversationEntryValue(entry), true } return ConversationEntry{}, false } @@ -820,8 +865,7 @@ func (s *ConversationStore) FindContinuationBySegments(history []conversationPro if !conversationSegmentsMatchSuffix(entrySegments, normalizedHistory) { continue } - cloned := cloneConversationEntry(entry) - return cloned, true + return copyConversationEntryValue(entry), true } return ConversationEntry{}, false } @@ -881,16 +925,18 @@ func (s *ServerState) deleteResponsesByConversationOrThread(conversationID strin return } s.mu.Lock() - for id, item := range s.ResponsesByID { - if (conversationID != "" && strings.TrimSpace(item.ConversationID) == conversationID) || - (threadID != "" && strings.TrimSpace(item.ThreadID) == threadID) { - delete(s.ResponsesByID, id) - } + if s.ResponseStore != nil { + s.ResponseStore.deleteByConversationOrThread(conversationID, threadID) } + sqliteWriter := s.sqliteWriter store := s.Store storeEnabled := store != nil && responsesPersistenceEnabled(s.Config) s.mu.Unlock() if storeEnabled { + if sqliteWriter != nil { + sqliteWriter.EnqueueDeleteResponsesByConversationOrThread(conversationID, threadID) + return + } if err := store.DeleteResponsesByConversationOrThread(conversationID, threadID); err != nil { log.Printf("[sqlite] delete responses conversation=%s thread=%s failed: %v", conversationID, threadID, err) } diff --git a/internal/app/httpclient_audit.md b/internal/app/httpclient_audit.md new file mode 100644 index 0000000..34880ef --- /dev/null +++ b/internal/app/httpclient_audit.md @@ -0,0 +1,69 @@ +# T-4-2 HTTP Client / Transport Audit + +Date: 2026-05-02 + +## Scope checked + +- `internal/app/notion_client.go` +- `internal/app/login_helper.go` +- `internal/app/notion_client_login_transport.go` +- `internal/app/account_discovery.go` +- `internal/app/session_refresh.go` + +## Findings + +### 1) NotionAI request path creates new `http.Client`/`Transport` per `NotionAIClient` + +- Location: `internal/app/notion_client.go:newNotionAIClientWithMode` +- Behavior: + - Builds a fresh `http.Transport` and `http.Client` every time a `NotionAIClient` is created. + - In dispatch paths (`runPromptWithSession*`) this can happen frequently, so connection pools are not reused across those client instances. +- Impact: + - Potentially higher connect/TLS handshake overhead under sustained traffic. + - Extra pressure on upstream and local sockets due to fragmented pools. + +### 2) Login helper path also creates fresh `http.Client` + +- Location: `internal/app/login_helper.go:newNotionLoginSession` +- Behavior: + - Creates a new cookie jar and `http.Client` per login session call. +- Notes: + - This path is less hot than inference path, but still relevant for repeated refresh/login workflows. + +### 3) Proxy/header behavior correctness constraints + +- Proxy resolution and resin headers are request/account dependent: + - `ProxyResolver.ResolveProxyForRequest(accountEmail, targetURL)` can vary by account/policy. + - `postJSONResponse` overlays per-request proxy headers (e.g. resin account header). +- Any reuse strategy must preserve: + - account-aware proxy resolution + - per-request header injection behavior + - stream vs non-stream timeout difference + +## Recommendation + +Introduce a transport cache in `internal/app/notion_client.go`: + +- Cache key dimensions: + - normalized upstream base/origin/host/tls server name + - proxy mode + proxy urls + resin settings + - account email key (for account-specific proxy routing) + - streaming flag is **not** required in transport key (timeout is on `http.Client`, not transport) +- Cache value: + - reusable `*http.Transport` +- Then construct short-lived `http.Client` wrappers over cached transport: + - standard client timeout = request timeout + - streaming client timeout = 0 +- Add evidence: + - metric for transport/client creation count + - benchmark around repeated client creation path if needed + +## Current status + +- Updated 2026-05-02 follow-up: + - Implemented transport cache in `newNotionAIClientWithMode` via keyed map + RWMutex. + - Added runtime visibility metric: `notion2api_http_transport_cache_total` (`hit_rlock`, `hit_lock`, `miss_new`) exposed through `/debug/vars`. + - Added tests validating: + - same account/config => transport reuse + - different account proxy policy => transport separation + - Added benchmark `BenchmarkNewNotionAIClientWithModeTransportCache` showing warm-cache path lower alloc/op and ns/op than forced cold-cache path. diff --git a/internal/app/main.go b/internal/app/main.go index 5e96d4a..278b4bc 100644 --- a/internal/app/main.go +++ b/internal/app/main.go @@ -1,14 +1,20 @@ package app import ( + "bytes" "context" "encoding/json" + "errors" + "expvar" "fmt" + "io" "log" "net/http" + _ "net/http/pprof" "runtime/debug" "strings" "sync" + "sync/atomic" "time" ) @@ -20,21 +26,35 @@ type StoredResponse struct { AccountEmail string } +type snapshotBundle struct { + Config AppConfig + Session SessionInfo + ModelRegistry ModelRegistry + DispatchOrder []NotionAccount +} + type ServerState struct { - mu sync.RWMutex - refreshMu sync.Mutex - Config AppConfig - Session SessionInfo - Client *NotionAIClient - Store *SQLiteStore - ModelRegistry ModelRegistry - ResponsesByID map[string]StoredResponse - Conversations *ConversationStore - AdminTokens map[string]time.Time - AdminLoginAttempts map[string]AdminLoginAttempt - AccountDispatchSlots map[string]accountDispatchState - LastSessionRefresh time.Time - LastSessionRefreshError string + mu sync.RWMutex + refreshMu sync.Mutex + Config AppConfig + Session SessionInfo + Client *NotionAIClient + Store *SQLiteStore + ModelRegistry ModelRegistry + ResponseStore *responseStore + Conversations *ConversationStore + AdminTokens map[string]time.Time + AdminLoginAttempts map[string]AdminLoginAttempt + DispatchProbeCache *probeCache + LastSessionRefresh time.Time + LastSessionRefreshError string + responseStoreCleanupCancel context.CancelFunc + sqliteWriter *SQLiteWriter + snap atomic.Pointer[snapshotBundle] + slots atomic.Pointer[map[string]*accountSlot] + cachedHealthzStaticJSON atomic.Pointer[[]byte] + cachedModelsListJSON atomic.Pointer[[]byte] + cachedModelByIDJSON atomic.Pointer[map[string][]byte] } type accountDispatchState struct { @@ -42,6 +62,38 @@ type accountDispatchState struct { InFlight int } +type accountSlot struct { + max atomic.Int32 + inflight atomic.Int32 +} + +type healthzStaticPayload struct { + OK bool `json:"ok"` + DefaultModel string `json:"default_model"` + ModelCount int `json:"model_count"` + UserEmail string `json:"user_email"` + SpaceID string `json:"space_id"` + ActiveAccount string `json:"active_account"` + SessionRefreshEnable bool `json:"session_refresh_enabled"` +} + +type publicModelPayload struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Name string `json:"name"` + Family string `json:"family"` + Group string `json:"group"` + Beta bool `json:"beta"` + NotionModel string `json:"notion_model"` +} + +type publicModelsListPayload struct { + Object string `json:"object"` + Data []publicModelPayload `json:"data"` +} + type App struct { State *ServerState runPromptOverride func(*http.Request, PromptRunRequest) (InferenceResult, error) @@ -56,8 +108,15 @@ const ( ephemeralConversationCleanupInterval = time.Minute ephemeralConversationCleanupBatchSize = 24 sillyTavernQuietConversationTTL = 10 * time.Minute + corsAllowOrigin = "*" + corsAllowHeaders = "Authorization, Content-Type, X-Admin-Token" + corsAllowMethods = "GET, POST, PUT, DELETE, OPTIONS" ) +var errRequestTooLarge = errors.New("request body too large") +var responseStorePruneTotalMetric = expvar.NewMap("notion2api_response_store_prune_total") +var testHookResponseStoreCleanupInterval time.Duration + type continuationTarget struct { Conversation ConversationEntry Session *conversationContinuationState @@ -112,28 +171,68 @@ func normalizeAccountMaxConcurrency(raw int) int { return raw } -func (s *ServerState) initializeAccountDispatchSlotsLocked() { - if s.AccountDispatchSlots == nil { - s.AccountDispatchSlots = map[string]accountDispatchState{} +func clampSlotInFlight(slot *accountSlot, max int32) int32 { + if slot == nil { + return 0 + } + if max <= 0 { + max = 1 + } + for { + current := slot.inflight.Load() + if current < 0 { + if slot.inflight.CompareAndSwap(current, 0) { + return 0 + } + continue + } + if current <= max { + return current + } + if slot.inflight.CompareAndSwap(current, max) { + return max + } } - next := map[string]accountDispatchState{} +} + +func (s *ServerState) rebuildAccountSlotsLocked() { + if s == nil { + return + } + var previous map[string]*accountSlot + if loaded := s.slots.Load(); loaded != nil { + previous = *loaded + } + next := make(map[string]*accountSlot, len(s.Config.Accounts)) for _, account := range s.Config.Accounts { - emailKey := canonicalEmailKey(account.Email) + emailKey := getAccountEmailKey(account) if emailKey == "" { continue } - maxConcurrency := normalizeAccountMaxConcurrency(account.MaxConcurrency) - state := s.AccountDispatchSlots[emailKey] - state.MaxConcurrency = maxConcurrency - if state.InFlight < 0 { - state.InFlight = 0 - } - if state.InFlight > state.MaxConcurrency { - state.InFlight = state.MaxConcurrency + maxConcurrency := int32(normalizeAccountMaxConcurrency(account.MaxConcurrency)) + if existing := previous[emailKey]; existing != nil { + existing.max.Store(maxConcurrency) + clampSlotInFlight(existing, maxConcurrency) + next[emailKey] = existing + continue } - next[emailKey] = state + slot := &accountSlot{} + slot.max.Store(maxConcurrency) + next[emailKey] = slot + } + s.slots.Store(&next) + syncDispatchSlotInflightFromSlots(next) +} + +func (s *ServerState) loadAccountSlots() map[string]*accountSlot { + if s == nil { + return nil } - s.AccountDispatchSlots = next + loaded := s.slots.Load() + if loaded == nil { + return nil + } + return *loaded } func (s *ServerState) TryAcquireAccountDispatchSlot(email string) bool { @@ -141,19 +240,24 @@ func (s *ServerState) TryAcquireAccountDispatchSlot(email string) bool { if emailKey == "" { return false } - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := s.loadAccountSlots()[emailKey] + if slot == nil { return false } - if state.InFlight >= state.MaxConcurrency { - return false + for { + maxConcurrency := slot.max.Load() + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := slot.inflight.Load() + if inflight >= maxConcurrency { + return false + } + if slot.inflight.CompareAndSwap(inflight, inflight+1) { + setDispatchSlotInflight(emailKey, int(inflight+1)) + return true + } } - state.InFlight++ - s.AccountDispatchSlots[emailKey] = state - return true } func (s *ServerState) ReleaseAccountDispatchSlot(email string) { @@ -161,19 +265,21 @@ func (s *ServerState) ReleaseAccountDispatchSlot(email string) { if emailKey == "" { return } - s.mu.Lock() - defer s.mu.Unlock() - if s.AccountDispatchSlots == nil { - return - } - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := s.loadAccountSlots()[emailKey] + if slot == nil { return } - if state.InFlight > 0 { - state.InFlight-- + for { + inflight := slot.inflight.Load() + if inflight <= 0 { + setDispatchSlotInflight(emailKey, 0) + return + } + if slot.inflight.CompareAndSwap(inflight, inflight-1) { + setDispatchSlotInflight(emailKey, int(inflight-1)) + return + } } - s.AccountDispatchSlots[emailKey] = state } func (s *ServerState) RemainingAccountDispatchSlots(email string) int { @@ -181,24 +287,27 @@ func (s *ServerState) RemainingAccountDispatchSlots(email string) int { if emailKey == "" { return 0 } - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := s.loadAccountSlots()[emailKey] + if slot == nil { return 0 } - remaining := state.MaxConcurrency - state.InFlight + maxConcurrency := slot.max.Load() + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := slot.inflight.Load() + remaining := int(maxConcurrency - inflight) if remaining < 0 { - remaining = 0 + return 0 } return remaining } func (s *ServerState) AvailableDispatchCapacity(emails []string) int { - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() + slots := s.loadAccountSlots() + if len(slots) == 0 { + return 0 + } total := 0 seen := map[string]struct{}{} for _, email := range emails { @@ -210,11 +319,16 @@ func (s *ServerState) AvailableDispatchCapacity(emails []string) int { continue } seen[emailKey] = struct{}{} - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := slots[emailKey] + if slot == nil { continue } - remaining := state.MaxConcurrency - state.InFlight + maxConcurrency := slot.max.Load() + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := slot.inflight.Load() + remaining := int(maxConcurrency - inflight) if remaining > 0 { total += remaining } @@ -223,17 +337,31 @@ func (s *ServerState) AvailableDispatchCapacity(emails []string) int { } func (s *ServerState) AccountDispatchSnapshot() map[string]accountDispatchState { - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() - out := make(map[string]accountDispatchState, len(s.AccountDispatchSlots)) - for key, value := range s.AccountDispatchSlots { - out[key] = value + slots := s.loadAccountSlots() + out := make(map[string]accountDispatchState, len(slots)) + for key, slot := range slots { + if slot == nil { + continue + } + maxConcurrency := int(slot.max.Load()) + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := int(slot.inflight.Load()) + if inflight < 0 { + inflight = 0 + } + if inflight > maxConcurrency { + inflight = maxConcurrency + } + out[key] = accountDispatchState{ + MaxConcurrency: maxConcurrency, + InFlight: inflight, + } } return out } - func maxFloat(a float64, b float64) float64 { if a > b { return a @@ -265,13 +393,13 @@ func newServerState(cfg AppConfig) (*ServerState, error) { return nil, err } state := &ServerState{ - ResponsesByID: map[string]StoredResponse{}, Conversations: newConversationStore(), AdminTokens: map[string]time.Time{}, AdminLoginAttempts: map[string]AdminLoginAttempt{}, - AccountDispatchSlots: map[string]accountDispatchState{}, + DispatchProbeCache: newProbeCache(), Store: store, } + state.ResponseStore = newResponseStore(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) persistedAccountsLoaded := false if store != nil { accounts, activeAccount, ok, loadErr := store.LoadAccounts() @@ -305,7 +433,10 @@ func newServerState(cfg AppConfig) (*ServerState, error) { _ = store.Close() return nil, loadErr } - state.ResponsesByID = responses + if state.ResponseStore == nil { + state.ResponseStore = newResponseStore(time.Duration(maxInt(state.Config.Responses.StoreTTLSeconds, 1)) * time.Second) + } + state.ResponseStore.replaceAll(responses) } if conversationSnapshotsPersistenceEnabled(state.Config) { conversations, loadErr := store.LoadConversations() @@ -321,7 +452,9 @@ func newServerState(cfg AppConfig) (*ServerState, error) { return nil, saveErr } } + state.sqliteWriter = newSQLiteWriter(store, time.Duration(maxInt(state.Config.Responses.StoreTTLSeconds, 1))*time.Second) } + state.startResponseStoreCleanupLoop(context.Background()) return state, nil } @@ -352,15 +485,42 @@ func (s *ServerState) ApplyConfig(cfg AppConfig) error { s.Session = session s.ModelRegistry = registry s.Client = client + if s.sqliteWriter != nil { + s.sqliteWriter.SetTTL(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) + } + s.rebuildAccountSlotsLocked() + s.updateSnapshotBundleLocked() + s.rebuildStaticJSONCachesLocked() return nil } func (s *ServerState) Snapshot() (AppConfig, SessionInfo, ModelRegistry) { + if s == nil { + return AppConfig{}, SessionInfo{}, ModelRegistry{} + } + if snap := s.snap.Load(); snap != nil { + return snap.Config, snap.Session, snap.ModelRegistry + } s.mu.RLock() defer s.mu.RUnlock() return s.Config, s.Session, s.ModelRegistry } +func (s *ServerState) updateSnapshotBundleLocked() { + if s == nil { + return + } + now := time.Now() + dispatchOrder := buildDispatchCandidateOrder(s.Config, now) + bundle := &snapshotBundle{ + Config: s.Config, + Session: s.Session, + ModelRegistry: s.ModelRegistry, + DispatchOrder: dispatchOrder, + } + s.snap.Store(bundle) +} + func (s *ServerState) SaveAndApply(cfg AppConfig) error { cfg = normalizeConfig(cfg) if err := validateConfiguredAPIKey(cfg); err != nil { @@ -382,6 +542,17 @@ func (s *ServerState) SaveAndApply(cfg AppConfig) error { return err } } + s.mu.Lock() + if s.ResponseStore == nil { + s.ResponseStore = newResponseStore(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) + } else { + s.ResponseStore.setTTL(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) + } + s.updateSnapshotBundleLocked() + s.mu.Unlock() + if canonicalEmailKey(current.ActiveAccount) != canonicalEmailKey(cfg.ActiveAccount) && s.DispatchProbeCache != nil { + s.DispatchProbeCache.invalidateAll() + } return nil } @@ -394,16 +565,6 @@ func (s *ServerState) conversationPersistenceStore() *SQLiteStore { return s.Store } -func (s *ServerState) cleanupExpiredResponsesLocked(now time.Time) { - ttlSeconds := maxInt(s.Config.Responses.StoreTTLSeconds, 1) - ttl := time.Duration(ttlSeconds) * time.Second - for id, item := range s.ResponsesByID { - if now.Sub(item.CreatedAt) > ttl { - delete(s.ResponsesByID, id) - } - } -} - func (s *ServerState) saveResponse(responseID string, payload map[string]any, conversationID string, threadID string) { s.saveResponseWithAccount(responseID, payload, conversationID, threadID, "") } @@ -411,24 +572,33 @@ func (s *ServerState) saveResponse(responseID string, payload map[string]any, co func (s *ServerState) saveResponseWithAccount(responseID string, payload map[string]any, conversationID string, threadID string, accountEmail string) { now := time.Now().UTC() s.mu.Lock() - s.cleanupExpiredResponsesLocked(now) - s.ResponsesByID[responseID] = StoredResponse{ + store := s.ResponseStore + if store == nil { + store = newResponseStore(time.Duration(maxInt(s.Config.Responses.StoreTTLSeconds, 1)) * time.Second) + s.ResponseStore = store + } + store.save(responseID, StoredResponse{ Payload: payload, CreatedAt: now, ConversationID: strings.TrimSpace(conversationID), ThreadID: strings.TrimSpace(threadID), AccountEmail: strings.TrimSpace(accountEmail), - } - store := s.Store + }, now) + sqliteWriter := s.sqliteWriter + sqliteStore := s.Store ttl := time.Duration(maxInt(s.Config.Responses.StoreTTLSeconds, 1)) * time.Second - storeEnabled := store != nil && responsesPersistenceEnabled(s.Config) + storeEnabled := sqliteStore != nil && responsesPersistenceEnabled(s.Config) s.mu.Unlock() if storeEnabled { - if err := store.SaveResponse(responseID, payload, now, conversationID, threadID, accountEmail); err != nil { + if sqliteWriter != nil { + sqliteWriter.EnqueueSaveResponse(responseID, payload, now, conversationID, threadID, accountEmail) + return + } + if err := sqliteStore.SaveResponse(responseID, payload, now, conversationID, threadID, accountEmail); err != nil { log.Printf("[sqlite] save response %s failed: %v", responseID, err) return } - if err := store.DeleteExpiredResponses(ttl); err != nil { + if err := sqliteStore.DeleteExpiredResponses(ttl); err != nil { log.Printf("[sqlite] cleanup responses failed: %v", err) } } @@ -445,12 +615,10 @@ func (s *ServerState) getResponse(responseID string) (map[string]any, bool) { func (s *ServerState) getStoredResponse(responseID string) (StoredResponse, bool) { s.mu.Lock() defer s.mu.Unlock() - s.cleanupExpiredResponsesLocked(time.Now()) - payload, ok := s.ResponsesByID[responseID] - if !ok { + if s.ResponseStore == nil { return StoredResponse{}, false } - return payload, true + return s.ResponseStore.get(responseID, time.Now().UTC()) } func (s *ServerState) loadConversationContinuationStateByConversationID(conversationID string) (*conversationContinuationState, error) { @@ -548,24 +716,204 @@ func (s *ServerState) invalidateConversationSession(sessionID string, status str func (s *ServerState) Close() error { s.mu.RLock() store := s.Store + cancelCleanup := s.responseStoreCleanupCancel + sqliteWriter := s.sqliteWriter s.mu.RUnlock() + if cancelCleanup != nil { + cancelCleanup() + } + if sqliteWriter != nil { + sqliteWriter.Close() + } if store == nil { return nil } return store.Close() } +func (s *ServerState) startResponseStoreCleanupLoop(parent context.Context) { + if s == nil { + return + } + if parent == nil { + parent = context.Background() + } + interval := responseStoreCleanupInterval + if testHookResponseStoreCleanupInterval > 0 { + interval = testHookResponseStoreCleanupInterval + } + ctx, cancel := context.WithCancel(parent) + s.mu.Lock() + s.responseStoreCleanupCancel = cancel + s.mu.Unlock() + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.runResponseStoreCleanupOnce(time.Now().UTC()) + } + } + }() +} + +func (s *ServerState) runResponseStoreCleanupOnce(now time.Time) int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + if s.ResponseStore == nil { + return 0 + } + removed := s.ResponseStore.pruneExpired(now) + if removed > 0 { + responseStorePruneTotalMetric.Add("expired_entries", int64(removed)) + } + return removed +} + +func buildPublicModelPayload(entry ModelDefinition) publicModelPayload { + return publicModelPayload{ + ID: entry.ID, + Object: "model", + Created: 0, + OwnedBy: "notion2api", + Name: entry.Name, + Family: entry.Family, + Group: entry.Group, + Beta: entry.Beta, + NotionModel: entry.NotionModel, + } +} + +func buildPublicModelsListPayload(registry ModelRegistry) publicModelsListPayload { + items := make([]publicModelPayload, 0, len(registry.Entries)) + for _, entry := range registry.Entries { + if !entry.Enabled { + continue + } + items = append(items, buildPublicModelPayload(entry)) + } + return publicModelsListPayload{ + Object: "list", + Data: items, + } +} + +func cloneBytes(src []byte) []byte { + if len(src) == 0 { + return nil + } + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func cloneBytesMap(src map[string][]byte) map[string][]byte { + if len(src) == 0 { + return nil + } + dst := make(map[string][]byte, len(src)) + for key, value := range src { + dst[key] = cloneBytes(value) + } + return dst +} + +func (s *ServerState) rebuildStaticJSONCachesLocked() { + healthPayload := healthzStaticPayload{ + OK: true, + DefaultModel: s.Config.DefaultPublicModel(), + ModelCount: len(s.ModelRegistry.Entries), + UserEmail: s.Session.UserEmail, + SpaceID: s.Session.SpaceID, + ActiveAccount: s.Config.ActiveAccount, + SessionRefreshEnable: s.Config.ResolveSessionRefresh().Enabled, + } + healthBody, err := json.Marshal(healthPayload) + if err == nil { + healthBodyCopy := cloneBytes(healthBody) + s.cachedHealthzStaticJSON.Store(&healthBodyCopy) + } else { + s.cachedHealthzStaticJSON.Store(nil) + } + + modelsPayload := buildPublicModelsListPayload(s.ModelRegistry) + modelsBody, err := json.Marshal(modelsPayload) + if err == nil { + modelsBodyCopy := cloneBytes(modelsBody) + s.cachedModelsListJSON.Store(&modelsBodyCopy) + } else { + s.cachedModelsListJSON.Store(nil) + } + + modelByID := make(map[string][]byte, len(s.ModelRegistry.Entries)) + for _, entry := range s.ModelRegistry.Entries { + if !entry.Enabled { + continue + } + body, marshalErr := json.Marshal(buildPublicModelPayload(entry)) + if marshalErr != nil { + continue + } + modelByID[normalizeLookupKey(entry.ID)] = cloneBytes(body) + } + modelByIDCopy := cloneBytesMap(modelByID) + s.cachedModelByIDJSON.Store(&modelByIDCopy) +} + +func writeJSONBytes(w http.ResponseWriter, status int, body []byte) { + applyCORSHeaders(w) + w.Header().Set("X-Notion2API", "1") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _, _ = w.Write(body) +} + +func appendHealthzRuntimeFields(body []byte, sessionReady bool, lastRefresh time.Time, lastRefreshError string) []byte { + trimmed := bytes.TrimSpace(body) + if len(trimmed) == 0 || trimmed[len(trimmed)-1] != '}' { + trimmed = []byte(`{"ok":true}`) + } + trimmed = bytes.TrimSuffix(trimmed, []byte("}")) + tail := map[string]any{ + "session_ready": sessionReady, + "last_session_refresh": formatTimeOrEmpty(lastRefresh), + "last_session_refresh_error": lastRefreshError, + } + tailBody, err := json.Marshal(tail) + if err != nil { + return body + } + tailBody = bytes.TrimPrefix(tailBody, []byte("{")) + out := make([]byte, 0, len(trimmed)+1+len(tailBody)) + out = append(out, trimmed...) + if len(trimmed) > 1 { + out = append(out, ',') + } + out = append(out, tailBody...) + return out +} + +func applyCORSHeaders(w http.ResponseWriter) { + w.Header().Set("Access-Control-Allow-Origin", corsAllowOrigin) + w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders) + w.Header().Set("Access-Control-Allow-Methods", corsAllowMethods) +} + func writeJSON(w http.ResponseWriter, status int, payload any) { body, err := json.Marshal(payload) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + applyCORSHeaders(w) w.Header().Set("X-Notion2API", "1") w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Admin-Token") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.WriteHeader(status) _, _ = w.Write(body) } @@ -581,13 +929,28 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string, errorTy }) } +func writeInvalidBodyError(w http.ResponseWriter, err error) { + if errors.Is(err, errRequestTooLarge) { + writeOpenAIError(w, http.StatusRequestEntityTooLarge, "request body exceeds configured limit", "invalid_request_error", "request_too_large") + return + } + writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) +} + func nilString() string { return "" } -func decodeBody(r *http.Request) (map[string]any, error) { - defer r.Body.Close() - decoder := json.NewDecoder(r.Body) +func decodeBodyWithLimit(w http.ResponseWriter, r *http.Request, maxBytes int64) (map[string]any, error) { + raw, err := decodeBodyRawWithLimit(w, r, maxBytes) + if err != nil { + return nil, err + } + return decodeBodyMapFromRaw(raw) +} + +func decodeBodyMapFromRaw(raw []byte) (map[string]any, error) { + decoder := json.NewDecoder(bytes.NewReader(raw)) decoder.UseNumber() var payload map[string]any if err := decoder.Decode(&payload); err != nil { @@ -599,6 +962,65 @@ func decodeBody(r *http.Request) (map[string]any, error) { return payload, nil } +func decodeBodyRawWithLimit(w http.ResponseWriter, r *http.Request, maxBytes int64) ([]byte, error) { + if maxBytes > 0 && w != nil { + r.Body = http.MaxBytesReader(w, r.Body, maxBytes) + } + defer r.Body.Close() + body, err := io.ReadAll(r.Body) + if err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + return nil, errRequestTooLarge + } + return nil, fmt.Errorf("invalid json: %w", err) + } + trimmed := bytes.TrimSpace(body) + if len(trimmed) == 0 { + return []byte("{}"), nil + } + var raw json.RawMessage + if err := json.Unmarshal(trimmed, &raw); err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + return nil, errRequestTooLarge + } + return nil, fmt.Errorf("invalid json: %w", err) + } + normalized := bytes.TrimSpace(raw) + if len(normalized) == 0 { + return []byte("{}"), nil + } + return normalized, nil +} + +func (a *App) decodeBody(w http.ResponseWriter, r *http.Request) (map[string]any, error) { + raw, err := a.decodeBodyRaw(w, r) + if err != nil { + return nil, err + } + return decodeBodyMapFromRaw(raw) +} + +func (a *App) decodeBodyRaw(w http.ResponseWriter, r *http.Request) ([]byte, error) { + maxBytes := int64(0) + if a != nil && a.State != nil { + cfg, _, _ := a.State.Snapshot() + maxBytes = cfg.Limits.MaxRequestBodyBytes + } + return decodeBodyRawWithLimit(w, r, maxBytes) +} + +func decodeTypedBodyFromRaw[T any](raw []byte) (T, error) { + var typed T + decoder := json.NewDecoder(bytes.NewReader(raw)) + decoder.UseNumber() + if err := decoder.Decode(&typed); err != nil { + return typed, fmt.Errorf("invalid json: %w", err) + } + return typed, nil +} + func (a *App) authOK(w http.ResponseWriter, r *http.Request) bool { cfg, _, _ := a.State.Snapshot() expected := strings.TrimSpace(cfg.APIKey) @@ -614,12 +1036,18 @@ func (a *App) authOK(w http.ResponseWriter, r *http.Request) bool { } func (a *App) serveHealthz(w http.ResponseWriter) { - cfg, session, registry := a.State.Snapshot() a.State.mu.RLock() sessionReady := a.State.Client != nil lastRefresh := a.State.LastSessionRefresh lastRefreshError := a.State.LastSessionRefreshError + cached := a.State.cachedHealthzStaticJSON.Load() a.State.mu.RUnlock() + if cached != nil { + body := appendHealthzRuntimeFields(*cached, sessionReady, lastRefresh, lastRefreshError) + writeJSONBytes(w, http.StatusOK, body) + return + } + cfg, session, registry := a.State.Snapshot() writeJSON(w, http.StatusOK, map[string]any{ "ok": true, "default_model": cfg.DefaultPublicModel(), @@ -635,28 +1063,13 @@ func (a *App) serveHealthz(w http.ResponseWriter) { } func (a *App) serveModels(w http.ResponseWriter) { - _, _, registry := a.State.Snapshot() - items := make([]map[string]any, 0, len(registry.Entries)) - for _, entry := range registry.Entries { - if !entry.Enabled { - continue - } - items = append(items, map[string]any{ - "id": entry.ID, - "object": "model", - "created": 0, - "owned_by": "notion2api", - "name": entry.Name, - "family": entry.Family, - "group": entry.Group, - "beta": entry.Beta, - "notion_model": entry.NotionModel, - }) + cached := a.State.cachedModelsListJSON.Load() + if cached != nil { + writeJSONBytes(w, http.StatusOK, *cached) + return } - writeJSON(w, http.StatusOK, map[string]any{ - "object": "list", - "data": items, - }) + _, _, registry := a.State.Snapshot() + writeJSON(w, http.StatusOK, buildPublicModelsListPayload(registry)) } func (a *App) serveModelByID(w http.ResponseWriter, path string) { @@ -667,17 +1080,13 @@ func (a *App) serveModelByID(w http.ResponseWriter, path string) { writeOpenAIError(w, http.StatusNotFound, "model not found", "invalid_request_error", "model_not_found") return } - writeJSON(w, http.StatusOK, map[string]any{ - "id": entry.ID, - "object": "model", - "created": 0, - "owned_by": "notion2api", - "name": entry.Name, - "family": entry.Family, - "group": entry.Group, - "beta": entry.Beta, - "notion_model": entry.NotionModel, - }) + if cached := a.State.cachedModelByIDJSON.Load(); cached != nil { + if body, ok := (*cached)[normalizeLookupKey(entry.ID)]; ok && len(body) > 0 { + writeJSONBytes(w, http.StatusOK, body) + return + } + } + writeJSON(w, http.StatusOK, buildPublicModelPayload(entry)) } func (a *App) serveResponseByID(w http.ResponseWriter, path string) { @@ -913,8 +1322,13 @@ func attachConversationResponseMetadata(payload map[string]any, conversationID s } func (a *App) resolveContinuationConversation(r *http.Request, payload map[string]any, previousResponseID string, hiddenPrompt string, segments []conversationPromptSegment) (continuationTarget, bool) { - rawCount := sessionRawMessageCount(segments) explicitConversationID := requestedConversationID(r, payload) + explicitThreadID := requestedThreadID(r, payload) + return a.resolveContinuationConversationWithExplicit(previousResponseID, hiddenPrompt, segments, explicitConversationID, explicitThreadID) +} + +func (a *App) resolveContinuationConversationWithExplicit(previousResponseID string, hiddenPrompt string, segments []conversationPromptSegment, explicitConversationID string, explicitThreadID string) (continuationTarget, bool) { + rawCount := sessionRawMessageCount(segments) validateState := func(state *conversationContinuationState) bool { if state == nil { return true @@ -986,9 +1400,9 @@ func (a *App) resolveContinuationConversation(r *http.Request, payload map[strin } } } - if threadID := requestedThreadID(r, payload); threadID != "" { - if entry, ok := a.State.conversations().FindByThreadID(threadID); ok { - state, err := a.State.loadConversationContinuationStateByThreadID(threadID) + if explicitThreadID != "" { + if entry, ok := a.State.conversations().FindByThreadID(explicitThreadID); ok { + state, err := a.State.loadConversationContinuationStateByThreadID(explicitThreadID) if err == nil && !validateState(state) { return continuationTarget{}, false } @@ -998,9 +1412,9 @@ func (a *App) resolveContinuationConversation(r *http.Request, payload map[strin return continuationTarget{Conversation: entry}, true } target := continuationTarget{Conversation: ConversationEntry{ - ThreadID: threadID, + ThreadID: explicitThreadID, }} - if state, err := a.State.loadConversationContinuationStateByThreadID(threadID); err == nil { + if state, err := a.State.loadConversationContinuationStateByThreadID(explicitThreadID); err == nil { if !validateState(state) { return continuationTarget{}, false } @@ -1111,6 +1525,70 @@ func includeUsageInStream(payload map[string]any) bool { return includeUsage } +func decodeChatCompletionsRequestBodyFromRaw(raw []byte) (chatCompletionsRequestBody, map[string]any, error) { + typed, err := decodeTypedBodyFromRaw[chatCompletionsRequestBody](raw) + if err == nil { + return normalizeTypedChatCompletionsRequestBody(typed), nil, nil + } + payload, mapErr := decodeBodyMapFromRaw(raw) + if mapErr != nil { + return chatCompletionsRequestBody{}, nil, mapErr + } + return extractChatCompletionsRequestBody(payload), payload, nil +} + +func decodeResponsesRequestBodyFromRaw(raw []byte) (responsesRequestBody, map[string]any, error) { + typed, err := decodeTypedBodyFromRaw[responsesRequestBody](raw) + if err == nil { + return normalizeTypedResponsesRequestBody(typed), nil, nil + } + payload, mapErr := decodeBodyMapFromRaw(raw) + if mapErr != nil { + return responsesRequestBody{}, nil, mapErr + } + return extractResponsesRequestBody(payload), payload, nil +} + +func maybeSillyTavernByTypedMessages(rawMessages any) bool { + items := sliceValue(rawMessages) + if len(items) == 0 { + return false + } + systemPrompts := make([]string, 0, len(items)) + for _, raw := range items { + msg := mapValue(raw) + if msg == nil { + continue + } + if strings.TrimSpace(strings.ToLower(stringValue(msg["role"]))) != "system" { + continue + } + text := collapseWhitespace(flattenContent(msg["content"])) + if text != "" { + systemPrompts = append(systemPrompts, text) + } + } + if len(systemPrompts) == 0 { + return false + } + if looksLikeSillyTavernImpersonate(systemPrompts) || looksLikeSillyTavernQuiet(systemPrompts, nil) { + return true + } + for _, prompt := range systemPrompts { + lower := strings.ToLower(collapseWhitespace(prompt)) + if strings.Contains(lower, "fictional chat between") || + strings.Contains(lower, "[start a new chat]") || + strings.Contains(lower, "[continue your last message without repeating its original content.]") { + return true + } + } + return false +} + +func rawMayNeedSillyTavernPayloadFallback(raw []byte) bool { + return bytes.Contains(raw, []byte(`"continue_prefill"`)) || bytes.Contains(raw, []byte(`"show_thoughts"`)) +} + func chatCompletionInitialFlushDelayForRequest(request PromptRunRequest) time.Duration { if request.ClientProfile == sillyTavernClientProfile || request.StreamReasoningWarmup { return 0 @@ -1152,16 +1630,33 @@ func (a *App) runPromptStreamWithSink(r *http.Request, request PromptRunRequest, } func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { - payload, err := decodeBody(r) + raw, err := a.decodeBodyRaw(w, r) if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) + writeInvalidBodyError(w, err) return } - if isLikelySillyTavernPayload(payload) { + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + writeInvalidBodyError(w, err) + return + } + if payload == nil && (typed.likelySillyTavernByEnvelope() || maybeSillyTavernByTypedMessages(typed.Messages) || rawMayNeedSillyTavernPayloadFallback(raw)) { + payload, err = decodeBodyMapFromRaw(raw) + if err != nil { + writeInvalidBodyError(w, err) + return + } + } + if payload != nil && (typed.likelySillyTavernByEnvelope() || isLikelySillyTavernPayload(payload)) { a.handleSillyTavernChatCompletionsPayload(w, r, payload) return } - normalized, err := normalizeChatInput(payload) + messages := sliceValue(typed.Messages) + if len(messages) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "messages must be an array", "invalid_request_error", nilString()) + return + } + normalized, err := normalizeChatInputFromParts(messages, typed.Attachments) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) return @@ -1171,7 +1666,12 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { return } cfg, _, registry := a.State.Snapshot() - entry, err := registry.Resolve(requestedModel(payload, cfg.DefaultPublicModel()), cfg.DefaultPublicModel()) + requestedModelID := requestedModelFromTyped(typed.Model, cfg.DefaultPublicModel()) + useWebSearch := requestedWebSearchFromTyped(typed.UseWebSearch, typed.Metadata, typed.Tools, cfg.Features.UseWebSearch) + preferredConversationID := requestedConversationIDFromTyped(r, typed.ConversationID, typed.Conversation, typed.Metadata) + explicitThreadID := requestedThreadIDFromTyped(r, typed.ThreadID, typed.Thread, typed.NotionThreadID, typed.Metadata) + requestedAccount := requestedAccountEmailFromTyped(r, typed.AccountEmail, typed.NotionAccountEmail, typed.Metadata) + entry, err := registry.Resolve(requestedModelID, cfg.DefaultPublicModel()) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", "model_not_found") return @@ -1187,17 +1687,16 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { HiddenPrompt: hiddenPrompt, PublicModel: entry.ID, NotionModel: entry.NotionModel, - UseWebSearch: requestedWebSearch(payload, cfg.Features.UseWebSearch), + UseWebSearch: useWebSearch, Attachments: normalized.Attachments, SessionFingerprint: originalFingerprint, RawMessageCount: originalRawMessageCount, } freshThreadMode := forceFreshThreadPerRequest(cfg) - preferredConversationID := requestedConversationID(r, payload) conversation := ConversationEntry{} - if matched, ok := a.resolveContinuationConversation(r, payload, "", hiddenPrompt, normalized.Segments); ok { + if matched, ok := a.resolveContinuationConversationWithExplicit("", hiddenPrompt, normalized.Segments, preferredConversationID, explicitThreadID); ok { conversation = matched.Conversation - request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccountEmail(r, payload)) + request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccount) if freshThreadMode { request.ForceLocalConversationContinue = strings.TrimSpace(conversation.ID) != "" request.Prompt = buildFreshThreadReplayPromptFromConversation(conversation, latestPrompt, normalized.Attachments, promptText) @@ -1210,14 +1709,18 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { request.Prompt = latestPrompt } } else { - request.PinnedAccountEmail = requestedAccountEmail(r, payload) + request.PinnedAccountEmail = requestedAccount } request.ConversationID = firstNonEmpty(strings.TrimSpace(conversation.ID), preferredConversationID) conversationID := a.startConversationTurn(conversation.ID, preferredConversationID, "api", "chat_completions", resolveRequestPromptForContinuation(normalized), request) setConversationIDHeader(w, conversationID) - stream, _ := payload["stream"].(bool) + stream := typed.Stream if stream { - a.writeChatCompletionLiveStream(w, r, request, entry.ID, includeUsageInStream(payload), conversationID) + includeUsage := false + if typed.StreamIncludeUsage != nil { + includeUsage = *typed.StreamIncludeUsage + } + a.writeChatCompletionLiveStream(w, r, request, entry.ID, includeUsage, conversationID) return } result, err := a.runPrompt(r, request) @@ -1237,9 +1740,9 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { } func (a *App) handleSillyTavernChatCompletions(w http.ResponseWriter, r *http.Request) { - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) + writeInvalidBodyError(w, err) return } a.handleSillyTavernChatCompletionsPayload(w, r, payload) @@ -1349,14 +1852,19 @@ func (a *App) handleSillyTavernChatCompletionsPayload(w http.ResponseWriter, r * } func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { - payload, err := decodeBody(r) + raw, err := a.decodeBodyRaw(w, r) if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) + writeInvalidBodyError(w, err) return } - stream, _ := payload["stream"].(bool) + typed, _, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + writeInvalidBodyError(w, err) + return + } + stream := typed.Stream var previousResponse map[string]any - previousResponseID := strings.TrimSpace(stringValue(payload["previous_response_id"])) + previousResponseID := strings.TrimSpace(typed.PreviousResponseID) if previousResponseID != "" { var ok bool previousResponse, ok = a.State.getResponse(previousResponseID) @@ -1365,7 +1873,7 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { return } } - normalized, err := normalizeResponsesInput(payload, previousResponse) + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, previousResponse) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) return @@ -1375,7 +1883,12 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { return } cfg, _, registry := a.State.Snapshot() - entry, err := registry.Resolve(requestedModel(payload, cfg.DefaultPublicModel()), cfg.DefaultPublicModel()) + requestedModelID := requestedModelFromTyped(typed.Model, cfg.DefaultPublicModel()) + useWebSearch := requestedWebSearchFromTyped(typed.UseWebSearch, typed.Metadata, typed.Tools, cfg.Features.UseWebSearch) + preferredConversationID := requestedConversationIDFromTyped(r, typed.ConversationID, typed.Conversation, typed.Metadata) + explicitThreadID := requestedThreadIDFromTyped(r, typed.ThreadID, typed.Thread, typed.NotionThreadID, typed.Metadata) + requestedAccount := requestedAccountEmailFromTyped(r, typed.AccountEmail, typed.NotionAccountEmail, typed.Metadata) + entry, err := registry.Resolve(requestedModelID, cfg.DefaultPublicModel()) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", "model_not_found") return @@ -1391,17 +1904,16 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { HiddenPrompt: hiddenPrompt, PublicModel: entry.ID, NotionModel: entry.NotionModel, - UseWebSearch: requestedWebSearch(payload, cfg.Features.UseWebSearch), + UseWebSearch: useWebSearch, Attachments: normalized.Attachments, SessionFingerprint: originalFingerprint, RawMessageCount: originalRawMessageCount, } freshThreadMode := forceFreshThreadPerRequest(cfg) - preferredConversationID := requestedConversationID(r, payload) conversation := ConversationEntry{} - if matched, ok := a.resolveContinuationConversation(r, payload, previousResponseID, hiddenPrompt, normalized.Segments); ok { + if matched, ok := a.resolveContinuationConversationWithExplicit(previousResponseID, hiddenPrompt, normalized.Segments, preferredConversationID, explicitThreadID); ok { conversation = matched.Conversation - request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccountEmail(r, payload)) + request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccount) if freshThreadMode { request.ForceLocalConversationContinue = strings.TrimSpace(conversation.ID) != "" request.Prompt = buildFreshThreadReplayPromptFromConversation(conversation, latestPrompt, normalized.Attachments, promptText) @@ -1414,7 +1926,7 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { request.Prompt = latestPrompt } } else { - request.PinnedAccountEmail = requestedAccountEmail(r, payload) + request.PinnedAccountEmail = requestedAccount } if freshThreadMode && strings.TrimSpace(conversation.ID) == "" { request.Prompt = buildFreshThreadReplayPromptFromStoredResponse(normalized.PreviousResponsePrompt, latestPrompt, normalized.Attachments, request.Prompt) @@ -1468,11 +1980,11 @@ func (a *App) writeUpstreamError(w http.ResponseWriter, err error) { } func prepareOpenAISSEHeaders(w http.ResponseWriter) { + applyCORSHeaders(w) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("Access-Control-Allow-Origin", "*") w.WriteHeader(http.StatusOK) } @@ -2106,7 +2618,13 @@ func (a *App) writeResponsesStream(w http.ResponseWriter, r *http.Request, resul } func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { + startedAt := time.Now() + statusCode := http.StatusOK + defer func() { + observeRequestDuration(r.URL.Path, r.Method, statusCode, time.Since(startedAt)) + }() safeWriter := &panicSafeResponseWriter{ResponseWriter: w} + applyCORSHeaders(safeWriter) defer func() { if recovered := recover(); recovered != nil { stack := strings.TrimSpace(string(debug.Stack())) @@ -2139,10 +2657,8 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() if r.Method == http.MethodOptions { - safeWriter.Header().Set("Access-Control-Allow-Origin", "*") - safeWriter.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Admin-Token") - safeWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") safeWriter.WriteHeader(http.StatusNoContent) + statusCode = safeWriter.status return } @@ -2150,16 +2666,20 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodGet && path == "/": a.serveIndex(safeWriter) + statusCode = safeWriter.status return case strings.HasPrefix(path, "/admin"): a.handleAdmin(safeWriter, r) + statusCode = safeWriter.status return case r.Method == http.MethodGet && path == "/healthz": a.serveHealthz(safeWriter) + statusCode = safeWriter.status return } if !a.authOK(safeWriter, r) { + statusCode = safeWriter.status return } @@ -2168,6 +2688,10 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { a.serveModels(safeWriter) case r.Method == http.MethodGet && strings.HasPrefix(path, "/v1/models/"): a.serveModelByID(safeWriter, path) + case r.Method == http.MethodGet && path == "/debug/vars": + expvar.Handler().ServeHTTP(safeWriter, r) + case r.Method == http.MethodGet && path == "/metrics": + writePrometheusMetrics(safeWriter) case r.Method == http.MethodGet && strings.HasPrefix(path, "/v1/responses/"): a.serveResponseByID(safeWriter, path) case r.Method == http.MethodPost && path == "/v1/st/chat/completions": @@ -2179,6 +2703,7 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { default: writeOpenAIError(safeWriter, http.StatusNotFound, "route not found", "invalid_request_error", "not_found") } + statusCode = safeWriter.status } func Main() { @@ -2190,6 +2715,14 @@ func Main() { app := &App{State: state} state.StartSessionRefreshLoop(context.Background()) app.StartEphemeralConversationCleanupLoop(context.Background()) + if cfg.Debug.PprofEnabled { + go func(addr string) { + log.Printf("[pprof] listening on http://%s/debug/pprof/ (local debug endpoint; avoid public exposure)", addr) + if err := http.ListenAndServe(addr, nil); err != nil { + log.Printf("[pprof] server stopped: %v", err) + } + }(cfg.Debug.PprofAddr) + } addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) server := &http.Server{ Addr: addr, diff --git a/internal/app/main_fresh_thread_test.go b/internal/app/main_fresh_thread_test.go index fd2a3ad..31a8379 100644 --- a/internal/app/main_fresh_thread_test.go +++ b/internal/app/main_fresh_thread_test.go @@ -2,15 +2,21 @@ package app import ( "bytes" + "context" "encoding/json" "errors" + "expvar" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" + "runtime" + "sort" + "strconv" "strings" "testing" + "time" ) func newFreshThreadTestApp(t *testing.T) *App { @@ -244,6 +250,1248 @@ func TestHandleSillyTavernFreshThreadReplaysLocalConversation(t *testing.T) { assertConversationContinued(t, app, seeded.ID, "thread-new-st", "The story continues.") } +func TestNormalizeConfigSetsPprofDefaults(t *testing.T) { + cfg := normalizeConfig(AppConfig{}) + if cfg.Debug.PprofEnabled { + t.Fatalf("expected pprof disabled by default") + } + if cfg.Debug.PprofAddr != "127.0.0.1:6060" { + t.Fatalf("unexpected default pprof addr: %q", cfg.Debug.PprofAddr) + } +} + +func TestDefaultConfigSetsDispatchProbeCacheTTLDefault(t *testing.T) { + cfg := defaultConfig() + if cfg.Dispatch.ProbeCacheTTLSeconds != 45 { + t.Fatalf("unexpected default dispatch probe cache ttl: %d", cfg.Dispatch.ProbeCacheTTLSeconds) + } +} + +func TestNormalizeConfigClampsNegativeDispatchProbeCacheTTL(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + Dispatch: DispatchConfig{ProbeCacheTTLSeconds: -3}, + }) + if cfg.Dispatch.ProbeCacheTTLSeconds != 0 { + t.Fatalf("expected negative dispatch probe cache ttl to clamp to 0, got %d", cfg.Dispatch.ProbeCacheTTLSeconds) + } +} + +func TestDefaultConfigBrowserHelperPoolSizeDefaultZero(t *testing.T) { + cfg := defaultConfig() + if got := cfg.Browser.HelperPoolSize; got != 0 { + t.Fatalf("unexpected default browser helper pool size: got %d want %d", got, 0) + } +} + +func TestNormalizeConfigClampsBrowserHelperPoolSizeBounds(t *testing.T) { + negative := normalizeConfig(AppConfig{ + Browser: BrowserConfig{HelperPoolSize: -2}, + }) + if got := negative.Browser.HelperPoolSize; got != 0 { + t.Fatalf("expected negative helper pool size clamp to 0, got %d", got) + } + tooLarge := normalizeConfig(AppConfig{ + Browser: BrowserConfig{HelperPoolSize: 99}, + }) + if got := tooLarge.Browser.HelperPoolSize; got != 8 { + t.Fatalf("expected oversized helper pool size clamp to 8, got %d", got) + } +} + +func TestEmbeddedBrowserHelperAssetsLoaded(t *testing.T) { + helper := strings.TrimSpace(nodeWreqHelperScript()) + if helper == "" { + t.Fatalf("expected embedded browser helper script to be non-empty") + } + for _, needle := range []string{ + "const { fetch } = require('node-wreq');", + "process.stdout.write(JSON.stringify(result));", + } { + if !strings.Contains(helper, needle) { + t.Fatalf("embedded browser helper script missing %q", needle) + } + } + + login := strings.TrimSpace(nodeWreqLoginHelperScript()) + if login == "" { + t.Fatalf("expected embedded browser login helper script to be non-empty") + } + for _, needle := range []string{ + "const { fetch } = require('node-wreq');", + "result.set_cookies = [...setCookieRecord.entries()]", + } { + if !strings.Contains(login, needle) { + t.Fatalf("embedded browser login helper script missing %q", needle) + } + } +} + +func TestNormalizeConfigPrecomputesRetryPrefixes(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + Prompt: PromptConfig{ + CodingRetryPrefixes: []string{"custom-coding-prefix"}, + GeneralRetryPrefixes: []string{"custom-general-prefix"}, + DirectAnswerRetryPrefixes: []string{"custom-direct-prefix"}, + }, + }) + if len(cfg.Prompt.precomputedAllRetryPrefixes) == 0 { + t.Fatalf("expected precomputed retry prefixes") + } + joined := strings.Join(cfg.Prompt.precomputedAllRetryPrefixes, "\n") + for _, required := range []string{ + "custom-coding-prefix", + "custom-general-prefix", + "custom-direct-prefix", + } { + if !strings.Contains(joined, required) { + t.Fatalf("precomputed retry prefixes missing %q", required) + } + } +} + +func TestEnsureAccountPathsSetsEmailKey(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + LoginHelper: LoginHelperConfig{SessionsDir: "probe_files/notion_accounts"}, + }) + account := ensureAccountPaths(cfg, NotionAccount{Email: " Alice@Example.COM "}) + if account.emailKey != "alice@example.com" { + t.Fatalf("unexpected cached email key: %q", account.emailKey) + } +} + +func BenchmarkPromptGuardLooksLikeCodingRequest(b *testing.B) { + text := "Please help debug this golang function and refactor the docker deployment script." + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = promptGuardLooksLikeCodingRequest(text) + } +} + +func BenchmarkPromptGuardStripRetryPrefixes(b *testing.B) { + cfg := normalizeConfig(AppConfig{ + Prompt: PromptConfig{ + CodingRetryPrefixes: []string{"custom-coding-prefix"}, + GeneralRetryPrefixes: []string{"custom-general-prefix"}, + DirectAnswerRetryPrefixes: []string{"custom-direct-prefix"}, + }, + }) + base := "this is a coding request body" + input := cfg.Prompt.CodingRetryPrefixes[0] + cfg.Prompt.GeneralRetryPrefixes[0] + base + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = promptGuardStripRetryPrefixes(cfg, input) + } +} + +func BenchmarkServeModelsCaching(b *testing.B) { + cfg := defaultConfig() + cfg.APIKey = "bench-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + b.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer bench-api-key") + + b.Run("cached", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + b.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + } + }) + + b.Run("uncached_fallback", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + state.cachedModelsListJSON.Store(nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + b.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + } + }) +} + +func BenchmarkDecodeChatCompletionsTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":true, + "stream_options":{"include_usage":"1"}, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":"请总结这段文本并给出要点。"} + ], + "metadata":{"use_web_search":false} + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + if len(sliceValue(typed.Messages)) == 0 { + b.Fatalf("expected typed messages") + } + } +} + +func BenchmarkDecodeChatCompletionsMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":true, + "stream_options":{"include_usage":"1"}, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":"请总结这段文本并给出要点。"} + ], + "metadata":{"use_web_search":false} + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractChatCompletionsRequestBody(payload) + if len(sliceValue(typed.Messages)) == 0 { + b.Fatalf("expected map-extracted messages") + } + } +} + +func BenchmarkNormalizeChatInputFromTypedMessages(b *testing.B) { + raw := []byte(`{ + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + typed, _, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + messages := sliceValue(typed.Messages) + if len(messages) == 0 { + b.Fatalf("expected typed messages") + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeChatInputFromParts(messages, typed.Attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkNormalizeChatInputFromMapMessages(b *testing.B) { + raw := []byte(`{ + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + messages := sliceValue(payload["messages"]) + if len(messages) == 0 { + b.Fatalf("expected map messages") + } + attachments := payload["attachments"] + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeChatInputFromParts(messages, attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkDecodeResponsesTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "previous_response_id":"resp_123", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "metadata":{"use_web_search":"1"}, + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + if len(sliceValue(typed.Input)) == 0 { + b.Fatalf("expected typed input items") + } + } +} + +func BenchmarkDecodeResponsesMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "previous_response_id":"resp_123", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "metadata":{"use_web_search":"1"}, + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractResponsesRequestBody(payload) + if len(sliceValue(typed.Input)) == 0 { + b.Fatalf("expected map-extracted responses input") + } + } +} + +func BenchmarkNormalizeResponsesInputFromTyped(b *testing.B) { + raw := []byte(`{ + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + typed, _, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkNormalizeResponsesInputFromMap(b *testing.B) { + raw := []byte(`{ + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeResponsesInputFromParts(payload["input"], payload["attachments"], nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkChatDecodeAndNormalizeTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + normalized, err := normalizeChatInputFromParts(sliceValue(typed.Messages), typed.Attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkChatDecodeAndNormalizeMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractChatCompletionsRequestBody(payload) + normalized, err := normalizeChatInputFromParts(sliceValue(typed.Messages), typed.Attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkResponsesDecodeAndNormalizeTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkResponsesDecodeAndNormalizeMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractResponsesRequestBody(payload) + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func TestServeModelsUsesStaticJSONCache(t *testing.T) { + app := newFreshThreadTestApp(t) + raw := []byte(`{"object":"list","data":[{"id":"cached-model","object":"model"}]}`) + ready := append([]byte(nil), raw...) + app.State.cachedModelsListJSON.Store(&ready) + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + if got := strings.TrimSpace(rec.Body.String()); got != string(raw) { + t.Fatalf("expected cached body, got %s", got) + } +} + +func TestServeModelByIDUsesStaticJSONCache(t *testing.T) { + app := newFreshThreadTestApp(t) + _, _, registry := app.State.Snapshot() + entry, err := registry.Resolve("gpt-5.4", "auto") + if err != nil { + t.Fatalf("resolve model failed: %v", err) + } + body := []byte(`{"id":"gpt-5.4","object":"model","cached":true}`) + cache := map[string][]byte{ + normalizeLookupKey(entry.ID): append([]byte(nil), body...), + } + app.State.cachedModelByIDJSON.Store(&cache) + req := httptest.NewRequest(http.MethodGet, "/v1/models/"+entry.ID, nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + if got := strings.TrimSpace(rec.Body.String()); got != string(body) { + t.Fatalf("expected cached body, got %s", got) + } +} + +func TestServeHealthzIncludesRefreshRuntimeFieldsWhenStaticCacheExists(t *testing.T) { + app := newFreshThreadTestApp(t) + static := []byte(`{"ok":true,"default_model":"gpt-5.4","model_count":3,"user_email":"user@example.com","space_id":"space-id","active_account":"acc@example.com","session_refresh_enabled":true}`) + staticCopy := append([]byte(nil), static...) + app.State.cachedHealthzStaticJSON.Store(&staticCopy) + app.State.mu.Lock() + app.State.LastSessionRefresh = time.Date(2026, time.January, 2, 3, 4, 5, 0, time.UTC) + app.State.LastSessionRefreshError = "refresh failed" + app.State.mu.Unlock() + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("unmarshal healthz failed: %v", err) + } + if got, _ := payload["default_model"].(string); got != "gpt-5.4" { + t.Fatalf("unexpected default_model: %q", got) + } + if got, ok := payload["session_ready"].(bool); !ok || got { + t.Fatalf("unexpected session_ready: %#v", payload["session_ready"]) + } + if got, _ := payload["last_session_refresh"].(string); got != "2026-01-02T03:04:05Z" { + t.Fatalf("unexpected last_session_refresh: %q", got) + } + if got, _ := payload["last_session_refresh_error"].(string); got != "refresh failed" { + t.Fatalf("unexpected last_session_refresh_error: %q", got) + } +} + +func TestServeHTTPDebugVarsExposesWreqClientMetric(t *testing.T) { + app := newFreshThreadTestApp(t) + before := int64(0) + if value := wreqClientNewTotalMetric.Get("standard"); value != nil { + before = value.(*expvar.Int).Value() + } + wreqClientNewTotalMetric.Add("standard", 1) + req := httptest.NewRequest(http.MethodGet, "/debug/vars", nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + body := rec.Body.String() + if !strings.Contains(body, `"notion2api_wreq_client_new_total"`) { + t.Fatalf("expected metrics payload to include wreq client metric, got %s", body) + } + if !strings.Contains(body, `"notion2api_http_transport_cache_total"`) { + t.Fatalf("expected metrics payload to include transport cache metric, got %s", body) + } + after := int64(0) + if value := wreqClientNewTotalMetric.Get("standard"); value != nil { + after = value.(*expvar.Int).Value() + } + if after < before+1 { + t.Fatalf("expected metric value to be incremented, before=%d after=%d", before, after) + } +} + +func TestServeHTTPMetricsExposesCorePrometheusSeries(t *testing.T) { + resetMetricsForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + setDispatchSlotInflight("alice@example.com", 2) + observeWreqFFICallDuration(25 * time.Millisecond) + observeSQLiteOpDuration("save_response", 2*time.Millisecond) + addBrowserHelperSpawn() + addBrowserHelperPoolWorkerSpawn() + + warmReq := httptest.NewRequest(http.MethodGet, "/healthz", nil) + warmRec := httptest.NewRecorder() + app.ServeHTTP(warmRec, warmReq) + if warmRec.Code != http.StatusOK { + t.Fatalf("unexpected warm-up status: got %d want %d", warmRec.Code, http.StatusOK) + } + + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + body := rec.Body.String() + for _, want := range []string{ + "notion2api_request_duration_seconds_bucket", + "notion2api_dispatch_slot_inflight", + "notion2api_wreq_ffi_call_duration_seconds_bucket", + "notion2api_browser_helper_spawn_total", + "notion2api_browser_helper_pool_worker_spawn_total", + "notion2api_sqlite_op_duration_seconds_bucket", + "notion2api_response_store_prune_total", + } { + if !strings.Contains(body, want) { + t.Fatalf("expected /metrics output to include %q, got: %s", want, body) + } + } + if !strings.Contains(body, "notion2api_browser_helper_pool_worker_spawn_total 1") { + t.Fatalf("expected pool worker spawn counter value to be 1, got: %s", body) + } +} + +func TestSnapshotReadsFromAtomicBundle(t *testing.T) { + state := &ServerState{} + cfg := defaultConfig() + cfg.APIKey = "snapshot-api-key" + session := SessionInfo{UserID: "user-1", SpaceID: "space-1"} + registry := ModelRegistry{ + Entries: []ModelDefinition{ + {ID: "gpt-5.4", Enabled: true}, + }, + } + + state.mu.Lock() + state.Config = cfg + state.Session = session + state.ModelRegistry = registry + state.updateSnapshotBundleLocked() + state.mu.Unlock() + + gotCfg, gotSession, gotRegistry := state.Snapshot() + if gotCfg.APIKey != cfg.APIKey { + t.Fatalf("snapshot cfg mismatch: got %q want %q", gotCfg.APIKey, cfg.APIKey) + } + if gotSession.UserID != session.UserID || gotSession.SpaceID != session.SpaceID { + t.Fatalf("snapshot session mismatch: got %+v want %+v", gotSession, session) + } + if len(gotRegistry.Entries) != 1 || gotRegistry.Entries[0].ID != "gpt-5.4" { + t.Fatalf("snapshot registry mismatch: %+v", gotRegistry.Entries) + } + if len(state.snap.Load().DispatchOrder) != 0 { + t.Fatalf("expected empty dispatch order for empty accounts") + } +} + +func TestSnapshotDispatchOrderPrecomputed(t *testing.T) { + tempDir := t.TempDir() + aliceProbe := filepath.Join(tempDir, "alice-probe.json") + bobProbe := filepath.Join(tempDir, "bob-probe.json") + if err := os.WriteFile(aliceProbe, []byte(`{"ok":true}`), 0o600); err != nil { + t.Fatalf("write alice probe failed: %v", err) + } + if err := os.WriteFile(bobProbe, []byte(`{"ok":true}`), 0o600); err != nil { + t.Fatalf("write bob probe failed: %v", err) + } + + cfg := defaultConfig() + cfg.APIKey = "snapshot-dispatch-order-api-key" + cfg.ActiveAccount = "bob@example.com" + cfg.Accounts = []NotionAccount{ + {Email: "alice@example.com", Priority: 10, MaxConcurrency: 1, ProbeJSON: aliceProbe}, + {Email: "bob@example.com", Priority: 1, MaxConcurrency: 1, ProbeJSON: bobProbe}, + {Email: "carol@example.com", Priority: 50, MaxConcurrency: 1, Disabled: true}, + } + cfg = normalizeConfig(cfg) + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + snap := state.snap.Load() + if snap == nil { + t.Fatalf("expected non-nil snapshot bundle") + } + if len(snap.DispatchOrder) != 2 { + t.Fatalf("unexpected dispatch order length: got %d want 2", len(snap.DispatchOrder)) + } + if getAccountEmailKey(snap.DispatchOrder[0]) != "bob@example.com" { + t.Fatalf("expected active account first in precomputed dispatch order, got %q", snap.DispatchOrder[0].Email) + } + if getAccountEmailKey(snap.DispatchOrder[1]) != "alice@example.com" { + t.Fatalf("expected second candidate to be alice, got %q", snap.DispatchOrder[1].Email) + } +} + +func TestResolveDispatchCandidatesFromSnapshotUsesPrecomputedOrder(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Accounts: []NotionAccount{ + {Email: "first@example.com", Priority: 10, MaxConcurrency: 1}, + {Email: "second@example.com", Priority: 20, MaxConcurrency: 1}, + }, + ActiveAccount: "first@example.com", + }) + now := time.Now() + bundle := &snapshotBundle{ + Config: cfg, + DispatchOrder: []NotionAccount{ + {Email: "second@example.com", Priority: 20, MaxConcurrency: 1}, + {Email: "first@example.com", Priority: 10, MaxConcurrency: 1}, + }, + } + candidates, err := resolveDispatchCandidatesFromSnapshot(bundle, PromptRunRequest{}, now) + if err != nil { + t.Fatalf("resolveDispatchCandidatesFromSnapshot failed: %v", err) + } + if len(candidates) != 2 { + t.Fatalf("unexpected candidates length: got %d want 2", len(candidates)) + } + if getAccountEmailKey(candidates[0]) != "second@example.com" || getAccountEmailKey(candidates[1]) != "first@example.com" { + t.Fatalf("unexpected candidate order from snapshot: %+v", candidates) + } +} + +func TestConversationStoreGetReturnsValueSnapshotAfterMutation(t *testing.T) { + store := newConversationStore() + created := store.Create(ConversationCreateRequest{ + PreferredID: "conv-value-snapshot", + Source: "api", + Transport: "chat_completions", + Model: "gpt-5.4", + Prompt: "hello", + }) + got1, ok := store.Get(created.ID) + if !ok { + t.Fatalf("expected created conversation to exist") + } + if got1.Status != "running" { + t.Fatalf("unexpected initial status: %q", got1.Status) + } + + store.Complete(created.ID, InferenceResult{ + Text: "done", + ThreadID: "thread-1", + AccountEmail: "alice@example.com", + }) + + got2, ok := store.Get(created.ID) + if !ok { + t.Fatalf("expected conversation after completion") + } + if got2.Status != "completed" { + t.Fatalf("unexpected status after complete: %q", got2.Status) + } + if got1.Status == got2.Status { + t.Fatalf("expected old value snapshot to remain unchanged, got1=%q got2=%q", got1.Status, got2.Status) + } +} + +func TestConversationStoreSummaryUsesCachedPreviewAfterMutations(t *testing.T) { + store := newConversationStore() + created := store.Create(ConversationCreateRequest{ + PreferredID: "conv-preview-cache", + Source: "api", + Transport: "chat_completions", + Model: "gpt-5.4", + Prompt: "first question", + }) + list1 := store.List() + if len(list1) == 0 { + t.Fatalf("expected list to have one entry") + } + if !strings.Contains(list1[0].Preview, "first question") { + t.Fatalf("unexpected initial preview: %q", list1[0].Preview) + } + + store.AppendAssistantDelta(created.ID, "assistant draft") + list2 := store.List() + if len(list2) == 0 { + t.Fatalf("expected list to have one entry after delta") + } + if !strings.Contains(list2[0].Preview, "assistant draft") { + t.Fatalf("expected preview to reflect assistant delta, got %q", list2[0].Preview) + } + + store.Complete(created.ID, InferenceResult{ + Text: "final assistant reply", + ThreadID: "thread-preview", + AccountEmail: "preview@example.com", + }) + list3 := store.List() + if len(list3) == 0 { + t.Fatalf("expected list to have one entry after complete") + } + if !strings.Contains(list3[0].Preview, "final assistant reply") { + t.Fatalf("expected preview to reflect completed assistant text, got %q", list3[0].Preview) + } +} + +func TestRequestedWebSearchFromTypedMetadataAndTools(t *testing.T) { + if got := requestedWebSearchFromTyped(nil, json.RawMessage(`{"use_web_search": true}`), nil, false); !got { + t.Fatalf("expected use_web_search=true from metadata to enable web search") + } + if got := requestedWebSearchFromTyped(nil, json.RawMessage(`{"notion_use_web_search":"false"}`), nil, true); got { + t.Fatalf("expected notion_use_web_search=false metadata to disable web search") + } + if got := requestedWebSearchFromTyped(nil, nil, json.RawMessage(`[{"type":"web_search_preview"}]`), false); !got { + t.Fatalf("expected web_search tool to enable web search") + } + if got := requestedWebSearchFromTyped(nil, map[string]any{"use_web_search": "1"}, nil, false); !got { + t.Fatalf("expected use_web_search=1 map metadata to enable web search") + } + if got := requestedWebSearchFromTyped(nil, nil, []map[string]any{{"type": "web_search_legacy"}}, false); !got { + t.Fatalf("expected web_search tool map slice to enable web search") + } +} + +func TestExtractTypedRequestBodies(t *testing.T) { + chatPayload := map[string]any{ + "model": "gpt-5.4", + "stream": true, + "stream_options": map[string]any{"include_usage": true}, + "conversation_id": "conv-typed-chat", + "account_email": "typed@example.com", + "use_web_search": "true", + "metadata": map[string]any{"notion_use_web_search": false}, + "attachments": []any{map[string]any{"type": "image_url", "url": "https://example.com/image.png"}}, + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + "type": "continue", + "user_name": "user", + "char_name": "char", + "group_names": []any{"g1"}, + "continue_prefill": "next", + "show_thoughts": true, + "notion_account_email": "typed2@example.com", + } + chatTyped := extractChatCompletionsRequestBody(chatPayload) + if chatTyped.Model != "gpt-5.4" || !chatTyped.Stream { + t.Fatalf("unexpected typed chat body: %+v", chatTyped) + } + if chatTyped.UseWebSearch == nil || !*chatTyped.UseWebSearch { + t.Fatalf("expected typed chat use_web_search=true") + } + if chatTyped.StreamIncludeUsage == nil || !*chatTyped.StreamIncludeUsage { + t.Fatalf("expected typed chat stream include_usage=true") + } + if _, ok := chatTyped.Attachments.([]any); !ok { + t.Fatalf("expected typed chat attachments to keep raw array type") + } + if _, ok := chatTyped.Messages.([]any); !ok { + t.Fatalf("expected typed chat messages to keep raw array type") + } + if !chatTyped.likelySillyTavernByEnvelope() { + t.Fatalf("expected chat body to be identified as likely sillytavern by envelope") + } + + respPayload := map[string]any{ + "model": "gpt-5.4", + "stream": false, + "previous_response_id": "resp_123", + "conversation_id": "conv-typed-responses", + "thread_id": "thread-typed", + "account_email": "resp@example.com", + "use_web_search": true, + "metadata": map[string]any{"use_web_search": true}, + "input": []any{map[string]any{"type": "text", "text": "input payload"}}, + "attachments": []any{map[string]any{"type": "file", "file_url": "https://example.com/file.txt"}}, + } + respTyped := extractResponsesRequestBody(respPayload) + if respTyped.Model != "gpt-5.4" || respTyped.Stream { + t.Fatalf("unexpected typed responses body: %+v", respTyped) + } + if respTyped.PreviousResponseID != "resp_123" || respTyped.ConversationID != "conv-typed-responses" { + t.Fatalf("unexpected typed responses ids: %+v", respTyped) + } + if respTyped.UseWebSearch == nil || !*respTyped.UseWebSearch { + t.Fatalf("expected typed responses use_web_search=true") + } + if _, ok := respTyped.Input.([]any); !ok { + t.Fatalf("expected typed responses input to keep raw array type") + } + if _, ok := respTyped.Attachments.([]any); !ok { + t.Fatalf("expected typed responses attachments to keep raw array type") + } +} + +func TestExtractChatTypedStreamIncludeUsageParsing(t *testing.T) { + fromRaw := extractChatCompletionsRequestBody(map[string]any{ + "stream_options": json.RawMessage(`{"include_usage":"1"}`), + }) + if fromRaw.StreamIncludeUsage == nil || !*fromRaw.StreamIncludeUsage { + t.Fatalf("expected stream include_usage to parse true from raw json string flag") + } + + fromMapFalse := extractChatCompletionsRequestBody(map[string]any{ + "stream_options": map[string]any{"include_usage": false}, + }) + if fromMapFalse.StreamIncludeUsage == nil { + t.Fatalf("expected stream include_usage pointer to be populated for explicit false") + } + if *fromMapFalse.StreamIncludeUsage { + t.Fatalf("expected stream include_usage=false from typed stream_options map") + } +} + +func TestRequestedIdentifiersFromTypedRespectHeaders(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("X-Conversation-ID", "header-conv") + req.Header.Set("X-Thread-ID", "header-thread") + req.Header.Set("X-Account-Email", "header@example.com") + + if got := requestedConversationIDFromTyped(req, "body-conv", "body-conv2", map[string]any{"conversation_id": "meta-conv"}); got != "header-conv" { + t.Fatalf("conversation id should prefer header, got %q", got) + } + if got := requestedThreadIDFromTyped(req, "body-thread", "body-thread2", "body-thread3", map[string]any{"thread_id": "meta-thread"}); got != "header-thread" { + t.Fatalf("thread id should prefer header, got %q", got) + } + if got := requestedAccountEmailFromTyped(req, "body@example.com", "body2@example.com", map[string]any{"account_email": "meta@example.com"}); got != "header@example.com" { + t.Fatalf("account email should prefer header, got %q", got) + } +} + +func TestRequestedIdentifiersFromTypedFallbackToMetadata(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + metadata := json.RawMessage(`{"conversation_id":"meta-conv","thread_id":"meta-thread","account_email":"meta@example.com"}`) + + if got := requestedConversationIDFromTyped(req, "", "", metadata); got != "meta-conv" { + t.Fatalf("conversation id should fallback to metadata, got %q", got) + } + if got := requestedThreadIDFromTyped(req, "", "", "", metadata); got != "meta-thread" { + t.Fatalf("thread id should fallback to metadata, got %q", got) + } + if got := requestedAccountEmailFromTyped(req, "", "", metadata); got != "meta@example.com" { + t.Fatalf("account email should fallback to metadata, got %q", got) + } +} + +func TestResolveContinuationConversationWithExplicitUsesTypedThreadID(t *testing.T) { + app := newFreshThreadTestApp(t) + seeded := seedCompletedConversation(t, app, "conv-typed-explicit", "Seed question", "Seed answer", "thread-explicit") + + segments := []conversationPromptSegment{ + {Role: "user", Text: "follow up"}, + } + + target, ok := app.resolveContinuationConversationWithExplicit("", "", segments, "", "thread-explicit") + if !ok { + t.Fatalf("expected explicit typed thread id to resolve continuation target") + } + if strings.TrimSpace(target.Conversation.ID) != seeded.ID { + t.Fatalf("unexpected resolved conversation id: got %q want %q", target.Conversation.ID, seeded.ID) + } + if strings.TrimSpace(target.Conversation.ThreadID) != "thread-explicit" { + t.Fatalf("unexpected resolved thread id: got %q", target.Conversation.ThreadID) + } +} + +func TestTypedEnvelopeExtractionFallsBackToLegacyWhenTypedFieldsMissing(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + var captured PromptRunRequest + app.runPromptOverride = func(_ *http.Request, request PromptRunRequest) (InferenceResult, error) { + captured = request + return InferenceResult{ + Text: "typed fallback ok", + ThreadID: "thread-typed-fallback", + MessageID: "msg-typed-fallback", + TraceID: "trace-typed-fallback", + AccountEmail: "header@example.com", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", mustJSONBody(t, map[string]any{ + "model": "gpt-5.4", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + })) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + req.Header.Set("X-Account-Email", "header@example.com") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String()) + } + if captured.PinnedAccountEmail != "header@example.com" { + t.Fatalf("expected pinned account from header fallback path, got %q", captured.PinnedAccountEmail) + } + if captured.PublicModel != "gpt-5.4" { + t.Fatalf("expected resolved model from legacy payload path, got %q", captured.PublicModel) + } +} + +func sqliteWriterFallbackValue(reason string) int64 { + if strings.TrimSpace(reason) == "" { + return 0 + } + value := sqliteWriterFallbackTotalMetric.Get(reason) + if value == nil { + return 0 + } + counter, ok := value.(*expvar.Int) + if !ok || counter == nil { + return 0 + } + return counter.Value() +} + +func boolPtr(value bool) *bool { + return &value +} + +func TestSaveResponseWithAccountPersistsViaAsyncSQLiteWriter(t *testing.T) { + tempDir := t.TempDir() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = filepath.Join(tempDir, "responses.sqlite") + cfg.Storage.PersistConversations = true + cfg.Storage.PersistResponses = boolPtr(true) + cfg.Responses.StoreTTLSeconds = 3600 + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + responseID := "resp_async_test_1" + payload := map[string]any{ + "id": responseID, + "object": "response", + "output": []any{ + map[string]any{ + "type": "message", + "content": []any{ + map[string]any{ + "type": "output_text", + "text": "hello from async sqlite writer", + }, + }, + }, + }, + } + state.saveResponseWithAccount(responseID, payload, "conv-async", "thread-async", "async@example.com") + + deadline := time.Now().Add(3 * time.Second) + for { + record, ok := state.getStoredResponse(responseID) + if ok && strings.TrimSpace(record.ThreadID) == "thread-async" { + break + } + if time.Now().After(deadline) { + t.Fatalf("response not visible in in-memory store before deadline") + } + time.Sleep(25 * time.Millisecond) + } + + readStore, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore(read) failed: %v", err) + } + defer func() { + _ = readStore.Close() + }() + + waitUntil := time.Now().Add(3 * time.Second) + for { + rows, queryErr := readStore.db.Query(`SELECT payload_json, conversation_id, thread_id, account_email FROM responses WHERE response_id = ?`, responseID) + if queryErr != nil { + t.Fatalf("query persisted response failed: %v", queryErr) + } + found := false + var rawPayload string + var conversationID string + var threadID string + var accountEmail string + for rows.Next() { + found = true + if scanErr := rows.Scan(&rawPayload, &conversationID, &threadID, &accountEmail); scanErr != nil { + _ = rows.Close() + t.Fatalf("scan persisted response failed: %v", scanErr) + } + } + _ = rows.Close() + if found { + if strings.TrimSpace(conversationID) != "conv-async" { + t.Fatalf("conversation_id mismatch: got %q want %q", conversationID, "conv-async") + } + if strings.TrimSpace(threadID) != "thread-async" { + t.Fatalf("thread_id mismatch: got %q want %q", threadID, "thread-async") + } + if strings.TrimSpace(accountEmail) != "async@example.com" { + t.Fatalf("account_email mismatch: got %q want %q", accountEmail, "async@example.com") + } + if !strings.Contains(rawPayload, "hello from async sqlite writer") { + t.Fatalf("unexpected payload_json: %s", rawPayload) + } + break + } + if time.Now().After(waitUntil) { + t.Fatalf("persisted response not found before deadline") + } + time.Sleep(25 * time.Millisecond) + } +} + +func TestSQLiteWriterCloseFlushesQueuedResponseWrites(t *testing.T) { + tempDir := t.TempDir() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = filepath.Join(tempDir, "close-flush.sqlite") + cfg.Storage.PersistConversations = true + cfg.Storage.PersistResponses = boolPtr(true) + cfg.Responses.StoreTTLSeconds = 3600 + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + + total := 12 + for i := 0; i < total; i++ { + responseID := "resp_flush_" + strconv.Itoa(i) + state.saveResponseWithAccount(responseID, map[string]any{ + "id": responseID, + "object": "response", + "idx": i, + }, "conv-flush", "thread-flush", "flush@example.com") + } + + if err := state.Close(); err != nil { + t.Fatalf("state.Close failed: %v", err) + } + + readStore, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore(read) failed: %v", err) + } + defer func() { + _ = readStore.Close() + }() + + row := readStore.db.QueryRow(`SELECT COUNT(1) FROM responses WHERE conversation_id = ? AND thread_id = ?`, "conv-flush", "thread-flush") + var persisted int + if scanErr := row.Scan(&persisted); scanErr != nil { + t.Fatalf("scan persisted count failed: %v", scanErr) + } + if persisted != total { + t.Fatalf("persisted response count mismatch after close flush: got %d want %d", persisted, total) + } +} + +func TestSQLiteWriterFallbackMetricRemainsStableUnderNormalLoad(t *testing.T) { + tempDir := t.TempDir() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = filepath.Join(tempDir, "fallback-metric.sqlite") + cfg.Storage.PersistConversations = true + cfg.Storage.PersistResponses = boolPtr(true) + cfg.Responses.StoreTTLSeconds = 3600 + + beforeChannelFull := sqliteWriterFallbackValue("channel_full") + beforeUnavailable := sqliteWriterFallbackValue("writer_unavailable") + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + for i := 0; i < 8; i++ { + responseID := "resp_metric_" + strconv.Itoa(i) + state.saveResponseWithAccount(responseID, map[string]any{ + "id": responseID, + "object": "response", + "idx": i, + }, "conv-metric", "thread-metric", "metric@example.com") + } + + time.Sleep(250 * time.Millisecond) + + afterChannelFull := sqliteWriterFallbackValue("channel_full") + afterUnavailable := sqliteWriterFallbackValue("writer_unavailable") + if afterChannelFull != beforeChannelFull { + t.Fatalf("expected no channel_full fallback in normal load; before=%d after=%d", beforeChannelFull, afterChannelFull) + } + if afterUnavailable != beforeUnavailable { + t.Fatalf("expected no writer_unavailable fallback in normal load; before=%d after=%d", beforeUnavailable, afterUnavailable) + } +} + func TestHandleChatCompletionsFreshThreadContinuesExplicitConversationIDWithLatestUserOnly(t *testing.T) { app := newFreshThreadTestApp(t) @@ -369,6 +1617,53 @@ func TestServerStateSaveAndApplyRejectsEmptyAPIKey(t *testing.T) { } } +func TestServerStateSaveAndApplyInvalidatesDispatchProbeCacheOnActiveAccountChange(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + cfg.Accounts = []NotionAccount{ + { + Email: "alice@example.com", + ProbeJSON: "probe_files/notion_accounts/alice/probe.json", + UserID: "alice-user", + SpaceID: "alice-space", + ClientVersion: "v1", + }, + { + Email: "bob@example.com", + ProbeJSON: "probe_files/notion_accounts/bob/probe.json", + UserID: "bob-user", + SpaceID: "bob-space", + ClientVersion: "v1", + }, + } + cfg.ActiveAccount = "alice@example.com" + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + if state.DispatchProbeCache == nil { + t.Fatalf("expected dispatch probe cache to be initialized") + } + state.DispatchProbeCache.markSuccess("alice@example.com", time.Now()) + if state.DispatchProbeCache.shouldProbe("alice@example.com", 45*time.Second, time.Now()) { + t.Fatalf("expected warm cache entry before active-account change") + } + + next := state.Config + next.ActiveAccount = "bob@example.com" + if err := state.SaveAndApply(next); err != nil { + t.Fatalf("SaveAndApply failed: %v", err) + } + if !state.DispatchProbeCache.shouldProbe("alice@example.com", 45*time.Second, time.Now()) { + t.Fatalf("expected cache invalidation after active-account switch") + } +} + func TestHandleChatCompletionsStreamWritesErrorAfterHeadersSent(t *testing.T) { app := newFreshThreadTestApp(t) app.runPromptStreamSinkOverride = func(_ *http.Request, _ PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { @@ -401,6 +1696,47 @@ func TestHandleChatCompletionsStreamWritesErrorAfterHeadersSent(t *testing.T) { } } +func TestHandleChatCompletionsStreamIncludeUsageFromTypedMessages(t *testing.T) { + app := newFreshThreadTestApp(t) + app.runPromptStreamSinkOverride = func(_ *http.Request, _ PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { + if sink.Text != nil { + if err := sink.Text("hello "); err != nil { + t.Fatalf("stream text write failed: %v", err) + } + if err := sink.Text("world"); err != nil { + t.Fatalf("stream text write failed: %v", err) + } + } + return InferenceResult{ + Text: "hello world", + Prompt: "hello world", + ThreadID: "thread-stream-usage", + MessageID: "msg-stream-usage", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", mustJSONBody(t, map[string]any{ + "model": "gpt-5.4", + "stream": true, + "stream_options": map[string]any{"include_usage": true}, + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + })) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + body := rec.Body.String() + if !strings.Contains(body, "\"usage\"") { + t.Fatalf("expected stream output to include usage chunk, got body=%s", body) + } + if !strings.Contains(body, "data: [DONE]") { + t.Fatalf("expected stream done marker, got body=%s", body) + } +} + func TestNormalizeConfigDefaultsAccountMaxConcurrencyToOne(t *testing.T) { cfg := normalizeConfig(AppConfig{ APIKey: "test-api-key", @@ -487,3 +1823,562 @@ func TestRunPromptWithAccountPoolReturnsCapacityErrorWhenAllSlotsOccupied(t *tes t.Fatalf("expected wrapped sentinel error, got %v", runErr) } } + +func TestRefreshSessionInvalidatesDispatchProbeCacheOnSuccess(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Accounts: []NotionAccount{ + { + Email: "alice@example.com", + ProbeJSON: "/tmp/alice/probe.json", + StorageStatePath: "/tmp/alice/storage_state.json", + PendingStatePath: "/tmp/alice/pending_login.json", + UserID: "alice-user", + SpaceID: "alice-space", + UserName: "alice", + SpaceName: "alice-space-name", + ClientVersion: "v1", + Status: "ready", + }, + }, + ActiveAccount: "alice@example.com", + SessionRefresh: SessionRefreshConfig{ + Enabled: true, + RetryOnAuthError: true, + AutoSwitch: true, + }, + }) + state := &ServerState{ + Config: cfg, + Session: SessionInfo{UserID: "alice-user", SpaceID: "alice-space"}, + DispatchProbeCache: newProbeCache(), + ResponseStore: newResponseStore(45 * time.Second), + Conversations: newConversationStore(), + AdminTokens: map[string]time.Time{}, + AdminLoginAttempts: map[string]AdminLoginAttempt{}, + } + slot := &accountSlot{} + slot.max.Store(1) + slot.inflight.Store(0) + slotMap := map[string]*accountSlot{ + "alice@example.com": slot, + } + state.slots.Store(&slotMap) + syncDispatchSlotInflightFromSlots(slotMap) + state.DispatchProbeCache.markSuccess("alice@example.com", time.Now()) + + originalTryRefresh := testHookTryRefreshAccount + originalSaveAndApply := testHookSaveAndApply + defer func() { + testHookTryRefreshAccount = originalTryRefresh + testHookSaveAndApply = originalSaveAndApply + }() + + testHookTryRefreshAccount = func(ctx context.Context, cfg AppConfig, account NotionAccount) (AppConfig, error) { + account.Status = "ready" + account.LastError = "" + account.LastRefreshAt = time.Now().Format(time.RFC3339) + cfg.UpsertAccount(account) + return cfg, nil + } + testHookSaveAndApply = func(s *ServerState, cfg AppConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + s.Config = cfg + s.updateSnapshotBundleLocked() + return nil + } + + if err := state.RefreshSession(context.Background(), "test_refresh_success"); err != nil { + t.Fatalf("refresh session failed: %v", err) + } + if !state.DispatchProbeCache.shouldProbe("alice@example.com", 45*time.Second, time.Now()) { + t.Fatalf("expected probe cache to be invalidated after refresh success") + } +} + +func newSQLiteStoreTestConfig(path string) AppConfig { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = path + return cfg +} + +func TestOpenSQLiteStoreConfiguresReadWriteAndReadOnlyPools(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + if store.db == nil { + t.Fatalf("expected writable sqlite connection") + } + if store.roDB == nil { + t.Fatalf("expected read-only sqlite connection") + } + if got := store.db.Stats().MaxOpenConnections; got != 1 { + t.Fatalf("unexpected write db max open conns: got %d want 1", got) + } + wantReadConns := maxInt(2, runtime.NumCPU()) + if got := store.roDB.Stats().MaxOpenConnections; got != wantReadConns { + t.Fatalf("unexpected read db max open conns: got %d want %d", got, wantReadConns) + } +} + +func TestSQLiteStoreInitAppliesExtendedPragmas(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + + var mmapSize int64 + if err := store.db.QueryRow("PRAGMA mmap_size;").Scan(&mmapSize); err != nil { + t.Fatalf("query mmap_size failed: %v", err) + } + if mmapSize != 268435456 { + t.Fatalf("unexpected mmap_size: got %d want %d", mmapSize, int64(268435456)) + } + + var cacheSize int64 + if err := store.db.QueryRow("PRAGMA cache_size;").Scan(&cacheSize); err != nil { + t.Fatalf("query cache_size failed: %v", err) + } + if cacheSize != -65536 { + t.Fatalf("unexpected cache_size: got %d want %d", cacheSize, int64(-65536)) + } + + var tempStore int64 + if err := store.db.QueryRow("PRAGMA temp_store;").Scan(&tempStore); err != nil { + t.Fatalf("query temp_store failed: %v", err) + } + if tempStore != 2 { + t.Fatalf("unexpected temp_store: got %d want 2(memory)", tempStore) + } + + var autoCheckpoint int64 + if err := store.db.QueryRow("PRAGMA wal_autocheckpoint;").Scan(&autoCheckpoint); err != nil { + t.Fatalf("query wal_autocheckpoint failed: %v", err) + } + if autoCheckpoint != 1000 { + t.Fatalf("unexpected wal_autocheckpoint: got %d want 1000", autoCheckpoint) + } +} + +func TestSQLiteStoreReadOnlyConnectionRejectsWrites(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + _, err = store.roDB.Exec("CREATE TABLE read_only_write_should_fail(id INTEGER)") + if err == nil { + t.Fatalf("expected write on read-only connection to fail") + } + if !strings.Contains(strings.ToLower(err.Error()), "readonly") { + t.Fatalf("expected readonly error, got: %v", err) + } +} + +func TestSQLiteStoreLoadAccountsUsesReadOnlyConnection(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + + saveCfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Storage: StorageConfig{SQLitePath: cfg.Storage.SQLitePath}, + LoginHelper: LoginHelperConfig{SessionsDir: "probe_files/notion_accounts"}, + Accounts: []NotionAccount{{Email: "alice@example.com"}}, + ActiveAccount: "alice@example.com", + }) + if err := store.SaveAccounts(saveCfg); err != nil { + t.Fatalf("SaveAccounts failed: %v", err) + } + + if err := store.db.Close(); err != nil { + t.Fatalf("close write db failed: %v", err) + } + store.db = nil + accounts, activeAccount, ok, err := store.LoadAccounts() + if err != nil { + t.Fatalf("LoadAccounts failed: %v", err) + } + if !ok { + t.Fatalf("expected persisted accounts to be available") + } + if len(accounts) != 1 { + t.Fatalf("unexpected account count: got %d want 1", len(accounts)) + } + if getAccountEmailKey(accounts[0]) != "alice@example.com" { + t.Fatalf("unexpected loaded account email: %q", accounts[0].Email) + } + if canonicalEmailKey(activeAccount) != "alice@example.com" { + t.Fatalf("unexpected active account: %q", activeAccount) + } +} + +func TestServeHTTPOptionsReturnsCORSNoContent(t *testing.T) { + app := newFreshThreadTestApp(t) + req := httptest.NewRequest(http.MethodOptions, "/v1/models", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("unexpected options status: got %d want %d", rec.Code, http.StatusNoContent) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != corsAllowOrigin { + t.Fatalf("unexpected Access-Control-Allow-Origin: got %q want %q", got, corsAllowOrigin) + } + if got := rec.Header().Get("Access-Control-Allow-Headers"); got != corsAllowHeaders { + t.Fatalf("unexpected Access-Control-Allow-Headers: got %q want %q", got, corsAllowHeaders) + } + if got := rec.Header().Get("Access-Control-Allow-Methods"); got != corsAllowMethods { + t.Fatalf("unexpected Access-Control-Allow-Methods: got %q want %q", got, corsAllowMethods) + } +} + +func TestServeIndexIncludesCORSHeaders(t *testing.T) { + app := newFreshThreadTestApp(t) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != corsAllowOrigin { + t.Fatalf("unexpected Access-Control-Allow-Origin: got %q want %q", got, corsAllowOrigin) + } + if got := rec.Header().Get("Access-Control-Allow-Headers"); got != corsAllowHeaders { + t.Fatalf("unexpected Access-Control-Allow-Headers: got %q want %q", got, corsAllowHeaders) + } + if got := rec.Header().Get("Access-Control-Allow-Methods"); got != corsAllowMethods { + t.Fatalf("unexpected Access-Control-Allow-Methods: got %q want %q", got, corsAllowMethods) + } +} + +func TestNormalizeConfigSetsMaxRequestBodyBytesDefault(t *testing.T) { + cfg := normalizeConfig(AppConfig{}) + if got := cfg.Limits.MaxRequestBodyBytes; got != 4*1024*1024 { + t.Fatalf("unexpected max request body bytes default: got %d want %d", got, int64(4*1024*1024)) + } +} + +func TestNormalizeConfigClampsNonPositiveMaxRequestBodyBytes(t *testing.T) { + cfg := normalizeConfig(AppConfig{Limits: LimitsConfig{MaxRequestBodyBytes: -1}}) + if got := cfg.Limits.MaxRequestBodyBytes; got != 4*1024*1024 { + t.Fatalf("unexpected max request body bytes clamp: got %d want %d", got, int64(4*1024*1024)) + } +} + +func TestHandleChatCompletionsRejectsTooLargeBody(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + cfg.Limits.MaxRequestBodyBytes = 128 + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + oversizeText := strings.Repeat("x", 512) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", mustJSONBody(t, map[string]any{ + "model": "gpt-5.4", + "messages": []map[string]any{ + {"role": "user", "content": oversizeText}, + }, + })) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusRequestEntityTooLarge, rec.Body.String()) + } + body := rec.Body.String() + if !strings.Contains(body, `"code":"request_too_large"`) { + t.Fatalf("expected request_too_large code, got %s", body) + } + if !strings.Contains(body, `"type":"invalid_request_error"`) { + t.Fatalf("expected invalid_request_error type, got %s", body) + } +} + +func TestDecodeBodyRawWithLimitRejectsTrailingContent(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"a":1} {"b":2}`)) + raw, err := decodeBodyRawWithLimit(nil, req, 0) + if err == nil { + t.Fatalf("expected trailing content error, got raw=%q", string(raw)) + } + if !strings.Contains(err.Error(), "invalid json") { + t.Fatalf("expected invalid json error, got %v", err) + } +} + +func TestDecodeBodyRawWithLimitNormalizesWhitespace(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(" \n\t {\"a\":1}\n\t ")) + raw, err := decodeBodyRawWithLimit(nil, req, 0) + if err != nil { + t.Fatalf("decodeBodyRawWithLimit failed: %v", err) + } + if got := strings.TrimSpace(string(raw)); got != "{\"a\":1}" { + t.Fatalf("unexpected normalized raw body: got %q", got) + } +} + +func TestDecodeBodyRawWithLimitTreatsEmptyBodyAsEmptyObject(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(" \n\t ")) + raw, err := decodeBodyRawWithLimit(nil, req, 0) + if err != nil { + t.Fatalf("decodeBodyRawWithLimit failed: %v", err) + } + if string(raw) != "{}" { + t.Fatalf("expected empty object for empty body, got %q", string(raw)) + } +} + +func TestDecodeChatCompletionsRequestBodyFromRawFallsBackToMapOnTypedDecodeMismatch(t *testing.T) { + raw := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"group_names":[1]}`) + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + t.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload == nil { + t.Fatalf("expected payload fallback map to be populated") + } + messages := sliceValue(typed.Messages) + if len(messages) != 1 { + t.Fatalf("expected typed messages recovered via map fallback, got len=%d", len(messages)) + } + msg := mapValue(messages[0]) + if strings.TrimSpace(stringValue(msg["content"])) != "hello" { + t.Fatalf("expected fallback-typed message content 'hello', got %#v", msg["content"]) + } +} + +func TestDecodeChatCompletionsRequestBodyFromRawParsesStreamIncludeUsageWithoutMapFallback(t *testing.T) { + raw := []byte(`{"model":"gpt-5.4","stream_options":{"include_usage":"1"},"messages":[{"role":"user","content":"hello"}]}`) + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + t.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + t.Fatalf("expected typed decode path without map fallback") + } + if typed.StreamIncludeUsage == nil || !*typed.StreamIncludeUsage { + t.Fatalf("expected stream include_usage=true from typed decode path") + } +} + +func TestHandleChatCompletionsSillyTavernFallbackOnContinuePrefillKey(t *testing.T) { + app := newFreshThreadTestApp(t) + captured := PromptRunRequest{} + app.runPromptOverride = func(_ *http.Request, request PromptRunRequest) (InferenceResult, error) { + captured = request + return InferenceResult{ + Text: "ok", + ThreadID: "thread-st-fallback", + AccountEmail: "seed@example.com", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{ + "model":"gpt-5.4", + "messages":[{"role":"user","content":"Hello there"}], + "continue_prefill":"...", + "group_names":[1] + }`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String()) + } + if captured.ClientProfile != sillyTavernClientProfile { + t.Fatalf("expected sillytavern client profile, got %q", captured.ClientProfile) + } + if strings.TrimSpace(captured.Prompt) == "" { + t.Fatalf("expected non-empty prompt for sillytavern fallback") + } +} + +func TestDecodeResponsesRequestBodyFromRawFallsBackToMapOnTypedDecodeMismatch(t *testing.T) { + raw := []byte(`{"model":"gpt-5.4","input":"hello","attachments":[{"type":"file","file_url":"https://example.com/f.txt"}],"conversation_id":1}`) + typed, payload, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + t.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + if payload == nil { + t.Fatalf("expected payload fallback map to be populated") + } + if strings.TrimSpace(typed.Model) != "gpt-5.4" { + t.Fatalf("unexpected model after fallback: %q", typed.Model) + } + if strings.TrimSpace(flattenContent(typed.Input)) != "hello" { + t.Fatalf("expected fallback-typed input 'hello', got %#v", typed.Input) + } + atts := sliceValue(typed.Attachments) + if len(atts) != 1 { + t.Fatalf("expected one attachment after fallback, got %d", len(atts)) + } +} + +func TestHandleResponsesTypedFirstDecodeFallbackOnConversationIDTypeMismatch(t *testing.T) { + app := newFreshThreadTestApp(t) + seeded := seedCompletedConversation(t, app, "conv-responses-fallback", "Please remember this.", "Remembered.", "thread-old-responses-fallback") + + var captured PromptRunRequest + app.runPromptOverride = func(_ *http.Request, request PromptRunRequest) (InferenceResult, error) { + captured = request + return InferenceResult{ + Text: "Summary ready.", + ThreadID: "thread-new-responses-fallback", + MessageID: "msg-new-responses-fallback", + TraceID: "trace-new-responses-fallback", + AccountEmail: "seed@example.com", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{ + "model":"gpt-5.4", + "input":"Summarize that.", + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}], + "conversation_id":1 + }`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + req.Header.Set("X-Conversation-ID", seeded.ID) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status mismatch: got %d body=%s", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("X-Conversation-ID"); got != seeded.ID { + t.Fatalf("conversation header mismatch: got %q want %q", got, seeded.ID) + } + if !captured.ForceLocalConversationContinue { + t.Fatalf("expected ForceLocalConversationContinue to be enabled") + } + assertPromptContains(t, captured.Prompt, + "Continue the conversation using the transcript below.", + "[user]\nPlease remember this.", + "[assistant]\nRemembered.", + "[user]\nSummarize that.", + ) + assertConversationContinued(t, app, seeded.ID, "thread-new-responses-fallback", "Summary ready.") +} + +func TestCollectProbeModelPathsIncludesActiveAndAccountProbeJSON(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + ProbeJSON: " probe_files/notion_accounts/active/probe.json ", + Accounts: []NotionAccount{ + {ProbeJSON: "probe_files/notion_accounts/alpha/probe.json"}, + {ProbeJSON: "probe_files/notion_accounts/alpha/probe.json"}, + {ProbeJSON: " probe_files/notion_accounts/beta/probe.json "}, + }, + }) + paths := collectProbeModelPaths(cfg) + for i := range paths { + paths[i] = strings.ReplaceAll(paths[i], "\\", "/") + } + sort.Strings(paths) + expected := []string{ + "probe_files/notion_accounts/active/probe.json", + "probe_files/notion_accounts/alpha/probe.json", + "probe_files/notion_accounts/beta/probe.json", + } + if len(paths) != len(expected) { + t.Fatalf("unexpected path count: got %d want %d (%v)", len(paths), len(expected), paths) + } + for i := range expected { + if paths[i] != expected[i] { + t.Fatalf("unexpected path[%d]: got %q want %q", i, paths[i], expected[i]) + } + } +} + +func TestBuildModelRegistryLoadsProbeModelsFromActiveAndAccountPaths(t *testing.T) { + dir := t.TempDir() + activeProbe := filepath.Join(dir, "active-probe.json") + accountProbe := filepath.Join(dir, "account-probe.json") + activeBlob := `{"models":[{"model":"active-model-raw","modelMessage":"Active Model","modelFamily":"openai","displayGroup":"fast","isDisabled":false,"markdownChat":{"beta":false},"workflow":{"finalModelName":"active-notion-model","beta":false},"customAgent":{"finalModelName":"","beta":false}}]}` + accountBlob := `{"models":[{"model":"account-model-raw","modelMessage":"Account Model","modelFamily":"anthropic","displayGroup":"intelligent","isDisabled":false,"markdownChat":{"beta":false},"workflow":{"finalModelName":"account-notion-model","beta":false},"customAgent":{"finalModelName":"","beta":false}}]}` + writeProbeFile := func(path string, blob string) { + payload := map[string]any{ + "email": "tester@example.com", + "userId": "user-id", + "spaceId": "space-id", + "clientVersion": "v1", + "embeddedModels": blob, + } + encoded, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal probe payload failed: %v", err) + } + if err := os.WriteFile(path, encoded, 0o600); err != nil { + t.Fatalf("write probe payload failed: %v", err) + } + } + writeProbeFile(activeProbe, activeBlob) + writeProbeFile(accountProbe, accountBlob) + + cfg := normalizeConfig(AppConfig{ + ProbeJSON: activeProbe, + Accounts: []NotionAccount{ + {Email: "alpha@example.com", ProbeJSON: accountProbe}, + }, + }) + registry := buildModelRegistry(cfg) + if _, err := registry.Resolve("active-model", ""); err != nil { + t.Fatalf("expected active probe model to be loaded, got err=%v", err) + } + if _, err := registry.Resolve("account-model", ""); err != nil { + t.Fatalf("expected account probe model to be loaded, got err=%v", err) + } +} + +func TestDeleteAccountUsesCanonicalKeyComparison(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + ActiveAccount: " Alice@Example.com ", + ProbeJSON: "probe_files/notion_accounts/alice/probe.json", + Accounts: []NotionAccount{ + {Email: "alice@example.com"}, + }, + }) + ok := cfg.DeleteAccount("ALICE@example.com") + if !ok { + t.Fatalf("expected delete to succeed") + } + if len(cfg.Accounts) != 0 { + t.Fatalf("expected accounts to be empty after delete, got %d", len(cfg.Accounts)) + } + if cfg.ActiveAccount != "" { + t.Fatalf("expected active account to be cleared, got %q", cfg.ActiveAccount) + } + if cfg.ProbeJSON != "" { + t.Fatalf("expected probe json to be cleared, got %q", cfg.ProbeJSON) + } +} diff --git a/internal/app/metrics.go b/internal/app/metrics.go new file mode 100644 index 0000000..b1ed9b6 --- /dev/null +++ b/internal/app/metrics.go @@ -0,0 +1,435 @@ +package app + +import ( + "expvar" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +type histogramSeries struct { + count uint64 + sum float64 + buckets []uint64 +} + +func newHistogramSeries(bucketCount int) *histogramSeries { + if bucketCount < 0 { + bucketCount = 0 + } + return &histogramSeries{ + buckets: make([]uint64, bucketCount), + } +} + +func (s *histogramSeries) observe(seconds float64, bounds []float64) { + if s == nil { + return + } + if seconds < 0 { + seconds = 0 + } + s.count++ + s.sum += seconds + for idx, bound := range bounds { + if seconds <= bound { + s.buckets[idx]++ + } + } +} + +type requestDurationKey struct { + Path string + Method string + Status string +} + +type sqliteDurationKey struct { + Op string +} + +var requestDurationBuckets = []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} +var wreqFFICallDurationBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5} +var sqliteOpDurationBuckets = []float64{0.0005, 0.001, 0.0025, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1} + +var ( + requestDurationMu sync.Mutex + requestDurationSeries = map[requestDurationKey]*histogramSeries{} + + dispatchInflightMu sync.Mutex + dispatchInflight = map[string]int64{} + + wreqFFICallMu sync.Mutex + wreqFFICallSeries = newHistogramSeries(len(wreqFFICallDurationBuckets)) + + browserSpawnMu sync.Mutex + browserSpawnTotal uint64 + + browserPoolWorkerMu sync.Mutex + browserPoolWorkerTotal uint64 + + sqliteDurationMu sync.Mutex + sqliteDurationSeries = map[sqliteDurationKey]*histogramSeries{} +) + +func resetMetricsForTest() { + requestDurationMu.Lock() + requestDurationSeries = map[requestDurationKey]*histogramSeries{} + requestDurationMu.Unlock() + + dispatchInflightMu.Lock() + dispatchInflight = map[string]int64{} + dispatchInflightMu.Unlock() + + wreqFFICallMu.Lock() + wreqFFICallSeries = newHistogramSeries(len(wreqFFICallDurationBuckets)) + wreqFFICallMu.Unlock() + + browserSpawnMu.Lock() + browserSpawnTotal = 0 + browserSpawnMu.Unlock() + + browserPoolWorkerMu.Lock() + browserPoolWorkerTotal = 0 + browserPoolWorkerMu.Unlock() + + sqliteDurationMu.Lock() + sqliteDurationSeries = map[sqliteDurationKey]*histogramSeries{} + sqliteDurationMu.Unlock() +} + +func observeRequestDuration(path string, method string, status int, elapsed time.Duration) { + seconds := elapsed.Seconds() + if seconds < 0 { + seconds = 0 + } + key := requestDurationKey{ + Path: normalizeMetricsPathLabel(path), + Method: strings.ToUpper(strings.TrimSpace(method)), + Status: strconv.Itoa(status), + } + if key.Method == "" { + key.Method = "UNKNOWN" + } + requestDurationMu.Lock() + series := requestDurationSeries[key] + if series == nil { + series = newHistogramSeries(len(requestDurationBuckets)) + requestDurationSeries[key] = series + } + series.observe(seconds, requestDurationBuckets) + requestDurationMu.Unlock() +} + +func setDispatchSlotInflight(email string, inflight int) { + key := canonicalEmailKey(email) + if key == "" { + return + } + if inflight < 0 { + inflight = 0 + } + dispatchInflightMu.Lock() + dispatchInflight[key] = int64(inflight) + dispatchInflightMu.Unlock() +} + +func syncDispatchSlotInflightFromSlots(next map[string]*accountSlot) { + dispatchInflightMu.Lock() + defer dispatchInflightMu.Unlock() + for key := range dispatchInflight { + if _, ok := next[key]; !ok { + delete(dispatchInflight, key) + } + } + for key, slot := range next { + if slot == nil { + continue + } + inflight := slot.inflight.Load() + if inflight < 0 { + inflight = 0 + } + dispatchInflight[key] = int64(inflight) + } +} + +func observeWreqFFICallDuration(elapsed time.Duration) { + seconds := elapsed.Seconds() + if seconds < 0 { + seconds = 0 + } + wreqFFICallMu.Lock() + wreqFFICallSeries.observe(seconds, wreqFFICallDurationBuckets) + wreqFFICallMu.Unlock() +} + +func addBrowserHelperSpawn() { + browserSpawnMu.Lock() + browserSpawnTotal++ + browserSpawnMu.Unlock() +} + +func addBrowserHelperPoolWorkerSpawn() { + browserPoolWorkerMu.Lock() + browserPoolWorkerTotal++ + browserPoolWorkerMu.Unlock() +} + +func observeSQLiteOpDuration(op string, elapsed time.Duration) { + op = strings.TrimSpace(strings.ToLower(op)) + if op == "" { + op = "unknown" + } + seconds := elapsed.Seconds() + if seconds < 0 { + seconds = 0 + } + key := sqliteDurationKey{Op: op} + sqliteDurationMu.Lock() + series := sqliteDurationSeries[key] + if series == nil { + series = newHistogramSeries(len(sqliteOpDurationBuckets)) + sqliteDurationSeries[key] = series + } + series.observe(seconds, sqliteOpDurationBuckets) + sqliteDurationMu.Unlock() +} + +func normalizeMetricsPathLabel(path string) string { + clean := strings.TrimSpace(path) + if clean == "" { + return "unknown" + } + switch { + case clean == "/": + return "/" + case clean == "/healthz": + return "/healthz" + case clean == "/metrics": + return "/metrics" + case clean == "/debug/vars": + return "/debug/vars" + case strings.HasPrefix(clean, "/v1/models/"): + return "/v1/models/:id" + case clean == "/v1/models": + return "/v1/models" + case strings.HasPrefix(clean, "/v1/responses/"): + return "/v1/responses/:id" + case clean == "/v1/responses": + return "/v1/responses" + case clean == "/v1/chat/completions": + return "/v1/chat/completions" + case clean == "/v1/st/chat/completions": + return "/v1/st/chat/completions" + case strings.HasPrefix(clean, "/admin/accounts/"): + return "/admin/accounts/:id" + case strings.HasPrefix(clean, "/admin/conversations/"): + return "/admin/conversations/:id" + case strings.HasPrefix(clean, "/admin"): + return "/admin/*" + } + if strings.Count(clean, "/") >= 2 { + parts := strings.Split(clean, "/") + if len(parts) >= 3 { + return "/" + parts[1] + "/" + parts[2] + "/*" + } + } + return clean +} + +func writePrometheusMetrics(w http.ResponseWriter) { + if w == nil { + return + } + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + w.WriteHeader(http.StatusOK) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_request_duration_seconds HTTP request duration seconds by path/method/status.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_request_duration_seconds histogram") + writeRequestDurationHistogram(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_dispatch_slot_inflight Current in-flight dispatch slots per account email.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_dispatch_slot_inflight gauge") + writeDispatchInflightGauge(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_wreq_ffi_call_duration_seconds Duration of wreq-based helper calls in seconds.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_wreq_ffi_call_duration_seconds histogram") + writeWreqFFICallHistogram(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_browser_helper_spawn_total Total spawned browser helper subprocesses.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_browser_helper_spawn_total counter") + writeBrowserHelperSpawnCounter(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_browser_helper_pool_worker_spawn_total Total spawned persistent browser helper pool workers.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_browser_helper_pool_worker_spawn_total counter") + writeBrowserHelperPoolWorkerSpawnCounter(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_sqlite_op_duration_seconds SQLite operation durations in seconds by operation.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_sqlite_op_duration_seconds histogram") + writeSQLiteDurationHistogram(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_response_store_prune_total Total number of pruned in-memory response entries by reason.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_response_store_prune_total counter") + writeResponseStorePruneCounter(w) +} + +func writeRequestDurationHistogram(w http.ResponseWriter) { + requestDurationMu.Lock() + seriesMap := make(map[requestDurationKey]*histogramSeries, len(requestDurationSeries)) + keys := make([]requestDurationKey, 0, len(requestDurationSeries)) + for key, series := range requestDurationSeries { + copySeries := *series + copySeries.buckets = append([]uint64(nil), series.buckets...) + seriesMap[key] = ©Series + keys = append(keys, key) + } + requestDurationMu.Unlock() + + sort.Slice(keys, func(i, j int) bool { + if keys[i].Path != keys[j].Path { + return keys[i].Path < keys[j].Path + } + if keys[i].Method != keys[j].Method { + return keys[i].Method < keys[j].Method + } + return keys[i].Status < keys[j].Status + }) + + for _, key := range keys { + series := seriesMap[key] + if series == nil { + continue + } + labelPrefix := fmt.Sprintf("path=\"%s\",method=\"%s\",status=\"%s\"", + escapePrometheusLabelValue(key.Path), + escapePrometheusLabelValue(key.Method), + escapePrometheusLabelValue(key.Status), + ) + writeHistogramSeries(w, "notion2api_request_duration_seconds", labelPrefix, requestDurationBuckets, series) + } +} + +func writeDispatchInflightGauge(w http.ResponseWriter) { + dispatchInflightMu.Lock() + type pair struct { + email string + value int64 + } + items := make([]pair, 0, len(dispatchInflight)) + for email, value := range dispatchInflight { + items = append(items, pair{email: email, value: value}) + } + dispatchInflightMu.Unlock() + + sort.Slice(items, func(i, j int) bool { return items[i].email < items[j].email }) + for _, item := range items { + _, _ = fmt.Fprintf(w, "notion2api_dispatch_slot_inflight{email=\"%s\"} %d\n", + escapePrometheusLabelValue(item.email), item.value) + } +} + +func writeWreqFFICallHistogram(w http.ResponseWriter) { + wreqFFICallMu.Lock() + series := *wreqFFICallSeries + series.buckets = append([]uint64(nil), wreqFFICallSeries.buckets...) + wreqFFICallMu.Unlock() + writeHistogramSeries(w, "notion2api_wreq_ffi_call_duration_seconds", "", wreqFFICallDurationBuckets, &series) +} + +func writeBrowserHelperSpawnCounter(w http.ResponseWriter) { + browserSpawnMu.Lock() + total := browserSpawnTotal + browserSpawnMu.Unlock() + _, _ = fmt.Fprintf(w, "notion2api_browser_helper_spawn_total %d\n", total) +} + +func writeBrowserHelperPoolWorkerSpawnCounter(w http.ResponseWriter) { + browserPoolWorkerMu.Lock() + total := browserPoolWorkerTotal + browserPoolWorkerMu.Unlock() + _, _ = fmt.Fprintf(w, "notion2api_browser_helper_pool_worker_spawn_total %d\n", total) +} + +func writeSQLiteDurationHistogram(w http.ResponseWriter) { + sqliteDurationMu.Lock() + seriesMap := make(map[sqliteDurationKey]*histogramSeries, len(sqliteDurationSeries)) + keys := make([]sqliteDurationKey, 0, len(sqliteDurationSeries)) + for key, series := range sqliteDurationSeries { + copySeries := *series + copySeries.buckets = append([]uint64(nil), series.buckets...) + seriesMap[key] = ©Series + keys = append(keys, key) + } + sqliteDurationMu.Unlock() + + sort.Slice(keys, func(i, j int) bool { return keys[i].Op < keys[j].Op }) + for _, key := range keys { + series := seriesMap[key] + if series == nil { + continue + } + labelPrefix := fmt.Sprintf("op=\"%s\"", escapePrometheusLabelValue(key.Op)) + writeHistogramSeries(w, "notion2api_sqlite_op_duration_seconds", labelPrefix, sqliteOpDurationBuckets, series) + } +} + +func writeResponseStorePruneCounter(w http.ResponseWriter) { + if w == nil { + return + } + entryVar := responseStorePruneTotalMetric.Get("expired_entries") + if entryVar == nil { + _, _ = fmt.Fprintln(w, "notion2api_response_store_prune_total{reason=\"expired_entries\"} 0") + return + } + entryValue, ok := entryVar.(*expvar.Int) + if !ok || entryValue == nil { + _, _ = fmt.Fprintln(w, "notion2api_response_store_prune_total{reason=\"expired_entries\"} 0") + return + } + _, _ = fmt.Fprintf(w, "notion2api_response_store_prune_total{reason=\"expired_entries\"} %d\n", entryValue.Value()) +} + +func writeHistogramSeries(w http.ResponseWriter, metricName string, baseLabels string, bounds []float64, series *histogramSeries) { + if w == nil || series == nil { + return + } + for idx, bound := range bounds { + le := strconv.FormatFloat(bound, 'g', -1, 64) + labels := withExtraLabel(baseLabels, "le", le) + _, _ = fmt.Fprintf(w, "%s_bucket{%s} %d\n", metricName, labels, series.buckets[idx]) + } + infLabels := withExtraLabel(baseLabels, "le", "+Inf") + _, _ = fmt.Fprintf(w, "%s_bucket{%s} %d\n", metricName, infLabels, series.count) + if baseLabels == "" { + _, _ = fmt.Fprintf(w, "%s_sum %s\n", metricName, formatFloat(series.sum)) + _, _ = fmt.Fprintf(w, "%s_count %d\n", metricName, series.count) + return + } + _, _ = fmt.Fprintf(w, "%s_sum{%s} %s\n", metricName, baseLabels, formatFloat(series.sum)) + _, _ = fmt.Fprintf(w, "%s_count{%s} %d\n", metricName, baseLabels, series.count) +} + +func withExtraLabel(base string, name string, value string) string { + extra := fmt.Sprintf("%s=\"%s\"", name, escapePrometheusLabelValue(value)) + if strings.TrimSpace(base) == "" { + return extra + } + return base + "," + extra +} + +func escapePrometheusLabelValue(value string) string { + value = strings.ReplaceAll(value, `\`, `\\`) + value = strings.ReplaceAll(value, "\n", `\n`) + value = strings.ReplaceAll(value, `"`, `\"`) + return value +} + +func formatFloat(value float64) string { + return strconv.FormatFloat(value, 'g', -1, 64) +} diff --git a/internal/app/models.go b/internal/app/models.go index 3ad7ae0..537290c 100644 --- a/internal/app/models.go +++ b/internal/app/models.go @@ -6,6 +6,7 @@ import ( "os" "sort" "strings" + "sync" "unicode" ) @@ -56,7 +57,7 @@ func builtinModelDefinitions() []ModelDefinition { func buildModelRegistry(cfg AppConfig) ModelRegistry { entries := builtinModelDefinitions() - if probeEntries := extractProbeModelDefinitions(cfg.ProbeJSON); len(probeEntries) > 0 { + if probeEntries := extractProbeModelDefinitions(collectProbeModelPaths(cfg)); len(probeEntries) > 0 { entries = mergeModelDefinitions(entries, probeEntries) } if len(cfg.Models) > 0 { @@ -98,6 +99,27 @@ func buildModelRegistry(cfg AppConfig) ModelRegistry { return ModelRegistry{Entries: entries, ByID: byID, AliasToID: aliasToID} } +func collectProbeModelPaths(cfg AppConfig) []string { + paths := make([]string, 0, len(cfg.Accounts)+1) + seen := map[string]struct{}{} + appendPath := func(path string) { + clean := strings.TrimSpace(path) + if clean == "" { + return + } + if _, exists := seen[clean]; exists { + return + } + seen[clean] = struct{}{} + paths = append(paths, clean) + } + appendPath(cfg.ProbeJSON) + for _, account := range cfg.Accounts { + appendPath(account.ProbeJSON) + } + return paths +} + func (r ModelRegistry) Resolve(value string, fallback string) (ModelDefinition, error) { candidate := strings.TrimSpace(value) if candidate == "" { @@ -118,7 +140,54 @@ func (r ModelRegistry) Resolve(value string, fallback string) (ModelDefinition, return ModelDefinition{}, fmt.Errorf("unknown model: %s", candidate) } -func extractProbeModelDefinitions(path string) []ModelDefinition { +func extractProbeModelDefinitions(paths []string) []ModelDefinition { + if len(paths) == 0 { + return nil + } + parseConcurrency := len(paths) + if parseConcurrency > 4 { + parseConcurrency = 4 + } + if parseConcurrency < 1 { + parseConcurrency = 1 + } + type indexedResult struct { + index int + items []ModelDefinition + } + results := make([]indexedResult, len(paths)) + sem := make(chan struct{}, parseConcurrency) + var wg sync.WaitGroup + for i, path := range paths { + wg.Add(1) + go func(index int, probePath string) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + results[index] = indexedResult{ + index: index, + items: extractProbeModelDefinitionsFromPath(probePath), + } + }(i, path) + } + wg.Wait() + + seen := map[string]struct{}{} + out := make([]ModelDefinition, 0) + for _, result := range results { + for _, item := range result.items { + key := item.ID + "|" + item.NotionModel + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + out = append(out, item) + } + } + return out +} + +func extractProbeModelDefinitionsFromPath(path string) []ModelDefinition { if strings.TrimSpace(path) == "" { return nil } diff --git a/internal/app/notion_client.go b/internal/app/notion_client.go index 38f6087..9bab2a1 100644 --- a/internal/app/notion_client.go +++ b/internal/app/notion_client.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "encoding/json" "errors" + "expvar" "fmt" "io" "log" @@ -19,6 +20,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -36,6 +38,31 @@ const ( var leadingLangTagPattern = regexp.MustCompile(`(?is)^\s*(?:]*>|)\s*`) var prefixedTranscriptStepIDPattern = regexp.MustCompile(`^(?:cfg|ctx|upd)_([0-9a-fA-F]{32})$`) +var notionHTTPTransportCacheMetric = expvar.NewMap("notion2api_http_transport_cache_total") + +type notionHTTPTransportCacheKey struct { + UpstreamBaseURL string + UpstreamOriginURL string + UpstreamHostHeader string + UpstreamTLSServerName string + UpstreamUseEnvProxy bool + ProxyMode string + ProxyURL string + ProxyHTTPURL string + ProxyHTTPSURL string + ResinEnabled bool + ResinURL string + ResinPlatform string + ResinMode string + AccountEmailKey string +} + +var notionTransportCache = struct { + mu sync.RWMutex + items map[notionHTTPTransportCacheKey]*http.Transport +}{ + items: map[notionHTTPTransportCacheKey]*http.Transport{}, +} func bestEffortTimeout(parent context.Context, cap time.Duration) time.Duration { if cap <= 0 { @@ -385,12 +412,13 @@ type ndjsonPatchOperation struct { V any `json:"v"` } -type ndjsonEnvelope struct { - Type string `json:"type"` - Data map[string]any `json:"data,omitempty"` - Version int `json:"version,omitempty"` - V []ndjsonPatchOperation `json:"v,omitempty"` - RecordMap map[string]any `json:"recordMap,omitempty"` +type ndjsonStreamLine struct { + Type string `json:"type"` + V []ndjsonPatchOperation `json:"v,omitempty"` + RecordMap map[string]any `json:"recordMap,omitempty"` + ID string `json:"id,omitempty"` + FinishedAt any `json:"finishedAt,omitempty"` + Value []ndjsonAgentInferenceValue `json:"value,omitempty"` } type ndjsonAgentInferenceValue struct { @@ -623,10 +651,37 @@ func newNotionAIStreamingClient(session SessionInfo, cfg AppConfig, accountEmail return newNotionAIClientWithMode(session, cfg, accountEmail, true) } -func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail string, streaming bool) *NotionAIClient { +func buildNotionHTTPTransportCacheKey(cfg AppConfig, accountEmail string) notionHTTPTransportCacheKey { normalizedCfg := normalizeConfig(cfg) - resolver := NewProxyResolver(normalizedCfg) upstream := normalizedCfg.NotionUpstream() + policy := normalizedCfg.ResolveProxyPolicyForAccount(accountEmail) + return notionHTTPTransportCacheKey{ + UpstreamBaseURL: strings.TrimSpace(upstream.BaseURL), + UpstreamOriginURL: strings.TrimSpace(upstream.OriginURL), + UpstreamHostHeader: strings.TrimSpace(upstream.HostHeader), + UpstreamTLSServerName: strings.TrimSpace(upstream.TLSServerName), + UpstreamUseEnvProxy: upstream.UseEnvProxy, + ProxyMode: strings.TrimSpace(policy.Mode), + ProxyURL: strings.TrimSpace(policy.URL), + ProxyHTTPURL: strings.TrimSpace(policy.HTTPURL), + ProxyHTTPSURL: strings.TrimSpace(policy.HTTPSURL), + ResinEnabled: policy.Resin.Enabled, + ResinURL: strings.TrimSpace(policy.Resin.URL), + ResinPlatform: strings.TrimSpace(policy.Resin.Platform), + ResinMode: strings.TrimSpace(policy.Resin.Mode), + AccountEmailKey: canonicalEmailKey(accountEmail), + } +} + +func cachedNotionHTTPTransport(cfg AppConfig, accountEmail string, resolver *ProxyResolver, upstream NotionUpstream) *http.Transport { + key := buildNotionHTTPTransportCacheKey(cfg, accountEmail) + notionTransportCache.mu.RLock() + cached := notionTransportCache.items[key] + notionTransportCache.mu.RUnlock() + if cached != nil { + notionHTTPTransportCacheMetric.Add("hit_rlock", 1) + return cached + } tlsConfig := &tls.Config{InsecureSkipVerify: true} if strings.TrimSpace(upstream.TLSServerName) != "" { tlsConfig.ServerName = strings.TrimSpace(upstream.TLSServerName) @@ -650,6 +705,23 @@ func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail return proxyFunc(req) }, } + notionTransportCache.mu.Lock() + if existing := notionTransportCache.items[key]; existing != nil { + notionTransportCache.mu.Unlock() + notionHTTPTransportCacheMetric.Add("hit_lock", 1) + return existing + } + notionTransportCache.items[key] = transport + notionTransportCache.mu.Unlock() + notionHTTPTransportCacheMetric.Add("miss_new", 1) + return transport +} + +func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail string, streaming bool) *NotionAIClient { + normalizedCfg := normalizeConfig(cfg) + resolver := NewProxyResolver(normalizedCfg) + upstream := normalizedCfg.NotionUpstream() + transport := cachedNotionHTTPTransport(normalizedCfg, accountEmail, resolver, upstream) timeout := requestTimeout(normalizedCfg) clientTimeout := timeout if streaming { @@ -1083,7 +1155,9 @@ func (c *NotionAIClient) runInferenceTranscriptWithFallback(ctx context.Context, if c.Config.DebugUpstream { log.Printf("[debug_upstream] runInferenceTranscript http start thread_id=%s", threadID) } + callStartedAt := time.Now() parsed, err := c.runInferenceTranscriptHTTP(ctx, payload, threadID, sink) + observeWreqFFICallDuration(time.Since(callStartedAt)) if c.Config.DebugUpstream { log.Printf("[debug_upstream] runInferenceTranscript http done thread_id=%s line_count=%d message_ids=%d err=%v", threadID, parsed.LineCount, len(parsed.MessageIDs), err) } @@ -2336,26 +2410,28 @@ func (s *ndjsonTranscriptState) handleLine(line []byte, threadID string, sink In if len(line) == 0 { return nil } - var envelope ndjsonEnvelope - if err := json.Unmarshal(line, &envelope); err != nil { + var streamLine ndjsonStreamLine + if err := json.Unmarshal(line, &streamLine); err != nil { return err } s.LineCount++ - switch envelope.Type { + switch streamLine.Type { case "patch": - for _, op := range envelope.V { + for _, op := range streamLine.V { if err := s.applyPatchOperation(op, sink); err != nil { return err } } case "agent-inference": - var event ndjsonAgentInferenceEvent - if err := json.Unmarshal(line, &event); err != nil { - return err + event := ndjsonAgentInferenceEvent{ + Type: streamLine.Type, + ID: streamLine.ID, + FinishedAt: streamLine.FinishedAt, + Value: streamLine.Value, } return s.mergeAgentInferenceEvent(event, sink) case "record-map": - messageIDs, agent, outcomeErr, ok := finalThreadOutcomeFromRecordMap(envelope.RecordMap, threadID) + messageIDs, agent, outcomeErr, ok := finalThreadOutcomeFromRecordMap(streamLine.RecordMap, threadID) if len(messageIDs) > 0 { s.MessageIDs = messageIDs } @@ -2395,52 +2471,79 @@ func (s *ndjsonTranscriptState) result() ndjsonParseResult { func consumeNDJSONStream(reader io.Reader, threadID string, sink InferenceStreamSink) (ndjsonParseResult, error) { state := &ndjsonTranscriptState{ActiveAgentIndex: -1} - buffered := bufio.NewReader(reader) - for { - line, err := buffered.ReadBytes('\n') - if len(line) > 0 { - if handleErr := state.handleLine(line, threadID, sink); handleErr != nil { - return state.result(), handleErr - } - if state.hasTerminalAnswer() { - return state.result(), nil - } + scanner := newNDJSONScanner(reader) + for scanner.Scan() { + line := scanner.Bytes() + if handleErr := state.handleLine(line, threadID, sink); handleErr != nil { + return state.result(), handleErr } - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return state.result(), err + if state.hasTerminalAnswer() { + return state.result(), nil } } + if err := normalizeNDJSONScanError(scanner.Err()); err != nil { + return state.result(), err + } return state.result(), nil } var ndjsonIdleAfterAnswerTimeout = 5 * time.Second +var errNDJSONLineTooLarge = errors.New("ndjson line too large") + +const ( + ndjsonScannerInitialBuffer = 64 * 1024 + ndjsonMaxLineBytes = 16 * 1024 * 1024 +) type ndjsonReadEvent struct { line []byte err error } +func newNDJSONScanner(reader io.Reader) *bufio.Scanner { + scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 0, ndjsonScannerInitialBuffer), ndjsonMaxLineBytes) + return scanner +} + +func normalizeNDJSONScanError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, bufio.ErrTooLong) { + return fmt.Errorf("%w: exceeds %d bytes", errNDJSONLineTooLarge, ndjsonMaxLineBytes) + } + return err +} + func consumeNDJSONStreamWithIdleClose(reader io.ReadCloser, threadID string, sink InferenceStreamSink, idleAfterAnswer time.Duration) (ndjsonParseResult, error) { state := &ndjsonTranscriptState{ActiveAgentIndex: -1} - buffered := bufio.NewReader(reader) events := make(chan ndjsonReadEvent, 1) done := make(chan struct{}) defer close(done) go func() { - for { - line, err := buffered.ReadBytes('\n') + scanner := newNDJSONScanner(reader) + for scanner.Scan() { + line := append([]byte(nil), scanner.Bytes()...) select { - case events <- ndjsonReadEvent{line: line, err: err}: + case events <- ndjsonReadEvent{line: line}: case <-done: return } - if err != nil { + } + if err := normalizeNDJSONScanError(scanner.Err()); err != nil { + select { + case events <- ndjsonReadEvent{err: err}: + case <-done: return } + return + } + select { + case events <- ndjsonReadEvent{err: io.EOF}: + case <-done: + return } }() diff --git a/internal/app/notion_client_best_effort_test.go b/internal/app/notion_client_best_effort_test.go index bb04158..d998daf 100644 --- a/internal/app/notion_client_best_effort_test.go +++ b/internal/app/notion_client_best_effort_test.go @@ -1,8 +1,11 @@ package app import ( + "bytes" "context" "encoding/json" + "errors" + "expvar" "io" "net/http" "net/http/httptest" @@ -124,11 +127,128 @@ func TestProbeAccountProtocolHealthIgnoresContextAbort(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) defer cancel() - if err := app.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := app.probeAccountProtocolHealth(ctx, cfg, session, ""); err != nil { t.Fatalf("expected context abort probe to be ignored, got %v", err) } } +func TestProbeAccountProtocolHealthCachesProbeSuccessWithinTTL(t *testing.T) { + cfg := defaultConfig() + cfg.Dispatch.ProbeCacheTTLSeconds = 45 + state := &ServerState{DispatchProbeCache: newProbeCache()} + app := &App{ + State: state, + } + callCount := 0 + app.accountProtocolProbeOverride = func(ctx context.Context, cfg AppConfig, session SessionInfo) error { + callCount++ + return nil + } + + session := SessionInfo{ + UserEmail: "alice@example.com", + } + ctx := context.Background() + + for i := 0; i < 10; i++ { + if err := app.probeAccountProtocolHealth(ctx, cfg, session, "alice@example.com"); err != nil { + t.Fatalf("probe call %d failed: %v", i+1, err) + } + } + if callCount != 1 { + t.Fatalf("expected one upstream probe call within ttl window, got %d", callCount) + } +} + +func TestProbeAccountProtocolHealthReprobesAfterFailure(t *testing.T) { + cfg := defaultConfig() + cfg.Dispatch.ProbeCacheTTLSeconds = 45 + state := &ServerState{DispatchProbeCache: newProbeCache()} + app := &App{ + State: state, + } + callCount := 0 + app.accountProtocolProbeOverride = func(ctx context.Context, cfg AppConfig, session SessionInfo) error { + callCount++ + if callCount == 1 { + return errors.New("probe failed once") + } + return nil + } + + session := SessionInfo{ + UserEmail: "alice@example.com", + } + ctx := context.Background() + + if err := app.probeAccountProtocolHealth(ctx, cfg, session, "alice@example.com"); err == nil { + t.Fatalf("expected first probe failure") + } + if err := app.probeAccountProtocolHealth(ctx, cfg, session, "alice@example.com"); err != nil { + t.Fatalf("expected second probe to run and succeed, got %v", err) + } + if callCount != 2 { + t.Fatalf("expected second request to reprobe after failure, got callCount=%d", callCount) + } +} + +func TestRunPromptWithSessionIncrementsWreqClientMetric(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + beforeStandard := int64(0) + if v := wreqClientNewTotalMetric.Get("standard"); v != nil { + beforeStandard = v.(*expvar.Int).Value() + } + beforeStreaming := int64(0) + if v := wreqClientNewTotalMetric.Get("streaming"); v != nil { + beforeStreaming = v.(*expvar.Int).Value() + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = app.runPromptWithSession(ctx, cfg, session, "", PromptRunRequest{Prompt: "hi"}, nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + _, err = app.runPromptWithSession(ctx, cfg, session, "", PromptRunRequest{Prompt: "hi"}, func(string) error { return nil }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled for streaming run, got %v", err) + } + + afterStandard := int64(0) + if v := wreqClientNewTotalMetric.Get("standard"); v != nil { + afterStandard = v.(*expvar.Int).Value() + } + afterStreaming := int64(0) + if v := wreqClientNewTotalMetric.Get("streaming"); v != nil { + afterStreaming = v.(*expvar.Int).Value() + } + if afterStandard-beforeStandard < 1 { + t.Fatalf("expected standard metric increment, before=%d after=%d", beforeStandard, afterStandard) + } + if afterStreaming-beforeStreaming < 1 { + t.Fatalf("expected streaming metric increment, before=%d after=%d", beforeStreaming, afterStreaming) + } +} + func TestConsumeNDJSONStreamWithIdleCloseReturnsUpstreamErrorStep(t *testing.T) { threadID := "thread-error" messageID := "msg-error" @@ -147,6 +267,68 @@ func TestConsumeNDJSONStreamWithIdleCloseReturnsUpstreamErrorStep(t *testing.T) } } +func TestConsumeNDJSONStreamWithIdleCloseParsesFinalLineWithoutTrailingNewline(t *testing.T) { + threadID := "thread-error-no-newline" + messageID := "msg-error-no-newline" + recordMap := buildThreadErrorRecordMap(threadID, "test-space", messageID, "AI inference is not allowed.", "trust-rule-denied", "trace-error") + line, err := json.Marshal(map[string]any{ + "type": "record-map", + "recordMap": recordMap, + }) + if err != nil { + t.Fatalf("marshal ndjson line failed: %v", err) + } + + _, gotErr := consumeNDJSONStreamWithIdleClose(io.NopCloser(strings.NewReader(string(line))), threadID, InferenceStreamSink{}, 0) + if gotErr == nil || !strings.Contains(gotErr.Error(), "AI inference is not allowed") { + t.Fatalf("expected upstream error step without trailing newline, got %v", gotErr) + } +} + +func TestConsumeNDJSONStreamWithIdleCloseRejectsOversizedLine(t *testing.T) { + threadID := "thread-oversized-line" + oversizedLine := append(bytes.Repeat([]byte("a"), ndjsonMaxLineBytes+1), '\n') + + _, gotErr := consumeNDJSONStreamWithIdleClose(io.NopCloser(bytes.NewReader(oversizedLine)), threadID, InferenceStreamSink{}, 0) + if gotErr == nil { + t.Fatalf("expected oversized NDJSON line error, got nil") + } + if !errors.Is(gotErr, errNDJSONLineTooLarge) { + t.Fatalf("expected errNDJSONLineTooLarge, got %v", gotErr) + } +} + +func TestConsumeNDJSONStreamParsesFinalLineWithoutTrailingNewline(t *testing.T) { + threadID := "thread-error-no-newline-fallback" + messageID := "msg-error-no-newline-fallback" + recordMap := buildThreadErrorRecordMap(threadID, "test-space", messageID, "AI inference is not allowed.", "trust-rule-denied", "trace-error") + line, err := json.Marshal(map[string]any{ + "type": "record-map", + "recordMap": recordMap, + }) + if err != nil { + t.Fatalf("marshal ndjson line failed: %v", err) + } + + _, gotErr := consumeNDJSONStream(strings.NewReader(string(line)), threadID, InferenceStreamSink{}) + if gotErr == nil || !strings.Contains(gotErr.Error(), "AI inference is not allowed") { + t.Fatalf("expected upstream error step without trailing newline, got %v", gotErr) + } +} + +func TestConsumeNDJSONStreamRejectsOversizedLine(t *testing.T) { + threadID := "thread-oversized-line-fallback" + oversizedLine := append(bytes.Repeat([]byte("a"), ndjsonMaxLineBytes+1), '\n') + + _, gotErr := consumeNDJSONStream(bytes.NewReader(oversizedLine), threadID, InferenceStreamSink{}) + if gotErr == nil { + t.Fatalf("expected oversized NDJSON line error, got nil") + } + if !errors.Is(gotErr, errNDJSONLineTooLarge) { + t.Fatalf("expected errNDJSONLineTooLarge, got %v", gotErr) + } +} + func TestRunPromptReturnsUpstreamErrorStep(t *testing.T) { messageID := "msg-error" var recordMap map[string]any diff --git a/internal/app/notion_client_browser_fallback_test.go b/internal/app/notion_client_browser_fallback_test.go index 34e607c..e3f4164 100644 --- a/internal/app/notion_client_browser_fallback_test.go +++ b/internal/app/notion_client_browser_fallback_test.go @@ -2,9 +2,15 @@ package app import ( "context" + "crypto/sha256" "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" + "os" + "os/exec" + "path/filepath" "strings" "testing" "time" @@ -359,3 +365,345 @@ func TestBrowserFallbackTimeoutForPayloadHonorsParentDeadline(t *testing.T) { t.Fatalf("timeout = %s, want <= 40ms", got) } } + +func TestClassifyBrowserHelperExecErrorBranches(t *testing.T) { + t.Run("err not found maps to unavailable", func(t *testing.T) { + err := classifyBrowserHelperExecError(context.Background(), "node", exec.ErrNotFound, "") + var unavailable *browserHelperUnavailableError + if !errors.As(err, &unavailable) { + t.Fatalf("expected browserHelperUnavailableError, got %T %v", err, err) + } + if !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected not found message, got %v", err) + } + }) + + t.Run("context cancellation takes precedence", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := classifyBrowserHelperExecError(ctx, "node", errors.New("exec failed"), "stderr text") + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + }) + + t.Run("missing node-wreq maps to unavailable", func(t *testing.T) { + err := classifyBrowserHelperExecError(context.Background(), "node", errors.New("exec failed"), "Error: Cannot find module 'node-wreq'") + var unavailable *browserHelperUnavailableError + if !errors.As(err, &unavailable) { + t.Fatalf("expected browserHelperUnavailableError, got %T %v", err, err) + } + if !strings.Contains(err.Error(), "missing node-wreq module") { + t.Fatalf("unexpected unavailable message: %v", err) + } + }) + + t.Run("generic stderr path", func(t *testing.T) { + err := classifyBrowserHelperExecError(context.Background(), "node", errors.New("exec failed"), "boom") + if err == nil || !strings.Contains(err.Error(), "node helper failed: boom") { + t.Fatalf("expected generic helper failed error, got %v", err) + } + }) +} + +func TestSupportsBrowserRunInferenceFallbackMatrix(t *testing.T) { + baseCfg := defaultConfig() + baseCfg.APIKey = "test-key" + baseCfg.UpstreamBaseURL = "https://www.notion.so" + baseCfg.UpstreamOrigin = "https://www.notion.so" + client := newNotionAIClient(SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{Name: "token_v2", Value: "cookie"}}, + }, baseCfg, "") + + if !client.supportsBrowserRunInferenceFallback() { + t.Fatalf("expected fallback support for default https notion upstream") + } + + clientWithOverride := *client + clientWithOverride.browserRunInferenceFallback = func(ctx context.Context, payload map[string]any) (string, error) { + return "", nil + } + if !clientWithOverride.supportsBrowserRunInferenceFallback() { + t.Fatalf("expected explicit fallback override to force support") + } + + hostHeaderCfg := baseCfg + hostHeaderCfg.UpstreamHost = "example.com" + hostHeaderClient := newNotionAIClient(client.Session, hostHeaderCfg, "") + if hostHeaderClient.supportsBrowserRunInferenceFallback() { + t.Fatalf("expected fallback disabled when upstream host header override is set") + } + + localCfg := baseCfg + localCfg.UpstreamBaseURL = "https://127.0.0.1:8443" + localCfg.UpstreamOrigin = "https://127.0.0.1:8443" + localClient := newNotionAIClient(client.Session, localCfg, "") + if localClient.supportsBrowserRunInferenceFallback() { + t.Fatalf("expected fallback disabled for local upstream") + } + + httpCfg := baseCfg + httpCfg.UpstreamBaseURL = "http://www.notion.so" + httpCfg.UpstreamOrigin = "http://www.notion.so" + httpClient := newNotionAIClient(client.Session, httpCfg, "") + if httpClient.supportsBrowserRunInferenceFallback() { + t.Fatalf("expected fallback disabled for non-https upstream") + } +} + +func TestEnsureHelperScriptFileStageAStablePath(t *testing.T) { + tempDir := t.TempDir() + script := "console.log('stage-a-stable')\n" + + gotPath, err := ensureHelperScriptFile(tempDir, ".cjs", script) + if err != nil { + t.Fatalf("ensureHelperScriptFile failed: %v", err) + } + + sum := sha256.Sum256([]byte(script)) + wantPath := filepath.Join(tempDir, fmt.Sprintf("notion-helper-%x.cjs", sum)) + if gotPath != wantPath { + t.Fatalf("script path mismatch: got %q want %q", gotPath, wantPath) + } + + content, err := os.ReadFile(gotPath) + if err != nil { + t.Fatalf("read script failed: %v", err) + } + if string(content) != script { + t.Fatalf("script content mismatch: got %q want %q", string(content), script) + } +} + +func TestEnsureHelperScriptFileStageARepeatedCallsKeepSingleScript(t *testing.T) { + tempDir := t.TempDir() + script := "console.log('stage-a-repeat')\n" + + for i := 0; i < 100; i++ { + if _, err := ensureHelperScriptFile(tempDir, ".cjs", script); err != nil { + t.Fatalf("iteration %d ensureHelperScriptFile failed: %v", i+1, err) + } + } + + matches, err := filepath.Glob(filepath.Join(tempDir, "notion-helper-*.cjs")) + if err != nil { + t.Fatalf("glob helper scripts failed: %v", err) + } + if len(matches) != 1 { + t.Fatalf("expected exactly 1 helper script, got %d (%v)", len(matches), matches) + } +} + +func TestBrowserHelperNodeEnvForConfigIncludesPoolSize(t *testing.T) { + cfg := defaultConfig() + cfg.Browser.HelperPoolSize = 3 + env := browserHelperNodeEnvForConfig(cfg) + joined := strings.Join(env, "\n") + if !strings.Contains(joined, "NOTION2API_BROWSER_HELPER_POOL_SIZE=3") { + t.Fatalf("expected helper pool size env to include configured size, got %v", env) + } +} + +func TestConfiguredBrowserHelperPoolSizeBoundsAndFallback(t *testing.T) { + if got := configuredBrowserHelperPoolSize(normalizeConfig(AppConfig{ + Browser: BrowserConfig{HelperPoolSize: 99}, + })); got != 8 { + t.Fatalf("expected oversized config to clamp to 8, got %d", got) + } + if got := configuredBrowserHelperPoolSize(normalizeConfig(AppConfig{ + Browser: BrowserConfig{HelperPoolSize: -2}, + })); got < 1 || got > 8 { + t.Fatalf("expected fallback cpu-based pool size in [1,8], got %d", got) + } +} + +func TestBrowserHelperNodeEnvForConfigOmitsPoolSizeWhenDisabled(t *testing.T) { + cfg := defaultConfig() + cfg.Browser.HelperPoolSize = 0 + env := browserHelperNodeEnvForConfig(cfg) + for _, item := range env { + if strings.Contains(item, "NOTION2API_BROWSER_HELPER_POOL_SIZE=") { + t.Fatalf("expected pool-size env to be omitted for disabled pool, got %v", env) + } + } +} + +func TestBrowserHelperNodeEnvForConfigHonorsEnvOverride(t *testing.T) { + t.Setenv("NOTION2API_BROWSER_HELPER_POOL_SIZE", "5") + cfg := defaultConfig() + cfg.Browser.HelperPoolSize = 2 + env := browserHelperNodeEnvForConfig(cfg) + joined := strings.Join(env, "\n") + if !strings.Contains(joined, "NOTION2API_BROWSER_HELPER_POOL_SIZE=5") { + t.Fatalf("expected env override to win, got %v", env) + } + if strings.Contains(joined, "NOTION2API_BROWSER_HELPER_POOL_SIZE=2") { + t.Fatalf("expected config value to be ignored when env override is set, got %v", env) + } +} + +func TestExecuteHelperSubprocessPooledSpawnsWorkersAndReusesPool(t *testing.T) { + if _, err := exec.LookPath("node"); err != nil { + t.Skipf("skip pooled helper runtime test: node unavailable: %v", err) + } + + resetMetricsForTest() + resetBrowserHelperPoolsForTest() + t.Cleanup(func() { + resetBrowserHelperPoolsForTest() + }) + + script := ` +const poolMode = String(process.env.NOTION2API_BROWSER_HELPER_MODE || '').trim().toLowerCase() === 'pool'; +function writeFrame(body) { + const header = Buffer.allocUnsafe(4); + header.writeUInt32LE(body.length, 0); + process.stdout.write(header); + process.stdout.write(body); +} +if (!poolMode) { + process.stdin.setEncoding('utf8'); + let raw = ''; + process.stdin.on('data', (chunk) => { raw += chunk; }); + process.stdin.on('end', () => { + const parsed = JSON.parse(raw || '{}'); + process.stdout.write(JSON.stringify({ ok: true, echo: parsed.echo || '' })); + }); +} else { + let pending = Buffer.alloc(0); + const queue = []; + let draining = false; + async function drain() { + if (draining) return; + draining = true; + while (queue.length > 0) { + const payload = queue.shift(); + const parsed = JSON.parse(payload.toString('utf8')); + const out = Buffer.from(JSON.stringify({ + status: 200, + content_type: 'application/x-ndjson', + text: JSON.stringify({ type: 'record-map', echo: parsed.echo || '' }), + })); + writeFrame(out); + } + draining = false; + } + process.stdin.on('data', (chunk) => { + const incoming = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); + pending = Buffer.concat([pending, incoming]); + while (pending.length >= 4) { + const n = pending.readUInt32LE(0); + if (pending.length < 4 + n) break; + queue.push(pending.subarray(4, 4 + n)); + pending = pending.subarray(4 + n); + } + void drain(); + }); + process.stdin.on('end', () => process.exit(0)); +} +` + + extraEnv := []string{browserHelperPoolSizeEnvKey + "=2"} + + out1, err := executeHelperSubprocess(context.Background(), "node", ".cjs", script, []byte(`{"echo":"one"}`), extraEnv) + if err != nil { + t.Fatalf("first pooled executeHelperSubprocess failed: %v", err) + } + var resp1 map[string]any + if err := json.Unmarshal(out1, &resp1); err != nil { + t.Fatalf("unmarshal first pooled response failed: %v", err) + } + text1 := stringValue(resp1["text"]) + if !strings.Contains(text1, `"echo":"one"`) { + t.Fatalf("unexpected first pooled response text: %q", text1) + } + + browserPoolWorkerMu.Lock() + spawnAfterFirst := browserPoolWorkerTotal + browserPoolWorkerMu.Unlock() + if spawnAfterFirst != 2 { + t.Fatalf("expected two pool workers after first call, got %d", spawnAfterFirst) + } + t.Logf("pooled helper worker spawns after first call: %d", spawnAfterFirst) + + out2, err := executeHelperSubprocess(context.Background(), "node", ".cjs", script, []byte(`{"echo":"two"}`), extraEnv) + if err != nil { + t.Fatalf("second pooled executeHelperSubprocess failed: %v", err) + } + var resp2 map[string]any + if err := json.Unmarshal(out2, &resp2); err != nil { + t.Fatalf("unmarshal second pooled response failed: %v", err) + } + text2 := stringValue(resp2["text"]) + if !strings.Contains(text2, `"echo":"two"`) { + t.Fatalf("unexpected second pooled response text: %q", text2) + } + + browserPoolWorkerMu.Lock() + spawnAfterSecond := browserPoolWorkerTotal + browserPoolWorkerMu.Unlock() + if spawnAfterSecond != spawnAfterFirst { + t.Fatalf("expected pooled workers to be reused (no extra spawns), first=%d second=%d", spawnAfterFirst, spawnAfterSecond) + } + t.Logf("pooled helper worker spawns after second call: %d", spawnAfterSecond) +} + +func TestExecuteHelperSubprocessPooledAllowsNilContext(t *testing.T) { + if _, err := exec.LookPath("node"); err != nil { + t.Skipf("skip nil-context pooled helper test: node unavailable: %v", err) + } + + resetBrowserHelperPoolsForTest() + t.Cleanup(func() { + resetBrowserHelperPoolsForTest() + }) + + script := ` +const poolMode = String(process.env.NOTION2API_BROWSER_HELPER_MODE || '').trim().toLowerCase() === 'pool'; +function writeFrame(body) { + const header = Buffer.allocUnsafe(4); + header.writeUInt32LE(body.length, 0); + process.stdout.write(header); + process.stdout.write(body); +} +if (poolMode) { + let pending = Buffer.alloc(0); + process.stdin.on('data', (chunk) => { + const incoming = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); + pending = Buffer.concat([pending, incoming]); + while (pending.length >= 4) { + const n = pending.readUInt32LE(0); + if (pending.length < 4 + n) break; + pending = pending.subarray(4 + n); + const out = Buffer.from(JSON.stringify({ + status: 200, + content_type: 'application/x-ndjson', + text: JSON.stringify({ type: 'record-map', ok: true }), + })); + writeFrame(out); + } + }); + process.stdin.on('end', () => process.exit(0)); +} else { + process.exit(2); +} +` + extraEnv := []string{browserHelperPoolSizeEnvKey + "=2"} + out, err := executeHelperSubprocess(nil, "node", ".cjs", script, []byte(`{"x":1}`), extraEnv) + if err != nil { + t.Fatalf("executeHelperSubprocess with nil context failed: %v", err) + } + var resp map[string]any + if err := json.Unmarshal(out, &resp); err != nil { + t.Fatalf("unmarshal pooled nil-context response failed: %v", err) + } + if got := stringValue(resp["content_type"]); got != "application/x-ndjson" { + t.Fatalf("unexpected content_type: %q", got) + } + if !strings.Contains(stringValue(resp["text"]), `"ok":true`) { + t.Fatalf("unexpected text payload: %q", stringValue(resp["text"])) + } +} diff --git a/internal/app/notion_client_browser_transport.go b/internal/app/notion_client_browser_transport.go index caf2be0..64de5a2 100644 --- a/internal/app/notion_client_browser_transport.go +++ b/internal/app/notion_client_browser_transport.go @@ -3,14 +3,21 @@ package app import ( "bytes" "context" + "crypto/sha256" + "encoding/binary" "encoding/json" "errors" "fmt" + "io" "net/url" "os" "os/exec" "path/filepath" + "runtime" + "sort" + "strconv" "strings" + "sync" "time" ) @@ -18,6 +25,12 @@ const ( browserHelperCancelWaitDelay = 2 * time.Second notionWreqDefaultBrowserProfile = "chrome_142" notionWreqDefaultRequestTimeout = 120 * time.Second + maxBrowserHelperPoolSize = 8 + browserHelperPoolSizeEnvKey = "NOTION2API_BROWSER_HELPER_POOL_SIZE" + browserHelperPoolModeEnvKey = "NOTION2API_BROWSER_HELPER_MODE" + browserHelperPoolMode = "pool" + browserHelperPoolProtoEnvKey = "NOTION2API_BROWSER_HELPER_PROTOCOL" + browserHelperPoolProtoV1 = "N2A_HELPER_POOL_V1" ) type browserTransportRequest struct { @@ -43,8 +56,35 @@ type browserHelperUnavailableError struct { Message string } +type browserHelperPool struct { + runtimeName string + scriptPath string + extraEnv []string + size int + workers chan *browserHelperPoolWorker +} + +type browserHelperPoolWorker struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser +} + +type browserHelperPoolKey struct { + runtimeName string + scriptPath string + envKey string + size int +} + var ( runBrowserFallback = runInferenceTranscriptInBrowserWithNodeWreq + browserHelperPools = struct { + mu sync.Mutex + pools map[browserHelperPoolKey]*browserHelperPool + }{ + pools: map[browserHelperPoolKey]*browserHelperPool{}, + } ) func (e *browserHelperUnavailableError) Error() string { @@ -89,7 +129,8 @@ func runInferenceTranscriptInBrowserWithNodeWreq(ctx context.Context, client *No if err != nil { return "", err } - return runHelperScript(ctx, "node", ".cjs", nodeWreqHelperScript(), request, browserHelperNodeEnv()) + helperEnv := browserHelperNodeEnvForConfig(client.Config) + return runHelperScript(ctx, "node", ".cjs", nodeWreqHelperScript(), request, helperEnv) } func runHelperScript(ctx context.Context, runtimeName string, extension string, script string, request browserTransportRequest, extraEnv []string) (string, error) { @@ -115,32 +156,320 @@ func runHelperScript(ctx context.Context, runtimeName string, extension string, } func executeHelperSubprocess(ctx context.Context, runtimeName string, extension string, script string, requestPayload []byte, extraEnv []string) ([]byte, error) { + startedAt := time.Now() + defer func() { + observeWreqFFICallDuration(time.Since(startedAt)) + }() if _, err := exec.LookPath(runtimeName); err != nil { return nil, &browserHelperUnavailableError{Message: fmt.Sprintf("%s not found", runtimeName)} } - scriptFile, err := os.CreateTemp("", "notion-browser-helper-*"+extension) + scriptPath, err := ensureHelperScriptFile("", extension, script) if err != nil { return nil, err } - scriptPath := scriptFile.Name() - defer os.Remove(scriptPath) - if _, err := scriptFile.WriteString(script); err != nil { - _ = scriptFile.Close() - return nil, err - } - if err := scriptFile.Close(); err != nil { - return nil, err + if size, ok := browserHelperPoolSizeFromEnv(extraEnv); ok && size > 1 { + pooled, pooledErr := executeHelperSubprocessPooled(ctx, runtimeName, scriptPath, requestPayload, extraEnv, size) + if pooledErr == nil { + return pooled, nil + } + return nil, classifyBrowserHelperExecError(ctx, runtimeName, pooledErr, "") } cmd := newBrowserHelperCommand(ctx, runtimeName, scriptPath, requestPayload, extraEnv) var stdout bytes.Buffer var stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr + addBrowserHelperSpawn() if err := runBrowserHelperCommand(ctx, cmd); err != nil { return nil, classifyBrowserHelperExecError(ctx, runtimeName, err, stderr.String()) } return stdout.Bytes(), nil } + +func browserHelperPoolSizeFromEnv(extraEnv []string) (int, bool) { + for _, entry := range extraEnv { + pair := strings.SplitN(strings.TrimSpace(entry), "=", 2) + if len(pair) != 2 { + continue + } + if strings.TrimSpace(pair[0]) != browserHelperPoolSizeEnvKey { + continue + } + n, err := strconv.Atoi(strings.TrimSpace(pair[1])) + if err != nil || n <= 1 { + return 0, false + } + if n > maxBrowserHelperPoolSize { + n = maxBrowserHelperPoolSize + } + return n, true + } + return 0, false +} + +func browserHelperPoolEnvKey(extraEnv []string) string { + if len(extraEnv) == 0 { + return "" + } + filtered := make([]string, 0, len(extraEnv)) + for _, entry := range extraEnv { + trimmed := strings.TrimSpace(entry) + if trimmed == "" { + continue + } + filtered = append(filtered, trimmed) + } + sort.Strings(filtered) + return strings.Join(filtered, "\x00") +} + +func getOrCreateBrowserHelperPool(runtimeName string, scriptPath string, extraEnv []string, size int) (*browserHelperPool, error) { + if size <= 1 { + return nil, fmt.Errorf("invalid browser helper pool size: %d", size) + } + if size > maxBrowserHelperPoolSize { + size = maxBrowserHelperPoolSize + } + key := browserHelperPoolKey{ + runtimeName: strings.TrimSpace(runtimeName), + scriptPath: scriptPath, + envKey: browserHelperPoolEnvKey(extraEnv), + size: size, + } + browserHelperPools.mu.Lock() + if existing := browserHelperPools.pools[key]; existing != nil { + browserHelperPools.mu.Unlock() + return existing, nil + } + pool := &browserHelperPool{ + runtimeName: runtimeName, + scriptPath: scriptPath, + extraEnv: append([]string(nil), extraEnv...), + size: size, + workers: make(chan *browserHelperPoolWorker, size), + } + startedWorkers := make([]*browserHelperPoolWorker, 0, size) + for i := 0; i < size; i++ { + worker, err := startBrowserHelperPoolWorker(runtimeName, scriptPath, pool.extraEnv) + if err != nil { + for _, started := range startedWorkers { + stopBrowserHelperPoolWorker(started) + } + browserHelperPools.mu.Unlock() + return nil, err + } + startedWorkers = append(startedWorkers, worker) + pool.workers <- worker + } + browserHelperPools.pools[key] = pool + browserHelperPools.mu.Unlock() + return pool, nil +} + +func startBrowserHelperPoolWorker(runtimeName string, scriptPath string, extraEnv []string) (*browserHelperPoolWorker, error) { + cmd := exec.Command(runtimeName, scriptPath) + cmd.Env = append(os.Environ(), extraEnv...) + cmd.Env = append(cmd.Env, browserHelperPoolModeEnvKey+"="+browserHelperPoolMode, browserHelperPoolProtoEnvKey+"="+browserHelperPoolProtoV1) + configureBrowserHelperCommand(cmd) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + stdout, err := cmd.StdoutPipe() + if err != nil { + _ = stdin.Close() + return nil, err + } + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Start(); err != nil { + _ = stdin.Close() + _ = stdout.Close() + return nil, err + } + addBrowserHelperSpawn() + addBrowserHelperPoolWorkerSpawn() + return &browserHelperPoolWorker{ + cmd: cmd, + stdin: stdin, + stdout: stdout, + }, nil +} + +func stopBrowserHelperPoolWorker(worker *browserHelperPoolWorker) { + if worker == nil { + return + } + if worker.cmd != nil { + _ = cancelBrowserHelperCommand(worker.cmd) + } + if worker.stdin != nil { + _ = worker.stdin.Close() + } + if worker.stdout != nil { + _ = worker.stdout.Close() + } + if worker.cmd != nil { + _ = worker.cmd.Wait() + } +} + +func writePoolFrame(writer io.Writer, payload []byte) error { + if writer == nil { + return fmt.Errorf("pool writer unavailable") + } + frame := make([]byte, 4+len(payload)) + binary.LittleEndian.PutUint32(frame[:4], uint32(len(payload))) + copy(frame[4:], payload) + _, err := writer.Write(frame) + return err +} + +func readPoolFrame(reader io.Reader) ([]byte, error) { + if reader == nil { + return nil, fmt.Errorf("pool reader unavailable") + } + header := make([]byte, 4) + if _, err := io.ReadFull(reader, header); err != nil { + return nil, err + } + length := binary.LittleEndian.Uint32(header) + if length == 0 { + return []byte("{}"), nil + } + body := make([]byte, int(length)) + if _, err := io.ReadFull(reader, body); err != nil { + return nil, err + } + return body, nil +} + +func replaceBrowserHelperPoolWorker(worker *browserHelperPoolWorker, runtimeName string, scriptPath string, extraEnv []string) error { + fresh, err := startBrowserHelperPoolWorker(runtimeName, scriptPath, extraEnv) + if err != nil { + return err + } + if worker != nil { + stopBrowserHelperPoolWorker(worker) + *worker = *fresh + } else { + stopBrowserHelperPoolWorker(fresh) + } + return nil +} + +func executeHelperSubprocessPooled(ctx context.Context, runtimeName string, scriptPath string, requestPayload []byte, extraEnv []string, size int) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + pool, err := getOrCreateBrowserHelperPool(runtimeName, scriptPath, extraEnv, size) + if err != nil { + return nil, err + } + var worker *browserHelperPoolWorker + select { + case worker = <-pool.workers: + case <-ctx.Done(): + return nil, ctx.Err() + } + defer func() { + if worker != nil { + pool.workers <- worker + } + }() + + type poolResult struct { + body []byte + err error + } + done := make(chan poolResult, 1) + go func(w *browserHelperPoolWorker) { + if err := writePoolFrame(w.stdin, requestPayload); err != nil { + done <- poolResult{err: err} + return + } + body, err := readPoolFrame(w.stdout) + done <- poolResult{body: body, err: err} + }(worker) + + select { + case res := <-done: + if res.err != nil { + _ = replaceBrowserHelperPoolWorker(worker, pool.runtimeName, pool.scriptPath, pool.extraEnv) + return nil, classifyBrowserHelperExecError(ctx, runtimeName, res.err, "") + } + return res.body, nil + case <-ctx.Done(): + _ = replaceBrowserHelperPoolWorker(worker, pool.runtimeName, pool.scriptPath, pool.extraEnv) + return nil, ctx.Err() + } +} + +func resetBrowserHelperPoolsForTest() { + browserHelperPools.mu.Lock() + defer browserHelperPools.mu.Unlock() + for key, pool := range browserHelperPools.pools { + if pool == nil { + delete(browserHelperPools.pools, key) + continue + } + close(pool.workers) + for worker := range pool.workers { + stopBrowserHelperPoolWorker(worker) + } + delete(browserHelperPools.pools, key) + } + browserHelperPools.pools = map[browserHelperPoolKey]*browserHelperPool{} +} + +func ensureHelperScriptFile(tempDir string, extension string, script string) (string, error) { + if strings.TrimSpace(tempDir) == "" { + tempDir = os.TempDir() + } + if err := os.MkdirAll(tempDir, 0o755); err != nil { + return "", err + } + + scriptHash := sha256.Sum256([]byte(script)) + scriptPath := filepath.Join(tempDir, fmt.Sprintf("notion-helper-%x%s", scriptHash, extension)) + if existing, err := os.ReadFile(scriptPath); err == nil { + if string(existing) == script { + return scriptPath, nil + } + } else if !errors.Is(err, os.ErrNotExist) { + return "", err + } + + tmpFile, err := os.CreateTemp(tempDir, "notion-helper-write-*"+extension) + if err != nil { + return "", err + } + tmpPath := tmpFile.Name() + cleanupTmp := true + defer func() { + if cleanupTmp { + _ = os.Remove(tmpPath) + } + }() + if _, err := tmpFile.WriteString(script); err != nil { + _ = tmpFile.Close() + return "", err + } + if err := tmpFile.Close(); err != nil { + return "", err + } + if err := os.Rename(tmpPath, scriptPath); err != nil { + if existing, readErr := os.ReadFile(scriptPath); readErr == nil && string(existing) == script { + cleanupTmp = false + _ = os.Remove(tmpPath) + return scriptPath, nil + } + return "", err + } + cleanupTmp = false + return scriptPath, nil +} + func newBrowserHelperCommand(ctx context.Context, runtimeName string, scriptPath string, requestPayload []byte, extraEnv []string) *exec.Cmd { _ = ctx cmd := exec.CommandContext(context.Background(), runtimeName, scriptPath) @@ -291,6 +620,39 @@ func browserHelperNodeEnv() []string { return []string{"NODE_PATH=" + joined} } +func browserHelperNodeEnvForConfig(cfg AppConfig) []string { + base := browserHelperNodeEnv() + size := strings.TrimSpace(os.Getenv(browserHelperPoolSizeEnvKey)) + if size != "" { + return append(base, browserHelperPoolSizeEnvKey+"="+size) + } + if cfg.Browser.HelperPoolSize <= 1 { + return base + } + sizeNum := configuredBrowserHelperPoolSize(cfg) + if sizeNum <= 1 { + return base + } + return append(base, browserHelperPoolSizeEnvKey+"="+strconv.Itoa(sizeNum)) +} + +func configuredBrowserHelperPoolSize(cfg AppConfig) int { + if cfg.Browser.HelperPoolSize > 0 { + if cfg.Browser.HelperPoolSize > maxBrowserHelperPoolSize { + return maxBrowserHelperPoolSize + } + return cfg.Browser.HelperPoolSize + } + size := runtime.NumCPU() + if size < 1 { + size = 1 + } + if size > maxBrowserHelperPoolSize { + size = maxBrowserHelperPoolSize + } + return size +} + func browserHelperNodeModuleCandidates() []string { candidates := []string{ os.Getenv("NODE_PATH"), diff --git a/internal/app/notion_client_protocol_test.go b/internal/app/notion_client_protocol_test.go index 6ee2765..83f7926 100644 --- a/internal/app/notion_client_protocol_test.go +++ b/internal/app/notion_client_protocol_test.go @@ -3,11 +3,24 @@ package app import ( "context" "encoding/json" + "expvar" "net/http" "net/http/httptest" "testing" ) +func resetNotionTransportCacheForTest() { + notionTransportCache.mu.Lock() + defer notionTransportCache.mu.Unlock() + for _, transport := range notionTransportCache.items { + if transport != nil { + transport.CloseIdleConnections() + } + } + notionTransportCache.items = map[notionHTTPTransportCacheKey]*http.Transport{} + notionHTTPTransportCacheMetric.Init() +} + func newProtocolTestClient(cfg AppConfig) *NotionAIClient { cfg.APIKey = "test-api-key" if cfg.UpstreamBaseURL == "" { @@ -217,3 +230,146 @@ func TestPostJSONResponseAddsResinAccountHeaderWhenEnabled(t *testing.T) { t.Fatalf("%s = %q, want %q", defaultResinAccountHeader, got, want) } } + +func TestNewNotionAIClientWithModeReusesTransportForSameConfigAndAccount(t *testing.T) { + resetNotionTransportCacheForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + first := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + second := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + streaming := newNotionAIClientWithMode(session, cfg, "alice@example.com", true) + + if first.HTTPClient == nil || second.HTTPClient == nil || streaming.HTTPClient == nil { + t.Fatalf("expected HTTP clients to be initialized") + } + if first.HTTPClient.Transport == nil || second.HTTPClient.Transport == nil || streaming.HTTPClient.Transport == nil { + t.Fatalf("expected transports to be initialized") + } + if first.HTTPClient.Transport != second.HTTPClient.Transport { + t.Fatalf("expected transport reuse for same account/config") + } + if first.HTTPClient.Transport != streaming.HTTPClient.Transport { + t.Fatalf("expected streaming and standard clients to share transport cache") + } + if first.HTTPClient.Timeout <= 0 { + t.Fatalf("expected non-streaming timeout to be configured") + } + if streaming.HTTPClient.Timeout != 0 { + t.Fatalf("expected streaming client timeout to be disabled, got %s", streaming.HTTPClient.Timeout) + } +} + +func TestNewNotionAIClientWithModeSeparatesTransportWhenProxyPolicyDiffers(t *testing.T) { + resetNotionTransportCacheForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Accounts = []NotionAccount{ + { + Email: "alice@example.com", + ProxyMode: proxyModeHTTP, + ProxyURL: "http://127.0.0.1:18080", + }, + { + Email: "bob@example.com", + ProxyMode: proxyModeHTTP, + ProxyURL: "http://127.0.0.1:28080", + }, + } + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + alice := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + bob := newNotionAIClientWithMode(session, cfg, "bob@example.com", false) + + if alice.HTTPClient == nil || bob.HTTPClient == nil { + t.Fatalf("expected HTTP clients to be initialized") + } + if alice.HTTPClient.Transport == nil || bob.HTTPClient.Transport == nil { + t.Fatalf("expected transports to be initialized") + } + if alice.HTTPClient.Transport == bob.HTTPClient.Transport { + t.Fatalf("expected separate transports when account proxy policy differs") + } +} + +func TestCachedNotionHTTPTransportRecordsCacheMetrics(t *testing.T) { + resetNotionTransportCacheForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", true) + + mustAtLeast := func(label string, wantMin int64) { + var got int64 + if v := notionHTTPTransportCacheMetric.Get(label); v != nil { + got = v.(*expvar.Int).Value() + } + if got < wantMin { + t.Fatalf("metric %s too small: got %d want >= %d", label, got, wantMin) + } + } + mustAtLeast("miss_new", 1) + mustAtLeast("hit_rlock", 1) +} + +func BenchmarkNewNotionAIClientWithModeTransportCache(b *testing.B) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + + b.Run("warm_cache", func(b *testing.B) { + resetNotionTransportCacheForTest() + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + client := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + if client == nil || client.HTTPClient == nil || client.HTTPClient.Transport == nil { + b.Fatalf("expected client with transport") + } + } + }) + + b.Run("cold_cache_reset_each_iter", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + resetNotionTransportCacheForTest() + client := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + if client == nil || client.HTTPClient == nil || client.HTTPClient.Transport == nil { + b.Fatalf("expected client with transport") + } + } + }) +} diff --git a/internal/app/notion_client_wreq_transport.go b/internal/app/notion_client_wreq_transport.go index 7a5884c..22d5129 100644 --- a/internal/app/notion_client_wreq_transport.go +++ b/internal/app/notion_client_wreq_transport.go @@ -1,241 +1,19 @@ package app -func nodeWreqHelperScript() string { - return `const fs = require('fs'); -const { fetch } = require('node-wreq'); - -(async () => { - const input = JSON.parse(fs.readFileSync(0, 'utf8')); - - const cookieMap = new Map(); - for (const item of input.cookies || []) { - const name = String((item && item.name) || '').trim(); - if (!name) continue; - cookieMap.set(name, String((item && item.value) || '')); - } - const cookieJar = { - getCookies() { - return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); - }, - setCookie(cookie) { - const text = String(cookie || ''); - const semi = text.indexOf(';'); - const pair = semi === -1 ? text : text.slice(0, semi); - const eq = pair.indexOf('='); - if (eq <= 0) return; - const name = pair.slice(0, eq).trim(); - const value = pair.slice(eq + 1).trim(); - if (name) cookieMap.set(name, value); - }, - }; - - const headers = {}; - for (const [key, value] of Object.entries(input.headers || {})) { - if (key === undefined || key === null) continue; - if (String(key).toLowerCase() === 'cookie') continue; - headers[String(key)] = String(value == null ? '' : value); - } - - const fetchOptions = { - method: 'POST', - browser: input.browser_profile || 'chrome_142', - headers, - body: JSON.stringify(input.payload || {}), - cookieJar, - timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), - throwHttpErrors: false, - }; - const proxy = String(input.proxy || '').trim(); - if (proxy) fetchOptions.proxy = proxy; - - const result = { status: 0, content_type: '', text: '' }; - let response; - try { - response = await fetch(input.run_url, fetchOptions); - } catch (err) { - process.stderr.write((err && err.stack ? err.stack : String(err)) + '\n'); - process.exit(2); - return; - } - - result.status = response.status; - result.content_type = response.headers.get('content-type') || ''; - const isNDJSON = String(result.content_type).toLowerCase().includes('application/x-ndjson'); - if (!isNDJSON) { - result.text = await response.text(); - process.stdout.write(JSON.stringify(result)); - return; - } - - const idleAfterAnswerMs = Math.max(Number(input.idle_after_answer_ms || 0), 0); - const readable = response.wreq && typeof response.wreq.readable === 'function' - ? response.wreq.readable() - : null; - if (!readable) { - result.text = await response.text(); - process.stdout.write(JSON.stringify(result)); - return; - } +import _ "embed" - let pending = ''; - let sawAnswer = false; - let sawTerminal = false; - let settled = false; - let idleTimer = null; +var ( + //go:embed assets/browser-helper.cjs + embeddedNodeWreqHelperScript string - const markLineState = (line) => { - if (!line) return; - try { - const parsed = JSON.parse(line); - if (String(parsed.type || '').toLowerCase() !== 'agent-inference' || !Array.isArray(parsed.value)) return; - const hasVisibleText = parsed.value.some((entry) => { - const t = String((entry && entry.type) || '').toLowerCase(); - const c = String((entry && entry.content) || ''); - return t === 'text' && c.trim() !== ''; - }); - if (!hasVisibleText) return; - sawAnswer = true; - if (parsed.finishedAt != null) sawTerminal = true; - } catch (_) {} - }; + //go:embed assets/browser-login-helper.cjs + embeddedNodeWreqLoginHelperScript string +) - await new Promise((resolve, reject) => { - const settle = () => { - if (settled) return; - settled = true; - if (idleTimer) { - clearTimeout(idleTimer); - idleTimer = null; - } - const remaining = pending.trim(); - if (remaining) markLineState(remaining); - try { readable.destroy(); } catch (_) {} - resolve(); - }; - const armIdle = () => { - if (idleTimer) { - clearTimeout(idleTimer); - idleTimer = null; - } - if (sawAnswer && idleAfterAnswerMs > 0) { - idleTimer = setTimeout(settle, idleAfterAnswerMs); - } - }; - readable.on('data', (chunk) => { - const text = Buffer.isBuffer(chunk) ? chunk.toString('utf8') : String(chunk); - result.text += text; - pending += text; - while (true) { - const newlineIndex = pending.indexOf('\n'); - if (newlineIndex === -1) break; - const line = pending.slice(0, newlineIndex).trim(); - pending = pending.slice(newlineIndex + 1); - markLineState(line); - if (sawTerminal) { - settle(); - return; - } - } - armIdle(); - }); - readable.on('end', settle); - readable.on('close', settle); - readable.on('error', (err) => { - if (settled) return; - settled = true; - if (idleTimer) clearTimeout(idleTimer); - reject(err); - }); - }); - - process.stdout.write(JSON.stringify(result)); -})().catch((error) => { - process.stderr.write((error && error.stack ? error.stack : String(error)) + '\n'); - process.exit(1); -}); -` +func nodeWreqHelperScript() string { + return embeddedNodeWreqHelperScript } func nodeWreqLoginHelperScript() string { - return `const fs = require('fs'); -const { fetch } = require('node-wreq'); - -(async () => { - const input = JSON.parse(fs.readFileSync(0, 'utf8')); - - const cookieMap = new Map(); - for (const item of input.cookies || []) { - const name = String((item && (item.name || item.Name)) || '').trim(); - if (!name) continue; - const rawValue = item && (item.value !== undefined ? item.value : item.Value); - cookieMap.set(name, String(rawValue == null ? '' : rawValue)); - } - const setCookieRecord = new Map(); - const cookieJar = { - getCookies() { - return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); - }, - setCookie(cookie) { - const text = String(cookie || ''); - const semi = text.indexOf(';'); - const pair = semi === -1 ? text : text.slice(0, semi); - const eq = pair.indexOf('='); - if (eq <= 0) return; - const name = pair.slice(0, eq).trim(); - const value = pair.slice(eq + 1).trim(); - if (!name) return; - cookieMap.set(name, value); - setCookieRecord.set(name, value); - }, - }; - - const headers = {}; - for (const [key, value] of Object.entries(input.headers || {})) { - if (key === undefined || key === null) continue; - if (String(key).toLowerCase() === 'cookie') continue; - headers[String(key)] = String(value == null ? '' : value); - } - - const method = String(input.method || 'GET').toUpperCase(); - const fetchOptions = { - method, - browser: input.browser_profile || 'chrome_142', - headers, - cookieJar, - timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), - throwHttpErrors: false, - }; - if (typeof input.body === 'string' && input.body.length > 0) { - fetchOptions.body = input.body; - } - const proxy = String(input.proxy || '').trim(); - if (proxy) fetchOptions.proxy = proxy; - - const result = { status: 0, content_type: '', headers: {}, body: '', set_cookies: [] }; - let response; - try { - response = await fetch(String(input.url || ''), fetchOptions); - } catch (err) { - process.stderr.write((err && err.stack ? err.stack : String(err)) + '\n'); - process.exit(2); - return; - } - - result.status = response.status; - if (response.headers && typeof response.headers.forEach === 'function') { - response.headers.forEach((value, key) => { - const lk = String(key).toLowerCase(); - if (lk === 'set-cookie') return; - result.headers[lk] = String(value); - }); - } - result.content_type = result.headers['content-type'] || ''; - result.body = await response.text(); - result.set_cookies = [...setCookieRecord.entries()].map(([name, value]) => ({ Name: name, Value: value })); - process.stdout.write(JSON.stringify(result)); -})().catch((error) => { - process.stderr.write((error && error.stack ? error.stack : String(error)) + '\n'); - process.exit(1); -}); -` + return embeddedNodeWreqLoginHelperScript } diff --git a/internal/app/openai.go b/internal/app/openai.go index 27a3661..e2b5b4f 100644 --- a/internal/app/openai.go +++ b/internal/app/openai.go @@ -103,6 +103,13 @@ func normalizeChatInput(payload map[string]any) (NormalizedInput, error) { if !ok { return NormalizedInput{}, fmt.Errorf("messages must be an array") } + return normalizeChatInputFromParts(rawMessages, payload["attachments"]) +} + +func normalizeChatInputFromParts(rawMessages []any, attachmentsRaw any) (NormalizedInput, error) { + if rawMessages == nil { + return NormalizedInput{}, fmt.Errorf("messages must be an array") + } segments := make([]conversationPromptSegment, 0, len(rawMessages)) hiddenParts := make([]string, 0, len(rawMessages)) attachments := []InputAttachment{} @@ -125,7 +132,7 @@ func normalizeChatInput(payload map[string]any) (NormalizedInput, error) { hiddenParts = append(hiddenParts, hiddenSegments...) attachments = append(attachments, atts...) } - extra, err := extractAttachmentsFromAny(payload["attachments"]) + extra, err := extractAttachmentsFromAny(attachmentsRaw) if err != nil { return NormalizedInput{}, err } @@ -221,6 +228,10 @@ func buildConversationTranscriptPrompt(segments []conversationPromptSegment) str } func normalizeResponsesInput(payload map[string]any, previousResponse map[string]any) (NormalizedInput, error) { + return normalizeResponsesInputFromParts(payload["input"], payload["attachments"], previousResponse) +} + +func normalizeResponsesInputFromParts(rawInput any, attachmentsRaw any, previousResponse map[string]any) (NormalizedInput, error) { var ( prompt string hiddenPrompt string @@ -228,7 +239,7 @@ func normalizeResponsesInput(payload map[string]any, previousResponse map[string segments []conversationPromptSegment err error ) - switch x := payload["input"].(type) { + switch x := rawInput.(type) { case string: prompt = strings.TrimSpace(x) segments = appendConversationPromptSegment(segments, "user", prompt) @@ -242,10 +253,10 @@ func normalizeResponsesInput(payload map[string]any, previousResponse map[string return NormalizedInput{}, err } default: - prompt = strings.TrimSpace(flattenContent(payload["input"])) + prompt = strings.TrimSpace(flattenContent(rawInput)) segments = appendConversationPromptSegment(segments, "user", prompt) } - extra, err := extractAttachmentsFromAny(payload["attachments"]) + extra, err := extractAttachmentsFromAny(attachmentsRaw) if err != nil { return NormalizedInput{}, err } diff --git a/internal/app/openai_types.go b/internal/app/openai_types.go new file mode 100644 index 0000000..c0db59a --- /dev/null +++ b/internal/app/openai_types.go @@ -0,0 +1,339 @@ +package app + +import ( + "encoding/json" + "net/http" + "strings" +) + +type chatCompletionsRequestBody struct { + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + Conversation string `json:"conversation,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + Thread string `json:"thread,omitempty"` + NotionThreadID string `json:"notion_thread_id,omitempty"` + AccountEmail string `json:"account_email,omitempty"` + NotionAccountEmail string `json:"notion_account_email,omitempty"` + UseWebSearch *bool `json:"use_web_search,omitempty"` + Metadata any `json:"metadata,omitempty"` + Tools any `json:"tools,omitempty"` + StreamOptions any `json:"stream_options,omitempty"` + Messages any `json:"messages,omitempty"` + Attachments any `json:"attachments,omitempty"` + StreamIncludeUsage *bool `json:"-"` + Type string `json:"type,omitempty"` + UserName string `json:"user_name,omitempty"` + CharName string `json:"char_name,omitempty"` + GroupNames []string `json:"group_names,omitempty"` + ContinuePrefill string `json:"continue_prefill,omitempty"` + ShowThoughts *bool `json:"show_thoughts,omitempty"` +} + +type responsesRequestBody struct { + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + Conversation string `json:"conversation,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + Thread string `json:"thread,omitempty"` + NotionThreadID string `json:"notion_thread_id,omitempty"` + AccountEmail string `json:"account_email,omitempty"` + NotionAccountEmail string `json:"notion_account_email,omitempty"` + UseWebSearch *bool `json:"use_web_search,omitempty"` + Metadata any `json:"metadata,omitempty"` + Tools any `json:"tools,omitempty"` + Input any `json:"input,omitempty"` + Attachments any `json:"attachments,omitempty"` +} + +func trimStringSlice(values []string) []string { + if len(values) == 0 { + return nil + } + out := make([]string, 0, len(values)) + for _, value := range values { + if clean := strings.TrimSpace(value); clean != "" { + out = append(out, clean) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func normalizeTypedChatCompletionsRequestBody(body chatCompletionsRequestBody) chatCompletionsRequestBody { + body.Model = strings.TrimSpace(body.Model) + body.ConversationID = strings.TrimSpace(body.ConversationID) + body.Conversation = strings.TrimSpace(body.Conversation) + body.ThreadID = strings.TrimSpace(body.ThreadID) + body.Thread = strings.TrimSpace(body.Thread) + body.NotionThreadID = strings.TrimSpace(body.NotionThreadID) + body.AccountEmail = strings.TrimSpace(body.AccountEmail) + body.NotionAccountEmail = strings.TrimSpace(body.NotionAccountEmail) + body.Type = strings.TrimSpace(body.Type) + body.UserName = strings.TrimSpace(body.UserName) + body.CharName = strings.TrimSpace(body.CharName) + body.ContinuePrefill = strings.TrimSpace(body.ContinuePrefill) + body.GroupNames = trimStringSlice(body.GroupNames) + if value, ok := parseIncludeUsageFromStreamOptionsAny(body.StreamOptions); ok { + copyValue := value + body.StreamIncludeUsage = ©Value + } + return body +} + +func normalizeTypedResponsesRequestBody(body responsesRequestBody) responsesRequestBody { + body.Model = strings.TrimSpace(body.Model) + body.PreviousResponseID = strings.TrimSpace(body.PreviousResponseID) + body.ConversationID = strings.TrimSpace(body.ConversationID) + body.Conversation = strings.TrimSpace(body.Conversation) + body.ThreadID = strings.TrimSpace(body.ThreadID) + body.Thread = strings.TrimSpace(body.Thread) + body.NotionThreadID = strings.TrimSpace(body.NotionThreadID) + body.AccountEmail = strings.TrimSpace(body.AccountEmail) + body.NotionAccountEmail = strings.TrimSpace(body.NotionAccountEmail) + return body +} + +func requestedModelFromTyped(model string, fallback string) string { + modelID := strings.TrimSpace(model) + if modelID == "" { + return fallback + } + return modelID +} + +func extractChatCompletionsRequestBody(payload map[string]any) chatCompletionsRequestBody { + if payload == nil { + return chatCompletionsRequestBody{} + } + body := chatCompletionsRequestBody{ + Model: strings.TrimSpace(stringValue(payload["model"])), + ConversationID: strings.TrimSpace(stringValue(payload["conversation_id"])), + Conversation: strings.TrimSpace(stringValue(payload["conversation"])), + ThreadID: strings.TrimSpace(stringValue(payload["thread_id"])), + Thread: strings.TrimSpace(stringValue(payload["thread"])), + NotionThreadID: strings.TrimSpace(stringValue(payload["notion_thread_id"])), + AccountEmail: strings.TrimSpace(stringValue(payload["account_email"])), + NotionAccountEmail: strings.TrimSpace(stringValue(payload["notion_account_email"])), + Type: strings.TrimSpace(stringValue(payload["type"])), + UserName: strings.TrimSpace(stringValue(payload["user_name"])), + CharName: strings.TrimSpace(stringValue(payload["char_name"])), + ContinuePrefill: strings.TrimSpace(stringValue(payload["continue_prefill"])), + GroupNames: stringSliceValue(payload["group_names"]), + } + body.Stream, _ = payload["stream"].(bool) + if value, ok := parseBoolField(payload["use_web_search"]); ok { + copyValue := value + body.UseWebSearch = ©Value + } + if value, ok := parseBoolField(payload["show_thoughts"]); ok { + copyValue := value + body.ShowThoughts = ©Value + } + body.Metadata = payload["metadata"] + body.Tools = payload["tools"] + body.StreamOptions = payload["stream_options"] + body.Messages = payload["messages"] + body.Attachments = payload["attachments"] + if value, ok := parseIncludeUsageFromStreamOptionsAny(body.StreamOptions); ok { + copyValue := value + body.StreamIncludeUsage = ©Value + } + return normalizeTypedChatCompletionsRequestBody(body) +} + +func extractResponsesRequestBody(payload map[string]any) responsesRequestBody { + if payload == nil { + return responsesRequestBody{} + } + body := responsesRequestBody{ + Model: strings.TrimSpace(stringValue(payload["model"])), + PreviousResponseID: strings.TrimSpace(stringValue(payload["previous_response_id"])), + ConversationID: strings.TrimSpace(stringValue(payload["conversation_id"])), + Conversation: strings.TrimSpace(stringValue(payload["conversation"])), + ThreadID: strings.TrimSpace(stringValue(payload["thread_id"])), + Thread: strings.TrimSpace(stringValue(payload["thread"])), + NotionThreadID: strings.TrimSpace(stringValue(payload["notion_thread_id"])), + AccountEmail: strings.TrimSpace(stringValue(payload["account_email"])), + NotionAccountEmail: strings.TrimSpace(stringValue(payload["notion_account_email"])), + } + body.Stream, _ = payload["stream"].(bool) + if value, ok := parseBoolField(payload["use_web_search"]); ok { + copyValue := value + body.UseWebSearch = ©Value + } + body.Metadata = payload["metadata"] + body.Tools = payload["tools"] + body.Input = payload["input"] + body.Attachments = payload["attachments"] + return normalizeTypedResponsesRequestBody(body) +} + +func requestedConversationIDFromTyped(r *http.Request, conversationID string, conversation string, metadata any) string { + if fromHeader := firstRequestValue(r, "X-Conversation-ID", "X-Notion-Conversation-ID"); fromHeader != "" { + return fromHeader + } + if value := strings.TrimSpace(conversationID); value != "" { + return value + } + if value := strings.TrimSpace(conversation); value != "" { + return value + } + return parseStringFieldFromMetadataAny(metadata, "conversation_id", "notion_conversation_id") +} + +func requestedThreadIDFromTyped(r *http.Request, threadID string, thread string, notionThreadID string, metadata any) string { + if fromHeader := firstRequestValue(r, "X-Thread-ID", "X-Notion-Thread-ID"); fromHeader != "" { + return fromHeader + } + for _, value := range []string{threadID, thread, notionThreadID} { + if clean := strings.TrimSpace(value); clean != "" { + return clean + } + } + return parseStringFieldFromMetadataAny(metadata, "thread_id", "notion_thread_id") +} + +func requestedAccountEmailFromTyped(r *http.Request, accountEmail string, notionAccountEmail string, metadata any) string { + if fromHeader := firstRequestValue(r, "X-Account-Email", "X-Notion-Account-Email"); fromHeader != "" { + return fromHeader + } + for _, value := range []string{accountEmail, notionAccountEmail} { + if clean := strings.TrimSpace(value); clean != "" { + return clean + } + } + return parseStringFieldFromMetadataAny(metadata, "account_email", "notion_account_email") +} + +func requestedWebSearchFromTyped(useWebSearch *bool, metadata any, tools any, fallback bool) bool { + if useWebSearch != nil { + return *useWebSearch + } + if value, ok := parseWebSearchFromMetadataAny(metadata); ok { + return value + } + if value, ok := parseWebSearchFromToolsAny(tools); ok { + return value + } + return fallback +} + +func parseWebSearchFromMetadataAny(raw any) (bool, bool) { + meta := decodeJSONObjectAny(raw) + if meta == nil { + return false, false + } + for _, key := range []string{"use_web_search", "notion_use_web_search"} { + if value, ok := meta[key]; ok { + if parsed, parsedOK := parseBoolField(value); parsedOK { + return parsed, true + } + } + } + return false, false +} + +func parseStringFieldFromMetadataAny(raw any, keys ...string) string { + meta := decodeJSONObjectAny(raw) + if meta == nil { + return "" + } + for _, key := range keys { + if value := strings.TrimSpace(stringValue(meta[key])); value != "" { + return value + } + } + return "" +} + +func decodeJSONObjectAny(raw any) map[string]any { + if raw == nil { + return nil + } + if meta := mapValue(raw); meta != nil { + return meta + } + var decoded map[string]any + switch value := raw.(type) { + case json.RawMessage: + if err := json.Unmarshal(value, &decoded); err == nil { + return decoded + } + case []byte: + if err := json.Unmarshal(value, &decoded); err == nil { + return decoded + } + case string: + if err := json.Unmarshal([]byte(value), &decoded); err == nil { + return decoded + } + } + return nil +} + +func parseWebSearchFromToolsAny(raw any) (bool, bool) { + if raw == nil { + return false, false + } + toolItems := sliceValue(raw) + if len(toolItems) == 0 { + switch value := raw.(type) { + case json.RawMessage: + var decoded []map[string]any + if err := json.Unmarshal(value, &decoded); err == nil { + toolItems = sliceValue(decoded) + } + case []byte: + var decoded []map[string]any + if err := json.Unmarshal(value, &decoded); err == nil { + toolItems = sliceValue(decoded) + } + case string: + var decoded []map[string]any + if err := json.Unmarshal([]byte(value), &decoded); err == nil { + toolItems = sliceValue(decoded) + } + } + } + for _, item := range toolItems { + tool := mapValue(item) + if tool == nil { + continue + } + toolType := strings.TrimSpace(stringValue(tool["type"])) + if strings.Contains(toolType, "web_search") { + return true, true + } + } + return false, false +} + +func parseIncludeUsageFromStreamOptionsAny(raw any) (bool, bool) { + options := decodeJSONObjectAny(raw) + if options == nil { + return false, false + } + return parseBoolField(options["include_usage"]) +} + +func (body chatCompletionsRequestBody) likelySillyTavernByEnvelope() bool { + if strings.TrimSpace(body.Type) != "" { + return true + } + if strings.TrimSpace(body.UserName) != "" && strings.TrimSpace(body.CharName) != "" { + return true + } + if len(body.GroupNames) > 0 { + return true + } + if strings.TrimSpace(body.ContinuePrefill) != "" { + return true + } + return body.ShowThoughts != nil +} diff --git a/internal/app/prompt_guard.go b/internal/app/prompt_guard.go index 0f8ea74..88d0576 100644 --- a/internal/app/prompt_guard.go +++ b/internal/app/prompt_guard.go @@ -96,6 +96,23 @@ func promptGuardProfileChain(cfg AppConfig, hasTools bool) []promptProfile { return chain[:limit] } +func buildPromptGuardAllRetryPrefixes(promptCfg PromptConfig) []string { + totalCap := len(defaultPromptCodingRetryPrefixes()) + + len(defaultPromptGeneralRetryPrefixes()) + + len(defaultPromptDirectAnswerRetryPrefixes()) + + len(promptCfg.CodingRetryPrefixes) + + len(promptCfg.GeneralRetryPrefixes) + + len(promptCfg.DirectAnswerRetryPrefixes) + out := make([]string, 0, totalCap) + out = append(out, defaultPromptCodingRetryPrefixes()...) + out = append(out, defaultPromptGeneralRetryPrefixes()...) + out = append(out, defaultPromptDirectAnswerRetryPrefixes()...) + out = append(out, promptCfg.CodingRetryPrefixes...) + out = append(out, promptCfg.GeneralRetryPrefixes...) + out = append(out, promptCfg.DirectAnswerRetryPrefixes...) + return out +} + func resolvePromptGuardProfile(cfg AppConfig, request PromptRunRequest) (promptProfile, int) { chain := promptGuardProfileChain(cfg, false) if len(chain) == 0 { @@ -175,11 +192,10 @@ func promptGuardPrepareRequest(cfg AppConfig, request PromptRunRequest) PromptRu func promptGuardStripRetryPrefixes(cfg AppConfig, text string) string { current := text - all := append(append([]string{}, defaultPromptCodingRetryPrefixes()...), defaultPromptGeneralRetryPrefixes()...) - all = append(all, defaultPromptDirectAnswerRetryPrefixes()...) - all = append(all, cfg.Prompt.CodingRetryPrefixes...) - all = append(all, cfg.Prompt.GeneralRetryPrefixes...) - all = append(all, cfg.Prompt.DirectAnswerRetryPrefixes...) + all := cfg.Prompt.precomputedAllRetryPrefixes + if len(all) == 0 { + all = buildPromptGuardAllRetryPrefixes(cfg.Prompt) + } matched := true for matched { matched = false @@ -193,13 +209,14 @@ func promptGuardStripRetryPrefixes(cfg AppConfig, text string) string { return current } +var promptGuardCodingRequestPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)\b(code|coding|program|function|class|bug|debug|refactor|api|sdk|javascript|typescript|python|golang|rust|docker|sql|bash|shell|json|yaml|repository|repo|frontend|backend|server|client)\b`), + regexp.MustCompile(`代码|编程|开发|函数|脚本|调试|报错|异常|接口|部署|构建|数据库|仓库|前端|后端|服务端|客户端|测试|日志`), + regexp.MustCompile("```"), +} + func promptGuardLooksLikeCodingRequest(text string) bool { - patterns := []*regexp.Regexp{ - regexp.MustCompile(`(?i)\b(code|coding|program|function|class|bug|debug|refactor|api|sdk|javascript|typescript|python|golang|rust|docker|sql|bash|shell|json|yaml|repository|repo|frontend|backend|server|client)\b`), - regexp.MustCompile(`代码|编程|开发|函数|脚本|调试|报错|异常|接口|部署|构建|数据库|仓库|前端|后端|服务端|客户端|测试|日志`), - regexp.MustCompile("```"), - } - for _, pattern := range patterns { + for _, pattern := range promptGuardCodingRequestPatterns { if pattern.MatchString(text) { return true } diff --git a/internal/app/request_dispatch.go b/internal/app/request_dispatch.go index 68e1cb3..5b364bd 100644 --- a/internal/app/request_dispatch.go +++ b/internal/app/request_dispatch.go @@ -3,9 +3,11 @@ package app import ( "context" "errors" + "expvar" "fmt" "net/http" "strings" + "sync" "time" ) @@ -16,6 +18,88 @@ const ( var errDispatchCapacityExceeded = errors.New("dispatch capacity exceeded") +var wreqClientNewTotalMetric = expvar.NewMap("notion2api_wreq_client_new_total") + +type probeCacheEntry struct { + lastChecked time.Time + lastOK bool +} + +type probeCache struct { + mu sync.Mutex + entries map[string]probeCacheEntry +} + +func newProbeCache() *probeCache { + return &probeCache{ + entries: map[string]probeCacheEntry{}, + } +} + +func (c *probeCache) shouldProbe(accountKey string, ttl time.Duration, now time.Time) bool { + if c == nil { + return true + } + if strings.TrimSpace(accountKey) == "" { + return true + } + if ttl <= 0 { + return true + } + c.mu.Lock() + defer c.mu.Unlock() + if c.entries == nil { + c.entries = map[string]probeCacheEntry{} + return true + } + entry, ok := c.entries[accountKey] + if !ok { + return true + } + if !entry.lastOK { + return true + } + return now.Sub(entry.lastChecked) >= ttl +} + +func (c *probeCache) markSuccess(accountKey string, now time.Time) { + if c == nil { + return + } + accountKey = strings.TrimSpace(accountKey) + if accountKey == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + if c.entries == nil { + c.entries = map[string]probeCacheEntry{} + } + c.entries[accountKey] = probeCacheEntry{lastChecked: now, lastOK: true} +} + +func (c *probeCache) markFailure(accountKey string) { + if c == nil { + return + } + accountKey = strings.TrimSpace(accountKey) + if accountKey == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + delete(c.entries, accountKey) +} + +func (c *probeCache) invalidateAll() { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + c.entries = map[string]probeCacheEntry{} +} + func requestTimeout(cfg AppConfig) time.Duration { return time.Duration(maxInt(cfg.TimeoutSec, 10)) * time.Second } @@ -40,7 +124,7 @@ func mergeDispatchCandidates(preferred *NotionAccount, candidates []NotionAccoun out := make([]NotionAccount, 0, len(candidates)+1) seen := map[string]struct{}{} appendCandidate := func(account NotionAccount) { - key := canonicalEmailKey(account.Email) + key := getAccountEmailKey(account) if key == "" { return } @@ -60,8 +144,19 @@ func mergeDispatchCandidates(preferred *NotionAccount, candidates []NotionAccoun } func resolveDispatchCandidates(cfg AppConfig, request PromptRunRequest, now time.Time) ([]NotionAccount, error) { + poolCandidates := buildDispatchCandidateOrder(cfg, now) + return resolveDispatchCandidatesWithPool(cfg, poolCandidates, request, now) +} + +func resolveDispatchCandidatesFromSnapshot(bundle *snapshotBundle, request PromptRunRequest, now time.Time) ([]NotionAccount, error) { + if bundle == nil { + return nil, noEligibleAccountsError() + } + return resolveDispatchCandidatesWithPool(bundle.Config, pickDispatchCandidatesFromSnapshot(bundle, now), request, now) +} + +func resolveDispatchCandidatesWithPool(cfg AppConfig, poolCandidates []NotionAccount, request PromptRunRequest, now time.Time) ([]NotionAccount, error) { pinnedEmail := strings.TrimSpace(request.PinnedAccountEmail) - poolCandidates := pickDispatchCandidates(cfg, now) if pinnedEmail == "" { if len(poolCandidates) == 0 { return nil, noEligibleAccountsError() @@ -116,6 +211,14 @@ func dispatchProtocolProbeTimeout(cfg AppConfig) time.Duration { return time.Duration(seconds) * time.Second } +func dispatchProbeCacheTTL(cfg AppConfig) time.Duration { + seconds := cfg.Dispatch.ProbeCacheTTLSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + func isDispatchContextAbort(ctx context.Context, err error) bool { if err == nil { return false @@ -126,17 +229,82 @@ func isDispatchContextAbort(ctx context.Context, err error) bool { return ctx != nil && ctx.Err() != nil } -func (a *App) probeAccountProtocolHealth(ctx context.Context, cfg AppConfig, session SessionInfo) error { +func (a *App) shouldProbeAccountProtocolHealth(accountKey string, ttl time.Duration, now time.Time) bool { + if a == nil { + return true + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return true + } + return a.State.DispatchProbeCache.shouldProbe(accountKey, ttl, now) +} + +func (a *App) markAccountProtocolProbeSuccess(accountKey string, now time.Time) { + if a == nil { + return + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return + } + a.State.DispatchProbeCache.markSuccess(accountKey, now) +} + +func (a *App) markAccountProtocolProbeFailure(accountKey string) { + if a == nil { + return + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return + } + a.State.DispatchProbeCache.markFailure(accountKey) +} + +func (a *App) invalidateDispatchProbeCache() { + if a == nil { + return + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return + } + a.State.DispatchProbeCache.invalidateAll() +} + +func (a *App) probeAccountProtocolHealth(ctx context.Context, cfg AppConfig, session SessionInfo, accountEmail string) error { + accountKey := canonicalEmailKey(accountEmail) + if accountKey == "" { + accountKey = canonicalEmailKey(session.UserEmail) + } + now := time.Now() + ttl := dispatchProbeCacheTTL(cfg) + if !a.shouldProbeAccountProtocolHealth(accountKey, ttl, now) { + return nil + } if a.accountProtocolProbeOverride != nil { - return a.accountProtocolProbeOverride(ctx, cfg, session) + err := a.accountProtocolProbeOverride(ctx, cfg, session) + if err == nil { + a.markAccountProtocolProbeSuccess(accountKey, now) + return nil + } + if isDispatchContextAbort(ctx, err) { + a.markAccountProtocolProbeSuccess(accountKey, now) + return nil + } + a.markAccountProtocolProbeFailure(accountKey) + return err } probeCtx, cancel := context.WithTimeout(ctx, dispatchProtocolProbeTimeout(cfg)) defer cancel() client := newNotionAIClient(session, cfg, "") _, err := client.listInferenceTranscripts(probeCtx) if isDispatchContextAbort(probeCtx, err) { + a.markAccountProtocolProbeSuccess(accountKey, now) return nil } + if err != nil { + a.markAccountProtocolProbeFailure(accountKey) + return err + } + a.markAccountProtocolProbeSuccess(accountKey, now) return err } @@ -145,7 +313,7 @@ func (a *App) loadReadyDispatchSession(ctx context.Context, cfg AppConfig, accou if err != nil { return SessionInfo{}, err } - if err := a.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := a.probeAccountProtocolHealth(ctx, cfg, session, account.Email); err != nil { return SessionInfo{}, err } return session, nil @@ -183,7 +351,7 @@ func (a *App) runPromptActiveFallback(r *http.Request, request PromptRunRequest, if err != nil { return InferenceResult{}, err } - if err := a.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := a.probeAccountProtocolHealth(ctx, cfg, session, ""); err != nil { return InferenceResult{}, err } @@ -204,9 +372,10 @@ func (a *App) runPromptActiveFallback(r *http.Request, request PromptRunRequest, } if cfg.ResolveSessionRefresh().RetryOnAuthError && isSessionRetryableError(err) && !emittedAny { if refreshErr := a.State.RefreshSession(ctx, "prompt_retry_fallback"); refreshErr == nil { + a.invalidateDispatchProbeCache() _, refreshed, _ := a.State.Snapshot() if strings.TrimSpace(refreshed.UserID) != "" && strings.TrimSpace(refreshed.SpaceID) != "" && len(refreshed.Cookies) > 0 { - if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed); probeErr != nil { + if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed, ""); probeErr != nil { return InferenceResult{}, probeErr } return a.runPromptWithSession(ctx, cfg, refreshed, "", request, wrappedDelta) @@ -226,7 +395,7 @@ func (a *App) runPromptActiveFallbackWithSink(r *http.Request, request PromptRun if err != nil { return InferenceResult{}, err } - if err := a.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := a.probeAccountProtocolHealth(ctx, cfg, session, ""); err != nil { return InferenceResult{}, err } @@ -261,9 +430,10 @@ func (a *App) runPromptActiveFallbackWithSink(r *http.Request, request PromptRun } if cfg.ResolveSessionRefresh().RetryOnAuthError && isSessionRetryableError(err) && !emittedAny { if refreshErr := a.State.RefreshSession(ctx, "prompt_retry_fallback"); refreshErr == nil { + a.invalidateDispatchProbeCache() _, refreshed, _ := a.State.Snapshot() if strings.TrimSpace(refreshed.UserID) != "" && strings.TrimSpace(refreshed.SpaceID) != "" && len(refreshed.Cookies) > 0 { - if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed); probeErr != nil { + if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed, ""); probeErr != nil { return InferenceResult{}, probeErr } return a.runPromptWithSessionWithSink(ctx, cfg, refreshed, "", request, InferenceStreamSink{ @@ -292,7 +462,17 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest defer cancel() now := time.Now() - candidates, err := resolveDispatchCandidates(cfg, request, now) + var candidates []NotionAccount + var err error + if a != nil && a.State != nil { + if snap := a.State.snap.Load(); snap != nil { + candidates, err = resolveDispatchCandidatesFromSnapshot(snap, request, now) + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } if err != nil { return InferenceResult{}, err } @@ -359,6 +539,7 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest refreshedCfg, refreshErr := a.State.tryRefreshAccount(ctx, cfg, account) if refreshErr == nil { if saveErr := a.State.SaveAndApply(refreshedCfg); saveErr == nil { + a.invalidateDispatchProbeCache() cfg = refreshedCfg refreshedAccount, _, ok := cfg.FindAccount(account.Email) if ok { @@ -445,7 +626,17 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu defer cancel() now := time.Now() - candidates, err := resolveDispatchCandidates(cfg, request, now) + var candidates []NotionAccount + var err error + if a != nil && a.State != nil { + if snap := a.State.snap.Load(); snap != nil { + candidates, err = resolveDispatchCandidatesFromSnapshot(snap, request, now) + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } if err != nil { return InferenceResult{}, err } @@ -526,6 +717,7 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu refreshedCfg, refreshErr := a.State.tryRefreshAccount(ctx, cfg, account) if refreshErr == nil { if saveErr := a.State.SaveAndApply(refreshedCfg); saveErr == nil { + a.invalidateDispatchProbeCache() cfg = refreshedCfg if refreshedAccount, _, ok := cfg.FindAccount(account.Email); ok { refreshedSession, loadErr := a.loadReadyDispatchSession(ctx, cfg, refreshedAccount) diff --git a/internal/app/response_store.go b/internal/app/response_store.go new file mode 100644 index 0000000..0209cd7 --- /dev/null +++ b/internal/app/response_store.go @@ -0,0 +1,217 @@ +package app + +import ( + "container/heap" + "strings" + "time" +) + +const responseStoreCleanupInterval = 30 * time.Second + +type responseExpiryEntry struct { + responseID string + createdAt time.Time +} + +type responseExpiryHeap []responseExpiryEntry + +func (h responseExpiryHeap) Len() int { + return len(h) +} + +func (h responseExpiryHeap) Less(i, j int) bool { + return h[i].createdAt.Before(h[j].createdAt) +} + +func (h responseExpiryHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *responseExpiryHeap) Push(x any) { + entry, _ := x.(responseExpiryEntry) + *h = append(*h, entry) +} + +func (h *responseExpiryHeap) Pop() any { + if h == nil || len(*h) == 0 { + return responseExpiryEntry{} + } + old := *h + last := old[len(old)-1] + *h = old[:len(old)-1] + return last +} + +type responseStore struct { + ttl time.Duration + items map[string]StoredResponse + expirations responseExpiryHeap +} + +var testHookResponseStorePrunePop func() + +func normalizeResponseStoreTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return time.Second + } + return ttl +} + +func newResponseStore(ttl time.Duration) *responseStore { + store := &responseStore{ + ttl: normalizeResponseStoreTTL(ttl), + items: map[string]StoredResponse{}, + expirations: responseExpiryHeap{}, + } + heap.Init(&store.expirations) + return store +} + +func (s *responseStore) setTTL(ttl time.Duration) { + if s == nil { + return + } + s.ttl = normalizeResponseStoreTTL(ttl) +} + +func (s *responseStore) ensureInitialized() { + if s == nil { + return + } + if s.items == nil { + s.items = map[string]StoredResponse{} + } + if s.expirations == nil { + s.expirations = responseExpiryHeap{} + heap.Init(&s.expirations) + } +} + +func (s *responseStore) save(responseID string, record StoredResponse, now time.Time) { + if s == nil { + return + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return + } + s.ensureInitialized() + now = now.UTC() + s.pruneExpired(now) + + createdAt := record.CreatedAt.UTC() + if createdAt.IsZero() { + createdAt = now + } + record.CreatedAt = createdAt + record.ConversationID = strings.TrimSpace(record.ConversationID) + record.ThreadID = strings.TrimSpace(record.ThreadID) + record.AccountEmail = strings.TrimSpace(record.AccountEmail) + + s.items[responseID] = record + heap.Push(&s.expirations, responseExpiryEntry{ + responseID: responseID, + createdAt: createdAt, + }) +} + +func (s *responseStore) get(responseID string, now time.Time) (StoredResponse, bool) { + if s == nil { + return StoredResponse{}, false + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return StoredResponse{}, false + } + s.ensureInitialized() + now = now.UTC() + s.pruneExpired(now) + record, ok := s.items[responseID] + if !ok { + return StoredResponse{}, false + } + if now.Sub(record.CreatedAt) > s.ttl { + delete(s.items, responseID) + return StoredResponse{}, false + } + return record, true +} + +func (s *responseStore) replaceAll(records map[string]StoredResponse) { + if s == nil { + return + } + s.ensureInitialized() + s.items = map[string]StoredResponse{} + s.expirations = responseExpiryHeap{} + heap.Init(&s.expirations) + for responseID, record := range records { + cleanID := strings.TrimSpace(responseID) + if cleanID == "" { + continue + } + createdAt := record.CreatedAt.UTC() + record.CreatedAt = createdAt + record.ConversationID = strings.TrimSpace(record.ConversationID) + record.ThreadID = strings.TrimSpace(record.ThreadID) + record.AccountEmail = strings.TrimSpace(record.AccountEmail) + s.items[cleanID] = record + heap.Push(&s.expirations, responseExpiryEntry{ + responseID: cleanID, + createdAt: createdAt, + }) + } +} + +func (s *responseStore) pruneExpired(now time.Time) int { + if s == nil { + return 0 + } + s.ensureInitialized() + if len(s.items) == 0 || len(s.expirations) == 0 { + return 0 + } + now = now.UTC() + removed := 0 + for len(s.expirations) > 0 { + top := s.expirations[0] + if now.Sub(top.createdAt) <= s.ttl { + break + } + entry, _ := heap.Pop(&s.expirations).(responseExpiryEntry) + if testHookResponseStorePrunePop != nil { + testHookResponseStorePrunePop() + } + current, ok := s.items[entry.responseID] + if !ok { + continue + } + if !current.CreatedAt.UTC().Equal(entry.createdAt) { + continue + } + delete(s.items, entry.responseID) + removed++ + } + return removed +} + +func (s *responseStore) deleteByConversationOrThread(conversationID string, threadID string) int { + if s == nil { + return 0 + } + conversationID = strings.TrimSpace(conversationID) + threadID = strings.TrimSpace(threadID) + if conversationID == "" && threadID == "" { + return 0 + } + s.ensureInitialized() + removed := 0 + for responseID, record := range s.items { + if (conversationID != "" && strings.TrimSpace(record.ConversationID) == conversationID) || + (threadID != "" && strings.TrimSpace(record.ThreadID) == threadID) { + delete(s.items, responseID) + removed++ + } + } + return removed +} diff --git a/internal/app/session_refresh.go b/internal/app/session_refresh.go index 85c0b2c..770fb36 100644 --- a/internal/app/session_refresh.go +++ b/internal/app/session_refresh.go @@ -9,6 +9,11 @@ import ( "time" ) +var ( + testHookTryRefreshAccount func(context.Context, AppConfig, NotionAccount) (AppConfig, error) + testHookSaveAndApply func(*ServerState, AppConfig) error +) + func sessionRefreshNowISO() string { return time.Now().Format(time.RFC3339) } @@ -278,45 +283,62 @@ func (s *ServerState) RefreshSession(ctx context.Context, reason string) error { return fmt.Errorf("no active account configured for session refresh") } - updatedCfg, err := s.tryRefreshAccount(ctx, cfg, account) + tryRefresh := s.tryRefreshAccount + if testHookTryRefreshAccount != nil { + tryRefresh = testHookTryRefreshAccount + } + saveAndApply := s.SaveAndApply + if testHookSaveAndApply != nil { + saveAndApply = func(cfg AppConfig) error { + return testHookSaveAndApply(s, cfg) + } + } + + updatedCfg, err := tryRefresh(ctx, cfg, account) if err == nil { - if saveErr := s.SaveAndApply(updatedCfg); saveErr != nil { + if saveErr := saveAndApply(updatedCfg); saveErr != nil { s.setSessionRefreshRuntime(saveErr) return saveErr } + if s.DispatchProbeCache != nil { + s.DispatchProbeCache.invalidateAll() + } s.setSessionRefreshRuntime(nil) return nil } if !refreshCfg.AutoSwitch { s.setSessionRefreshRuntime(err) - _ = s.SaveAndApply(updatedCfg) + _ = saveAndApply(updatedCfg) return fmt.Errorf("refresh active account %s failed (%s): %w", account.Email, reason, err) } lastErr := err for _, candidate := range cfg.Accounts { - if canonicalEmailKey(candidate.Email) == canonicalEmailKey(account.Email) { + if getAccountEmailKey(candidate) == getAccountEmailKey(account) { continue } if !fileExists(ensureAccountPaths(cfg, candidate).ProbeJSON) { continue } - nextCfg, nextErr := s.tryRefreshAccount(ctx, updatedCfg, candidate) + nextCfg, nextErr := tryRefresh(ctx, updatedCfg, candidate) if nextErr != nil { lastErr = nextErr updatedCfg = nextCfg continue } - if saveErr := s.SaveAndApply(nextCfg); saveErr != nil { + if saveErr := saveAndApply(nextCfg); saveErr != nil { s.setSessionRefreshRuntime(saveErr) return saveErr } + if s.DispatchProbeCache != nil { + s.DispatchProbeCache.invalidateAll() + } s.setSessionRefreshRuntime(nil) return nil } - _ = s.SaveAndApply(updatedCfg) + _ = saveAndApply(updatedCfg) s.setSessionRefreshRuntime(lastErr) return fmt.Errorf("session refresh failed after trying active account and fallbacks (%s): %w", reason, lastErr) } diff --git a/internal/app/sqlite_store.go b/internal/app/sqlite_store.go index de3f91d..6dfb596 100644 --- a/internal/app/sqlite_store.go +++ b/internal/app/sqlite_store.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "strings" "time" @@ -14,9 +15,18 @@ import ( type SQLiteStore struct { db *sql.DB + roDB *sql.DB path string } +func observeSQLiteDuration(op string, startedAt time.Time) { + if startedAt.IsZero() { + return + } + elapsed := time.Since(startedAt) + observeSQLiteOpDuration(op, elapsed) +} + func openSQLiteStore(cfg AppConfig) (*SQLiteStore, error) { path := strings.TrimSpace(cfg.ResolveSQLitePath()) if path == "" { @@ -33,11 +43,21 @@ func openSQLiteStore(cfg AppConfig) (*SQLiteStore, error) { return nil, fmt.Errorf("open sqlite: %w", err) } db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) store := &SQLiteStore{db: db, path: path} if err := store.init(); err != nil { _ = db.Close() return nil, err } + roDB, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=ro&_journal=WAL", path)) + if err != nil { + _ = db.Close() + return nil, fmt.Errorf("open sqlite read-only: %w", err) + } + readers := maxInt(2, runtime.NumCPU()) + roDB.SetMaxOpenConns(readers) + roDB.SetMaxIdleConns(readers) + store.roDB = roDB return store, nil } @@ -49,10 +69,21 @@ func (s *SQLiteStore) Path() string { } func (s *SQLiteStore) Close() error { - if s == nil || s.db == nil { + if s == nil { return nil } - return s.db.Close() + var closeErr error + if s.roDB != nil { + if err := s.roDB.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + if s.db != nil { + if err := s.db.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + return closeErr } func (s *SQLiteStore) init() error { @@ -64,6 +95,10 @@ func (s *SQLiteStore) init() error { `PRAGMA busy_timeout=5000;`, `PRAGMA synchronous=NORMAL;`, `PRAGMA foreign_keys=ON;`, + `PRAGMA mmap_size=268435456;`, + `PRAGMA cache_size=-65536;`, + `PRAGMA temp_store=MEMORY;`, + `PRAGMA wal_autocheckpoint=1000;`, } for _, stmt := range pragmas { if _, err := s.db.Exec(stmt); err != nil { @@ -164,7 +199,19 @@ func (s *SQLiteStore) init() error { return nil } +func (s *SQLiteStore) readDB() *sql.DB { + if s == nil { + return nil + } + if s.roDB != nil { + return s.roDB + } + return s.db +} + func (s *SQLiteStore) SaveAccounts(cfg AppConfig) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_accounts", startedAt) if s == nil || s.db == nil { return nil } @@ -190,7 +237,7 @@ func (s *SQLiteStore) SaveAccounts(cfg AppConfig) error { return err } active := 0 - if canonicalEmailKey(account.Email) == activeKey { + if getAccountEmailKey(account) == activeKey { active = 1 } if _, err = tx.Exec( @@ -211,10 +258,13 @@ func (s *SQLiteStore) SaveAccounts(cfg AppConfig) error { } func (s *SQLiteStore) LoadAccounts() ([]NotionAccount, string, bool, error) { - if s == nil || s.db == nil { + startedAt := time.Now() + defer observeSQLiteDuration("load_accounts", startedAt) + db := s.readDB() + if db == nil { return nil, "", false, nil } - rows, err := s.db.Query(`SELECT data_json, active FROM accounts ORDER BY position ASC, email ASC`) + rows, err := db.Query(`SELECT data_json, active FROM accounts ORDER BY position ASC, email ASC`) if err != nil { return nil, "", false, err } @@ -243,6 +293,8 @@ func (s *SQLiteStore) LoadAccounts() ([]NotionAccount, string, bool, error) { } func (s *SQLiteStore) SaveConversation(entry ConversationEntry) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_conversation", startedAt) if s == nil || s.db == nil { return nil } @@ -268,6 +320,8 @@ func (s *SQLiteStore) SaveConversation(entry ConversationEntry) error { } func (s *SQLiteStore) DeleteConversation(id string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_conversation", startedAt) if s == nil || s.db == nil || strings.TrimSpace(id) == "" { return nil } @@ -276,6 +330,8 @@ func (s *SQLiteStore) DeleteConversation(id string) error { } func (s *SQLiteStore) DeleteResponsesByConversationOrThread(conversationID string, threadID string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_responses_by_conversation_or_thread", startedAt) if s == nil || s.db == nil { return nil } @@ -297,10 +353,13 @@ func (s *SQLiteStore) DeleteResponsesByConversationOrThread(conversationID strin } func (s *SQLiteStore) LoadConversations() ([]ConversationEntry, error) { - if s == nil || s.db == nil { + startedAt := time.Now() + defer observeSQLiteDuration("load_conversations", startedAt) + db := s.readDB() + if db == nil { return nil, nil } - rows, err := s.db.Query(`SELECT data_json FROM conversations ORDER BY updated_at DESC, created_at DESC LIMIT ?`, maxConversationEntries) + rows, err := db.Query(`SELECT data_json FROM conversations ORDER BY updated_at DESC, created_at DESC LIMIT ?`, maxConversationEntries) if err != nil { return nil, err } @@ -324,6 +383,8 @@ func (s *SQLiteStore) LoadConversations() ([]ConversationEntry, error) { } func (s *SQLiteStore) SaveResponse(responseID string, payload map[string]any, createdAt time.Time, conversationID string, threadID string, accountEmail string) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_response", startedAt) if s == nil || s.db == nil || strings.TrimSpace(responseID) == "" { return nil } @@ -351,6 +412,8 @@ func (s *SQLiteStore) SaveResponse(responseID string, payload map[string]any, cr } func (s *SQLiteStore) DeleteExpiredResponses(ttl time.Duration) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_expired_responses", startedAt) if s == nil || s.db == nil || ttl <= 0 { return nil } @@ -360,13 +423,16 @@ func (s *SQLiteStore) DeleteExpiredResponses(ttl time.Duration) error { } func (s *SQLiteStore) LoadResponses(ttl time.Duration) (map[string]StoredResponse, error) { - if s == nil || s.db == nil { + startedAt := time.Now() + defer observeSQLiteDuration("load_responses", startedAt) + db := s.readDB() + if db == nil { return map[string]StoredResponse{}, nil } if err := s.DeleteExpiredResponses(ttl); err != nil { return nil, err } - rows, err := s.db.Query(`SELECT response_id, created_at, payload_json, conversation_id, thread_id, account_email FROM responses ORDER BY created_at DESC`) + rows, err := db.Query(`SELECT response_id, created_at, payload_json, conversation_id, thread_id, account_email FROM responses ORDER BY created_at DESC`) if err != nil { return nil, err } @@ -405,6 +471,8 @@ func (s *SQLiteStore) LoadResponses(ttl time.Duration) (map[string]StoredRespons } func (s *SQLiteStore) SaveConversationSession(session ConversationSession) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_conversation_session", startedAt) if s == nil || s.db == nil || strings.TrimSpace(session.ID) == "" { return nil } @@ -451,6 +519,8 @@ func (s *SQLiteStore) SaveConversationSession(session ConversationSession) error } func (s *SQLiteStore) SaveConversationSessionStep(step ConversationSessionStep) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_conversation_session_step", startedAt) if s == nil || s.db == nil || strings.TrimSpace(step.SessionID) == "" || strings.TrimSpace(step.UpdatedConfigID) == "" { return nil } @@ -489,10 +559,13 @@ func (s *SQLiteStore) LoadConversationSessionBySessionID(sessionID string) (Conv } func (s *SQLiteStore) loadConversationSession(query string, arg string) (ConversationSession, bool, error) { - if s == nil || s.db == nil || strings.TrimSpace(arg) == "" { + startedAt := time.Now() + defer observeSQLiteDuration("load_conversation_session", startedAt) + db := s.readDB() + if db == nil || strings.TrimSpace(arg) == "" { return ConversationSession{}, false, nil } - row := s.db.QueryRow(query, arg) + row := db.QueryRow(query, arg) var ( session ConversationSession createdAtText, updatedAtText, lastUsedAtText, deletedAtText string @@ -543,10 +616,13 @@ func (s *SQLiteStore) loadConversationSession(query string, arg string) (Convers } func (s *SQLiteStore) LoadConversationSessionStepIDs(sessionID string) ([]string, error) { - if s == nil || s.db == nil || strings.TrimSpace(sessionID) == "" { + startedAt := time.Now() + defer observeSQLiteDuration("load_conversation_session_step_ids", startedAt) + db := s.readDB() + if db == nil || strings.TrimSpace(sessionID) == "" { return nil, nil } - rows, err := s.db.Query(`SELECT updated_config_id FROM conversation_session_steps WHERE session_id = ? ORDER BY step_index ASC`, strings.TrimSpace(sessionID)) + rows, err := db.Query(`SELECT updated_config_id FROM conversation_session_steps WHERE session_id = ? ORDER BY step_index ASC`, strings.TrimSpace(sessionID)) if err != nil { return nil, err } @@ -563,6 +639,8 @@ func (s *SQLiteStore) LoadConversationSessionStepIDs(sessionID string) ([]string } func (s *SQLiteStore) MarkConversationSessionStatus(sessionID string, status string) error { + startedAt := time.Now() + defer observeSQLiteDuration("mark_conversation_session_status", startedAt) if s == nil || s.db == nil || strings.TrimSpace(sessionID) == "" { return nil } @@ -581,6 +659,8 @@ func (s *SQLiteStore) MarkConversationSessionStatus(sessionID string, status str } func (s *SQLiteStore) DeleteConversationSessionByConversationOrThread(conversationID string, threadID string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_conversation_session_by_conversation_or_thread", startedAt) if s == nil || s.db == nil { return nil } @@ -644,6 +724,8 @@ func (s *SQLiteStore) DeleteConversationSessionByConversationOrThread(conversati } func (s *SQLiteStore) SaveSillyTavernBinding(binding SillyTavernBinding) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_sillytavern_binding", startedAt) if s == nil || s.db == nil || strings.TrimSpace(binding.ConversationID) == "" { return nil } @@ -676,13 +758,16 @@ func (s *SQLiteStore) SaveSillyTavernBinding(binding SillyTavernBinding) error { } func (s *SQLiteStore) LoadRecentSillyTavernBindings(profileKey string, limit int) ([]SillyTavernBinding, error) { - if s == nil || s.db == nil || strings.TrimSpace(profileKey) == "" { + startedAt := time.Now() + defer observeSQLiteDuration("load_recent_sillytavern_bindings", startedAt) + db := s.readDB() + if db == nil || strings.TrimSpace(profileKey) == "" { return nil, nil } if limit <= 0 { limit = 12 } - rows, err := s.db.Query( + rows, err := db.Query( `SELECT conversation_id, profile_key, thread_id, account_email, mode, transcript_json, raw_message_count, updated_at FROM sillytavern_bindings WHERE profile_key = ? @@ -729,6 +814,8 @@ func (s *SQLiteStore) LoadRecentSillyTavernBindings(profileKey string, limit int } func (s *SQLiteStore) DeleteSillyTavernBinding(conversationID string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_sillytavern_binding", startedAt) if s == nil || s.db == nil || strings.TrimSpace(conversationID) == "" { return nil } diff --git a/internal/app/sqlite_writer.go b/internal/app/sqlite_writer.go new file mode 100644 index 0000000..c6d28ef --- /dev/null +++ b/internal/app/sqlite_writer.go @@ -0,0 +1,203 @@ +package app + +import ( + "expvar" + "log" + "strings" + "sync" + "sync/atomic" + "time" +) + +const defaultSQLiteWriterQueueSize = 1024 + +var sqliteWriterFallbackTotalMetric = expvar.NewMap("notion2api_sqlite_writer_fallback_total") + +type sqlitePersistOpKind uint8 + +const ( + sqlitePersistOpSaveResponse sqlitePersistOpKind = iota + 1 + sqlitePersistOpDeleteResponsesByConversationOrThread +) + +type sqlitePersistOp struct { + kind sqlitePersistOpKind + responseID string + payload map[string]any + createdAt time.Time + conversationID string + threadID string + accountEmail string +} + +type SQLiteWriter struct { + store *SQLiteStore + queue chan sqlitePersistOp + done chan struct{} + ttlNanos atomic.Int64 + + mu sync.RWMutex + closed bool +} + +func newSQLiteWriter(store *SQLiteStore, ttl time.Duration) *SQLiteWriter { + if store == nil { + return nil + } + writer := &SQLiteWriter{ + store: store, + queue: make(chan sqlitePersistOp, defaultSQLiteWriterQueueSize), + done: make(chan struct{}), + } + writer.SetTTL(ttl) + go writer.run() + return writer +} + +func (w *SQLiteWriter) SetTTL(ttl time.Duration) { + if w == nil { + return + } + if ttl <= 0 { + ttl = time.Second + } + w.ttlNanos.Store(int64(ttl)) +} + +func (w *SQLiteWriter) EnqueueSaveResponse(responseID string, payload map[string]any, createdAt time.Time, conversationID string, threadID string, accountEmail string) { + if w == nil || w.store == nil { + return + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return + } + op := sqlitePersistOp{ + kind: sqlitePersistOpSaveResponse, + responseID: responseID, + payload: clonePersistPayload(payload), + createdAt: createdAt, + conversationID: strings.TrimSpace(conversationID), + threadID: strings.TrimSpace(threadID), + accountEmail: strings.TrimSpace(accountEmail), + } + if w.tryEnqueue(op) { + return + } + sqliteWriterFallbackTotalMetric.Add("channel_full", 1) + w.apply(op) +} + +func (w *SQLiteWriter) EnqueueDeleteResponsesByConversationOrThread(conversationID string, threadID string) { + if w == nil || w.store == nil { + return + } + conversationID = strings.TrimSpace(conversationID) + threadID = strings.TrimSpace(threadID) + if conversationID == "" && threadID == "" { + return + } + op := sqlitePersistOp{ + kind: sqlitePersistOpDeleteResponsesByConversationOrThread, + conversationID: conversationID, + threadID: threadID, + } + if w.enqueueBlocking(op) { + return + } + sqliteWriterFallbackTotalMetric.Add("writer_unavailable", 1) + w.apply(op) +} + +func (w *SQLiteWriter) Close() { + if w == nil { + return + } + w.mu.Lock() + if w.closed { + w.mu.Unlock() + return + } + w.closed = true + close(w.queue) + w.mu.Unlock() + <-w.done +} + +func (w *SQLiteWriter) tryEnqueue(op sqlitePersistOp) bool { + w.mu.RLock() + defer w.mu.RUnlock() + if w.closed { + return false + } + select { + case w.queue <- op: + return true + default: + return false + } +} + +func (w *SQLiteWriter) enqueueBlocking(op sqlitePersistOp) bool { + w.mu.RLock() + defer w.mu.RUnlock() + if w.closed { + return false + } + select { + case w.queue <- op: + return true + default: + w.queue <- op + return true + } +} + +func (w *SQLiteWriter) run() { + defer close(w.done) + for op := range w.queue { + w.apply(op) + } +} + +func (w *SQLiteWriter) apply(op sqlitePersistOp) { + if w == nil || w.store == nil { + return + } + switch op.kind { + case sqlitePersistOpSaveResponse: + if err := w.store.SaveResponse(op.responseID, op.payload, op.createdAt, op.conversationID, op.threadID, op.accountEmail); err != nil { + log.Printf("[sqlite-writer] save response %s failed: %v", op.responseID, err) + return + } + if err := w.store.DeleteExpiredResponses(w.ttl()); err != nil { + log.Printf("[sqlite-writer] cleanup responses failed: %v", err) + } + case sqlitePersistOpDeleteResponsesByConversationOrThread: + if err := w.store.DeleteResponsesByConversationOrThread(op.conversationID, op.threadID); err != nil { + log.Printf("[sqlite-writer] delete responses conversation=%s thread=%s failed: %v", op.conversationID, op.threadID, err) + } + } +} + +func (w *SQLiteWriter) ttl() time.Duration { + if w == nil { + return time.Second + } + ttlNanos := w.ttlNanos.Load() + if ttlNanos <= 0 { + return time.Second + } + return time.Duration(ttlNanos) +} + +func clonePersistPayload(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + dst := make(map[string]any, len(src)) + for key, value := range src { + dst[key] = value + } + return dst +} diff --git a/internal/wreq/wreq_cgo.go b/internal/wreq/wreq_cgo.go index e18f1e6..82e9ea3 100644 --- a/internal/wreq/wreq_cgo.go +++ b/internal/wreq/wreq_cgo.go @@ -4,20 +4,41 @@ package wreq /* #cgo CFLAGS: -I${SRCDIR}/../../wreq-ffi/include -#cgo LDFLAGS: ${SRCDIR}/../../wreq-ffi/target/release/libwreq_ffi.a -ldl -lm -lpthread +#cgo windows LDFLAGS: ${SRCDIR}/../../wreq-ffi/target/release/libwreq_ffi.a +#cgo !windows LDFLAGS: ${SRCDIR}/../../wreq-ffi/target/release/libwreq_ffi.a -ldl -lm -lpthread #include #include "wreq_ffi.h" + +// Keep explicit forward declarations here so Go-side symbol binding remains stable +// even when a local generated header is stale/missing newer prototypes. +typedef struct WreqResponseHandle WreqResponseHandle; +int32_t wreq_request_begin(struct WreqClient *client, + const uint8_t *spec_json, + size_t spec_len, + const uint8_t *body_ptr, + size_t body_len, + struct WreqResponseHandle **out_handle, + uint16_t *out_status, + char **out_headers_json, + char **out_final_url, + char **out_error); +intptr_t wreq_response_read(struct WreqResponseHandle *handle, + uint8_t *buf, + size_t cap, + uint32_t timeout_ms); +void wreq_response_close(struct WreqResponseHandle *handle); */ import "C" import ( "context" - "encoding/base64" "encoding/json" "errors" "fmt" + "io" "runtime" + "strings" "sync/atomic" "unsafe" ) @@ -34,24 +55,33 @@ type RequestSpec struct { Method string `json:"method"` URL string `json:"url"` Headers [][]string `json:"headers,omitempty"` - BodyB64 string `json:"body_b64,omitempty"` + Body []byte `json:"-"` TimeoutSecs uint64 `json:"timeout_secs,omitempty"` } type Response struct { - OK bool `json:"ok"` - Status int `json:"status"` - Headers [][]string `json:"headers"` - BodyB64 string `json:"body_b64"` - FinalURL string `json:"final_url"` - Error string `json:"error,omitempty"` + Status int + Headers [][]string + FinalURL string + + handle *C.struct_WreqResponseHandle + closed atomic.Bool } -func (r *Response) Body() ([]byte, error) { - if r.BodyB64 == "" { - return nil, nil +const ( + wreqOK int32 = 0 + wreqErrNilArg int32 = -1 + wreqErrTimeout int32 = -9 +) + +func errorFromCode(where string, code int32, detail string) error { + if code == wreqOK { + return nil + } + if trimmed := strings.TrimSpace(detail); trimmed != "" { + return fmt.Errorf("wreq: %s failed (code=%d): %s", where, code, trimmed) } - return base64.StdEncoding.DecodeString(r.BodyB64) + return fmt.Errorf("wreq: %s failed (code=%d)", where, code) } type Client struct { @@ -86,7 +116,7 @@ func (c *Client) Close() error { return nil } -func (c *Client) Do(ctx context.Context, spec RequestSpec) (*Response, error) { +func (c *Client) Begin(ctx context.Context, spec RequestSpec) (*Response, error) { if c == nil || c.handle == nil || c.closed.Load() { return nil, errors.New("wreq: client closed") } @@ -94,29 +124,163 @@ func (c *Client) Do(ctx context.Context, spec RequestSpec) (*Response, error) { return nil, err } - reqJSON, err := json.Marshal(spec) + specPayload := struct { + Method string `json:"method"` + URL string `json:"url"` + Headers [][]string `json:"headers,omitempty"` + TimeoutSecs uint64 `json:"timeout_secs,omitempty"` + }{ + Method: spec.Method, + URL: spec.URL, + Headers: spec.Headers, + TimeoutSecs: spec.TimeoutSecs, + } + specJSON, err := json.Marshal(specPayload) if err != nil { return nil, fmt.Errorf("wreq: marshal request: %w", err) } - cReq := C.CString(string(reqJSON)) - defer C.free(unsafe.Pointer(cReq)) + var specPtr *C.uint8_t + if len(specJSON) > 0 { + specPtr = (*C.uint8_t)(unsafe.Pointer(&specJSON[0])) + } + var bodyPtr *C.uint8_t + if len(spec.Body) > 0 { + bodyPtr = (*C.uint8_t)(unsafe.Pointer(&spec.Body[0])) + } + + var cHandle *C.struct_WreqResponseHandle + var cStatus C.uint16_t + var cHeaders *C.char + var cFinalURL *C.char + var cErr *C.char + + code := int32(C.wreq_request_begin( + c.handle, + specPtr, + C.size_t(len(specJSON)), + bodyPtr, + C.size_t(len(spec.Body)), + &cHandle, + &cStatus, + &cHeaders, + &cFinalURL, + &cErr, + )) + + var detail string + if cErr != nil { + detail = C.GoString(cErr) + C.wreq_string_free(cErr) + } + if code != wreqOK { + if cHeaders != nil { + C.wreq_string_free(cHeaders) + } + if cFinalURL != nil { + C.wreq_string_free(cFinalURL) + } + return nil, errorFromCode("wreq_request_begin", code, detail) + } + if cHandle == nil { + if cHeaders != nil { + C.wreq_string_free(cHeaders) + } + if cFinalURL != nil { + C.wreq_string_free(cFinalURL) + } + return nil, errors.New("wreq: begin returned nil response handle") + } - cResp := C.wreq_request(c.handle, cReq) - if cResp == nil { - return nil, errors.New("wreq: wreq_request returned NULL") + resp := &Response{ + Status: int(cStatus), + handle: cHandle, } - defer C.wreq_string_free(cResp) + if cHeaders != nil { + headersJSON := C.GoString(cHeaders) + C.wreq_string_free(cHeaders) + if strings.TrimSpace(headersJSON) != "" { + if err := json.Unmarshal([]byte(headersJSON), &resp.Headers); err != nil { + _ = resp.Close() + return nil, fmt.Errorf("wreq: decode headers json: %w", err) + } + } + } + if cFinalURL != nil { + resp.FinalURL = C.GoString(cFinalURL) + C.wreq_string_free(cFinalURL) + } + + runtime.SetFinalizer(resp, func(r *Response) { _ = r.Close() }) + return resp, nil +} - goResp := C.GoString(cResp) - var resp Response - if err := json.Unmarshal([]byte(goResp), &resp); err != nil { - return nil, fmt.Errorf("wreq: unmarshal response: %w", err) +func (r *Response) Read(p []byte) (int, error) { + if r == nil { + return 0, errors.New("wreq: response is nil") + } + if r.handle == nil { + return 0, io.EOF + } + if len(p) == 0 { + return 0, nil } - if !resp.OK { - return &resp, fmt.Errorf("wreq: %s", resp.Error) + n := int64(C.wreq_response_read( + r.handle, + (*C.uint8_t)(unsafe.Pointer(&p[0])), + C.size_t(len(p)), + C.uint32_t(0), + )) + if n > 0 { + return int(n), nil + } + if n == 0 { + return 0, io.EOF + } + code := int32(n) + if code == wreqErrTimeout { + return 0, context.DeadlineExceeded + } + if code == wreqErrNilArg { + return 0, errors.New("wreq: invalid read argument") + } + return 0, errorFromCode("wreq_response_read", code, "") +} + +func (r *Response) Close() error { + if r == nil || !r.closed.CompareAndSwap(false, true) { + return nil + } + if r.handle != nil { + C.wreq_response_close(r.handle) + r.handle = nil + } + runtime.SetFinalizer(r, nil) + return nil +} + +func (r *Response) Body() ([]byte, error) { + if r == nil { + return nil, errors.New("wreq: response is nil") + } + body, err := io.ReadAll(r) + if err != nil && !errors.Is(err, io.EOF) { + _ = r.Close() + return nil, err + } + _ = r.Close() + return body, nil +} + +func (c *Client) Do(ctx context.Context, spec RequestSpec) (*Response, error) { + resp, err := c.Begin(ctx, spec) + if err != nil { + return nil, err + } + if _, err := resp.Body(); err != nil { + return nil, err } - return &resp, nil + return resp, nil } func Version() string { diff --git a/internal/wreq/wreq_ffi_compat.h b/internal/wreq/wreq_ffi_compat.h new file mode 100644 index 0000000..ebbbb73 --- /dev/null +++ b/internal/wreq/wreq_ffi_compat.h @@ -0,0 +1,26 @@ +#ifndef WREQ_FFI_COMPAT_H +#define WREQ_FFI_COMPAT_H + +#include +#include + +typedef struct WreqClient WreqClient; +typedef struct WreqResponseHandle WreqResponseHandle; + +int32_t wreq_request_begin(struct WreqClient *client, + const uint8_t *spec_json, + size_t spec_len, + const uint8_t *body_ptr, + size_t body_len, + struct WreqResponseHandle **out_handle, + uint16_t *out_status, + char **out_headers_json, + char **out_final_url, + char **out_error); +intptr_t wreq_response_read(struct WreqResponseHandle *handle, + uint8_t *buf, + size_t cap, + uint32_t timeout_ms); +void wreq_response_close(struct WreqResponseHandle *handle); + +#endif /* WREQ_FFI_COMPAT_H */ diff --git a/internal/wreq/wreq_streaming_stub_test.go b/internal/wreq/wreq_streaming_stub_test.go new file mode 100644 index 0000000..e262a84 --- /dev/null +++ b/internal/wreq/wreq_streaming_stub_test.go @@ -0,0 +1,30 @@ +//go:build !wreq_ffi + +package wreq + +import ( + "errors" + "io" + "testing" +) + +func TestWreqStubBeginNotLinked(t *testing.T) { + client, err := New(ClientConfig{}) + if err == nil || client != nil { + t.Fatalf("expected stub New to fail with ErrNotLinked") + } + if !errors.Is(err, ErrNotLinked) { + t.Fatalf("expected ErrNotLinked, got %v", err) + } +} + +func TestWreqStubResponseReadEOF(t *testing.T) { + var r Response + n, err := r.Read(make([]byte, 16)) + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } +} diff --git a/internal/wreq/wreq_streaming_test.go b/internal/wreq/wreq_streaming_test.go new file mode 100644 index 0000000..8b99e2b --- /dev/null +++ b/internal/wreq/wreq_streaming_test.go @@ -0,0 +1,85 @@ +package wreq + +import ( + "bufio" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func ndjsonLine(payload map[string]any) string { + raw, _ := json.Marshal(payload) + return string(raw) + "\n" +} + +func TestWreqStreamingLatencyShapeWithHTTPFallback(t *testing.T) { + firstWriteCh := make(chan time.Time, 1) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + flusher, _ := w.(http.Flusher) + firstWriteAt := time.Now() + firstWriteCh <- firstWriteAt + _, _ = w.Write([]byte(ndjsonLine(map[string]any{ + "type": "agent-inference", + "value": []map[string]any{ + {"type": "text", "content": "chunk-1"}, + }, + }))) + if flusher != nil { + flusher.Flush() + } + + time.Sleep(100 * time.Millisecond) + + _, _ = w.Write([]byte(ndjsonLine(map[string]any{ + "type": "agent-inference", + "value": []map[string]any{ + {"type": "text", "content": "chunk-2"}, + }, + "finishedAt": time.Now().Format(time.RFC3339Nano), + }))) + if flusher != nil { + flusher.Flush() + } + })) + defer upstream.Close() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, upstream.URL, strings.NewReader("{}")) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("http do: %v", err) + } + defer resp.Body.Close() + + reader := bufio.NewReader(resp.Body) + line1, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read first line: %v", err) + } + if strings.TrimSpace(line1) == "" { + t.Fatalf("first line is empty") + } + firstReadAt := time.Now() + firstWriteAt := <-firstWriteCh + + delay := firstReadAt.Sub(firstWriteAt) + if delay > 50*time.Millisecond { + t.Fatalf("first chunk delay too high: got %s want <= 50ms", delay) + } + + line2, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read second line: %v", err) + } + if !strings.Contains(line2, "chunk-2") { + t.Fatalf("unexpected second line: %q", line2) + } +} diff --git a/internal/wreq/wreq_stub.go b/internal/wreq/wreq_stub.go index ae6e862..ab5910c 100644 --- a/internal/wreq/wreq_stub.go +++ b/internal/wreq/wreq_stub.go @@ -5,6 +5,7 @@ package wreq import ( "context" "errors" + "io" ) var ErrNotLinked = errors.New("wreq: built without wreq_ffi tag; use node-wreq fallback") @@ -21,20 +22,19 @@ type RequestSpec struct { Method string `json:"method"` URL string `json:"url"` Headers [][]string `json:"headers,omitempty"` - BodyB64 string `json:"body_b64,omitempty"` + Body []byte `json:"-"` TimeoutSecs uint64 `json:"timeout_secs,omitempty"` } type Response struct { - OK bool `json:"ok"` - Status int `json:"status"` - Headers [][]string `json:"headers"` - BodyB64 string `json:"body_b64"` - FinalURL string `json:"final_url"` - Error string `json:"error,omitempty"` + Status int + Headers [][]string + FinalURL string } -func (r *Response) Body() ([]byte, error) { return nil, ErrNotLinked } +func (r *Response) Read(_ []byte) (int, error) { return 0, io.EOF } +func (r *Response) Body() ([]byte, error) { return nil, ErrNotLinked } +func (r *Response) Close() error { return nil } type Client struct{} @@ -42,6 +42,10 @@ func New(_ ClientConfig) (*Client, error) { return nil, ErrNotLinked } func (c *Client) Close() error { return nil } +func (c *Client) Begin(_ context.Context, _ RequestSpec) (*Response, error) { + return nil, ErrNotLinked +} + func (c *Client) Do(_ context.Context, _ RequestSpec) (*Response, error) { return nil, ErrNotLinked } diff --git a/scripts/perf/baseline.sh b/scripts/perf/baseline.sh new file mode 100644 index 0000000..d1175de --- /dev/null +++ b/scripts/perf/baseline.sh @@ -0,0 +1,299 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +PAYLOAD_TEMPLATE="${ROOT_DIR}/scripts/perf/payload-chat.json" +OUT_ROOT="${ROOT_DIR}/docs/perf" + +BASE_URL="${N2A_BASE_URL:-http://127.0.0.1:8787}" +PPROF_BASE="${N2A_PPROF_BASE:-http://127.0.0.1:6060}" +API_KEY="${N2A_API_KEY:-change-me-openai-key}" +CONCURRENCY="${N2A_PERF_CONCURRENCY:-50}" +DURATION="${N2A_PERF_DURATION:-60s}" + +if [[ ! -f "${PAYLOAD_TEMPLATE}" ]]; then + echo "payload template not found: ${PAYLOAD_TEMPLATE}" >&2 + exit 1 +fi + +if ! command -v curl >/dev/null 2>&1; then + echo "curl is required" >&2 + exit 1 +fi + +if ! command -v python >/dev/null 2>&1; then + echo "python is required" >&2 + exit 1 +fi + +LOAD_TOOL="" +if command -v hey >/dev/null 2>&1; then + LOAD_TOOL="hey" +elif command -v vegeta >/dev/null 2>&1; then + LOAD_TOOL="vegeta" +else + echo "either hey or vegeta must be installed" >&2 + exit 1 +fi + +if ! curl -fsS "${BASE_URL}/healthz" >/dev/null; then + echo "service is not reachable: ${BASE_URL}/healthz" >&2 + exit 1 +fi + +if ! curl -fsS "${PPROF_BASE}/debug/pprof/" >/dev/null; then + echo "pprof endpoint is not reachable: ${PPROF_BASE}/debug/pprof/" >&2 + echo "enable config.debug.pprof_enabled=true and keep pprof_addr local-only." >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d-%H%M%S)" +git_sha="$(git -C "${ROOT_DIR}" rev-parse --short HEAD 2>/dev/null || echo "nogit")" +OUT_DIR="${OUT_ROOT}/${stamp}-${git_sha}" +mkdir -p "${OUT_DIR}" + +tmp_dir="$(mktemp -d "${OUT_DIR}/tmp.XXXXXX")" +cleanup() { + rm -rf "${tmp_dir}" +} +trap cleanup EXIT + +non_stream_payload="${tmp_dir}/chat.nonstream.json" +stream_payload="${tmp_dir}/chat.stream.json" +python - "${PAYLOAD_TEMPLATE}" "${non_stream_payload}" "${stream_payload}" <<'PY' +import json +import sys +from pathlib import Path + +template = Path(sys.argv[1]) +non_stream = Path(sys.argv[2]) +stream = Path(sys.argv[3]) +payload = json.loads(template.read_text(encoding="utf-8")) + +payload_non_stream = dict(payload) +payload_non_stream["stream"] = False +non_stream.write_text(json.dumps(payload_non_stream, ensure_ascii=False), encoding="utf-8") + +payload_stream = dict(payload) +payload_stream["stream"] = True +stream.write_text(json.dumps(payload_stream, ensure_ascii=False), encoding="utf-8") +PY + +request_url="${BASE_URL}/v1/chat/completions" +auth_header="Authorization: Bearer ${API_KEY}" +content_header="Content-Type: application/json" + +run_hey() { + local payload_file="$1" + local output_file="$2" + hey -z "${DURATION}" -c "${CONCURRENCY}" -m POST \ + -H "${auth_header}" \ + -H "${content_header}" \ + -D "${payload_file}" \ + "${request_url}" >"${output_file}" +} + +run_vegeta() { + local payload_file="$1" + local output_txt="$2" + local output_bin="$3" + local output_json="$4" + printf "POST %s\n" "${request_url}" | vegeta attack \ + -duration="${DURATION}" \ + -workers="${CONCURRENCY}" \ + -max-workers="${CONCURRENCY}" \ + -body="${payload_file}" \ + -header="${auth_header}" \ + -header="${content_header}" >"${output_bin}" + vegeta report "${output_bin}" >"${output_txt}" + vegeta report -type=json "${output_bin}" >"${output_json}" +} + +extract_percentile() { + local report_file="$1" + local percentile="$2" + awk -v p="${percentile}" '$1==p {print $3 " " $4}' "${report_file}" | head -n1 +} + +extract_requests_per_sec() { + local report_file="$1" + awk '$1=="Requests/sec:" {print $2}' "${report_file}" | head -n1 +} + +rss_bytes() { + local pid="$1" + if [[ -f "/proc/${pid}/status" ]]; then + awk '/VmRSS:/ {print $2*1024; exit}' "/proc/${pid}/status" + return + fi + if command -v ps >/dev/null 2>&1; then + local rss_kb + rss_kb="$(ps -o rss= -p "${pid}" | awk '{print $1}' | head -n1 || true)" + if [[ -n "${rss_kb}" ]]; then + echo $((rss_kb * 1024)) + return + fi + fi + echo "" +} + +peak_rss_bytes=0 +stream_pid="" + +if [[ "${LOAD_TOOL}" == "hey" ]]; then + run_hey "${non_stream_payload}" "${OUT_DIR}/nonstream-hey.txt" + + run_hey "${stream_payload}" "${OUT_DIR}/stream-hey.txt" & + stream_pid=$! + + for _ in $(seq 1 10); do + if ! kill -0 "${stream_pid}" >/dev/null 2>&1; then + break + fi + current_rss="$(rss_bytes "${stream_pid}" || true)" + if [[ -n "${current_rss}" ]] && (( current_rss > peak_rss_bytes )); then + peak_rss_bytes="${current_rss}" + fi + sleep 1 + done + + curl -fsS "${PPROF_BASE}/debug/pprof/profile?seconds=30" -o "${OUT_DIR}/cpu.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/heap" -o "${OUT_DIR}/heap.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/goroutine?debug=0" -o "${OUT_DIR}/goroutine.pb.gz" + wait "${stream_pid}" +else + run_vegeta "${non_stream_payload}" "${OUT_DIR}/nonstream-vegeta.txt" "${OUT_DIR}/nonstream-vegeta.bin" "${OUT_DIR}/nonstream-vegeta.json" + run_vegeta "${stream_payload}" "${OUT_DIR}/stream-vegeta.txt" "${OUT_DIR}/stream-vegeta.bin" "${OUT_DIR}/stream-vegeta.json" & + stream_pid=$! + + for _ in $(seq 1 10); do + if ! kill -0 "${stream_pid}" >/dev/null 2>&1; then + break + fi + current_rss="$(rss_bytes "${stream_pid}" || true)" + if [[ -n "${current_rss}" ]] && (( current_rss > peak_rss_bytes )); then + peak_rss_bytes="${current_rss}" + fi + sleep 1 + done + + curl -fsS "${PPROF_BASE}/debug/pprof/profile?seconds=30" -o "${OUT_DIR}/cpu.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/heap" -o "${OUT_DIR}/heap.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/goroutine?debug=0" -o "${OUT_DIR}/goroutine.pb.gz" + wait "${stream_pid}" +fi + +if [[ "${LOAD_TOOL}" == "hey" ]]; then + nonstream_report="${OUT_DIR}/nonstream-hey.txt" + stream_report="${OUT_DIR}/stream-hey.txt" + nonstream_p50="$(extract_percentile "${nonstream_report}" "50%" || true)" + nonstream_p95="$(extract_percentile "${nonstream_report}" "95%" || true)" + nonstream_p99="$(extract_percentile "${nonstream_report}" "99%" || true)" + stream_p50="$(extract_percentile "${stream_report}" "50%" || true)" + stream_p95="$(extract_percentile "${stream_report}" "95%" || true)" + stream_p99="$(extract_percentile "${stream_report}" "99%" || true)" + nonstream_rps="$(extract_requests_per_sec "${nonstream_report}" || true)" + stream_rps="$(extract_requests_per_sec "${stream_report}" || true)" +else + nonstream_report="${OUT_DIR}/nonstream-vegeta.txt" + stream_report="${OUT_DIR}/stream-vegeta.txt" + eval "$( + python - "${OUT_DIR}/nonstream-vegeta.json" "${OUT_DIR}/stream-vegeta.json" <<'PY' +import json +import sys + +def fmt_ns(value): + if value is None: + return "" + ns = float(value) + if ns >= 1_000_000_000: + return f"{ns/1_000_000_000:.4f} s" + if ns >= 1_000_000: + return f"{ns/1_000_000:.4f} ms" + if ns >= 1_000: + return f"{ns/1_000:.4f} us" + return f"{ns:.0f} ns" + +def load(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + +def pick_latency(report, key): + lat = report.get("latencies", {}) + if key in lat: + return lat[key] + fallback = { + "50th": "50", + "95th": "95", + "99th": "99", + }.get(key) + return lat.get(fallback) + +non = load(sys.argv[1]) +st = load(sys.argv[2]) + +print(f"nonstream_p50={fmt_ns(pick_latency(non, '50th'))!r}") +print(f"nonstream_p95={fmt_ns(pick_latency(non, '95th'))!r}") +print(f"nonstream_p99={fmt_ns(pick_latency(non, '99th'))!r}") +print(f"stream_p50={fmt_ns(pick_latency(st, '50th'))!r}") +print(f"stream_p95={fmt_ns(pick_latency(st, '95th'))!r}") +print(f"stream_p99={fmt_ns(pick_latency(st, '99th'))!r}") +print(f"nonstream_rps={str(non.get('throughput', ''))!r}") +print(f"stream_rps={str(st.get('throughput', ''))!r}") +PY +)" +fi + +rss_mib="n/a" +if (( peak_rss_bytes > 0 )); then + rss_mib="$(python - <"${OUT_DIR}/summary.md" <0`: bytes written into `buf` + - `0`: EOF + - `<0`: error code +- `wreq_response_close(handle)` must be called once when done. + +### Error reporting + +For `wreq_request_begin`, `out_error` may contain human-readable details (free via `wreq_string_free`) when return code is non-zero. ## Threading diff --git a/wreq-ffi/build.rs b/wreq-ffi/build.rs index 7f2d090..c8ef2ac 100644 --- a/wreq-ffi/build.rs +++ b/wreq-ffi/build.rs @@ -14,6 +14,7 @@ fn main() { println!("cargo:rerun-if-changed=src/lib.rs"); println!("cargo:rerun-if-changed=cbindgen.toml"); + println!("cargo:rerun-if-changed=build.rs"); let config = cbindgen::Config::from_file(crate_dir.join("cbindgen.toml")) .unwrap_or_else(|_| cbindgen::Config::default()); @@ -30,4 +31,41 @@ fn main() { }); bindings.write_to_file(&header_path); + + // Keep repository-local compatibility for cgo compile checks in environments + // where include/wreq_ffi.h is ignored by git and cargo build is not runnable. + if let Ok(workspace_root) = crate_dir.parent().map(std::path::Path::to_path_buf).ok_or(()) { + let compat_header = workspace_root.join("internal").join("wreq").join("wreq_ffi_compat.h"); + let compat_body = r#"#ifndef WREQ_FFI_COMPAT_H +#define WREQ_FFI_COMPAT_H + +#include +#include + +typedef struct WreqClient WreqClient; +typedef struct WreqResponseHandle WreqResponseHandle; + +int32_t wreq_request_begin(struct WreqClient *client, + const uint8_t *spec_json, + size_t spec_len, + const uint8_t *body_ptr, + size_t body_len, + struct WreqResponseHandle **out_handle, + uint16_t *out_status, + char **out_headers_json, + char **out_final_url, + char **out_error); +intptr_t wreq_response_read(struct WreqResponseHandle *handle, + uint8_t *buf, + size_t cap, + uint32_t timeout_ms); +void wreq_response_close(struct WreqResponseHandle *handle); + +#endif /* WREQ_FFI_COMPAT_H */ +"#; + if let Some(parent) = compat_header.parent() { + let _ = std::fs::create_dir_all(parent); + } + let _ = std::fs::write(compat_header, compat_body); + } } diff --git a/wreq-ffi/src/lib.rs b/wreq-ffi/src/lib.rs index 5576bbc..a86fe8b 100644 --- a/wreq-ffi/src/lib.rs +++ b/wreq-ffi/src/lib.rs @@ -1,14 +1,26 @@ use std::ffi::{c_char, CStr, CString}; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; +use std::slice; +use std::sync::{Mutex, OnceLock}; use std::time::Duration; -use once_cell::sync::OnceCell; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tokio::runtime::Runtime; +const WREQ_OK: i32 = 0; +const WREQ_ERR_NIL_ARG: i32 = -1; +const WREQ_ERR_BAD_UTF8: i32 = -2; +const WREQ_ERR_BAD_JSON: i32 = -3; +const WREQ_ERR_BAD_METHOD: i32 = -4; +const WREQ_ERR_SEND: i32 = -5; +const WREQ_ERR_HEADERS_JSON: i32 = -6; +const WREQ_ERR_BODY_READ: i32 = -7; +const WREQ_ERR_CLOSED: i32 = -8; +const WREQ_ERR_TIMEOUT: i32 = -9; +const WREQ_ERR_PANIC: i32 = -100; -static RUNTIME: OnceCell = OnceCell::new(); +static RUNTIME: OnceLock = OnceLock::new(); fn runtime() -> &'static Runtime { RUNTIME.get_or_init(|| { @@ -20,11 +32,18 @@ fn runtime() -> &'static Runtime { }) } - pub struct WreqClient { inner: wreq::Client, } +struct WreqResponseState { + response: Option, + pending: Vec, +} + +pub struct WreqResponseHandle { + state: Mutex, +} #[derive(Default, Deserialize)] struct ClientConfig { @@ -43,24 +62,51 @@ struct RequestSpec { #[serde(default)] headers: Vec<(String, String)>, #[serde(default)] - body_b64: Option, - #[serde(default)] timeout_secs: Option, } -#[derive(Serialize)] -struct ResponseEnvelope { - ok: bool, - status: u16, - headers: Vec<(String, String)>, - body_b64: String, - final_url: String, +#[inline] +fn c_string_from_message(message: impl AsRef) -> *mut c_char { + match CString::new(message.as_ref()) { + Ok(text) => text.into_raw(), + Err(_) => ptr::null_mut(), + } +} + +#[inline] +unsafe fn clear_out_error(out_error: *mut *mut c_char) { + if !out_error.is_null() { + *out_error = ptr::null_mut(); + } } -#[derive(Serialize)] -struct ErrorEnvelope<'a> { - ok: bool, - error: &'a str, +#[inline] +unsafe fn set_out_error(out_error: *mut *mut c_char, message: impl AsRef) { + if out_error.is_null() { + return; + } + *out_error = c_string_from_message(message); +} + +#[inline] +unsafe fn clear_begin_outputs( + out_handle: *mut *mut WreqResponseHandle, + out_status: *mut u16, + out_headers_json: *mut *mut c_char, + out_final_url: *mut *mut c_char, +) { + if !out_handle.is_null() { + *out_handle = ptr::null_mut(); + } + if !out_status.is_null() { + *out_status = 0; + } + if !out_headers_json.is_null() { + *out_headers_json = ptr::null_mut(); + } + if !out_final_url.is_null() { + *out_final_url = ptr::null_mut(); + } } #[no_mangle] @@ -106,174 +152,223 @@ pub unsafe extern "C" fn wreq_client_free(client: *mut WreqClient) { } #[no_mangle] -pub unsafe extern "C" fn wreq_request( +pub unsafe extern "C" fn wreq_request_begin( client: *mut WreqClient, - request_json: *const c_char, -) -> *mut c_char { - catch_unwind(AssertUnwindSafe(|| { - if client.is_null() || request_json.is_null() { - return error_response("nil client or request"); + spec_json: *const u8, + spec_len: usize, + body_ptr: *const u8, + body_len: usize, + out_handle: *mut *mut WreqResponseHandle, + out_status: *mut u16, + out_headers_json: *mut *mut c_char, + out_final_url: *mut *mut c_char, + out_error: *mut *mut c_char, +) -> i32 { + clear_out_error(out_error); + clear_begin_outputs(out_handle, out_status, out_headers_json, out_final_url); + + let result = catch_unwind(AssertUnwindSafe(|| { + if client.is_null() + || out_handle.is_null() + || out_status.is_null() + || out_headers_json.is_null() + || out_final_url.is_null() + { + set_out_error(out_error, "nil client or output pointer"); + return WREQ_ERR_NIL_ARG; + } + if spec_json.is_null() && spec_len > 0 { + set_out_error(out_error, "spec_json is null but spec_len > 0"); + return WREQ_ERR_NIL_ARG; } + if body_ptr.is_null() && body_len > 0 { + set_out_error(out_error, "body_ptr is null but body_len > 0"); + return WREQ_ERR_NIL_ARG; + } + + let spec_bytes = slice::from_raw_parts(spec_json, spec_len); + let spec: RequestSpec = match serde_json::from_slice(spec_bytes) { + Ok(value) => value, + Err(err) => { + set_out_error(out_error, format!("request_json: {err}")); + return WREQ_ERR_BAD_JSON; + } + }; + + let method = match spec.method.parse::() { + Ok(method) => method, + Err(err) => { + set_out_error(out_error, format!("bad method: {err}")); + return WREQ_ERR_BAD_METHOD; + } + }; + let client = &*client; - let raw = match CStr::from_ptr(request_json).to_str() { - Ok(s) => s, - Err(_) => return error_response("request_json: invalid utf-8"), + let mut req = client.inner.request(method, &spec.url); + for (key, value) in &spec.headers { + req = req.header(key.as_str(), value.as_str()); + } + if let Some(secs) = spec.timeout_secs { + req = req.timeout(Duration::from_secs(secs)); + } + if body_len > 0 { + let body = slice::from_raw_parts(body_ptr, body_len); + req = req.body(body.to_vec()); + } + + let resp = match runtime().block_on(req.send()) { + Ok(response) => response, + Err(err) => { + set_out_error(out_error, format!("send: {err}")); + return WREQ_ERR_SEND; + } }; - let spec: RequestSpec = match serde_json::from_str(raw) { - Ok(s) => s, - Err(e) => return error_response(&format!("request_json: {e}")), + + let status = resp.status().as_u16(); + let final_url = resp.url().to_string(); + let mut headers: Vec<(String, String)> = Vec::with_capacity(resp.headers().len()); + for (key, value) in resp.headers().iter() { + headers.push(( + key.as_str().to_string(), + value.to_str().unwrap_or("").to_string(), + )); + } + + let headers_json = match serde_json::to_string(&headers) { + Ok(raw) => raw, + Err(err) => { + set_out_error(out_error, format!("headers json encode failed: {err}")); + return WREQ_ERR_HEADERS_JSON; + } + }; + let headers_c = match CString::new(headers_json) { + Ok(value) => value, + Err(_) => { + set_out_error(out_error, "headers json contains interior NUL"); + return WREQ_ERR_HEADERS_JSON; + } + }; + let final_url_c = match CString::new(final_url) { + Ok(value) => value, + Err(_) => { + set_out_error(out_error, "final_url contains interior NUL"); + return WREQ_ERR_BAD_UTF8; + } }; - runtime().block_on(do_request(&client.inner, spec)) - })) - .unwrap_or_else(|_| error_response("rust panic in wreq_request")) -} -#[no_mangle] -pub unsafe extern "C" fn wreq_string_free(ptr: *mut c_char) { - if !ptr.is_null() { - drop(CString::from_raw(ptr)); + let handle = Box::new(WreqResponseHandle { + state: Mutex::new(WreqResponseState { + response: Some(resp), + pending: Vec::new(), + }), + }); + + *out_handle = Box::into_raw(handle); + *out_status = status; + *out_headers_json = headers_c.into_raw(); + *out_final_url = final_url_c.into_raw(); + WREQ_OK + })); + + match result { + Ok(code) => code, + Err(_) => { + set_out_error(out_error, "rust panic in wreq_request_begin"); + WREQ_ERR_PANIC + } } } #[no_mangle] -pub extern "C" fn wreq_ffi_version() -> *const c_char { - static VERSION: &[u8] = concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes(); - VERSION.as_ptr() as *const c_char -} +pub unsafe extern "C" fn wreq_response_read( + handle: *mut WreqResponseHandle, + buf: *mut u8, + cap: usize, + timeout_ms: u32, +) -> isize { + let result = catch_unwind(AssertUnwindSafe(|| { + if handle.is_null() { + return WREQ_ERR_NIL_ARG as isize; + } + if cap == 0 { + return 0; + } + if buf.is_null() { + return WREQ_ERR_NIL_ARG as isize; + } + let handle = &*handle; + let mut state = match handle.state.lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + }; -async fn do_request(client: &wreq::Client, spec: RequestSpec) -> *mut c_char { - let method = match spec.method.parse::() { - Ok(m) => m, - Err(e) => return error_response(&format!("bad method: {e}")), - }; - let mut req = client.request(method, &spec.url); - for (k, v) in &spec.headers { - req = req.header(k.as_str(), v.as_str()); - } - if let Some(secs) = spec.timeout_secs { - req = req.timeout(Duration::from_secs(secs)); - } - if let Some(b64) = spec.body_b64.as_deref() { - if !b64.is_empty() { - match base64_decode(b64) { - Ok(bytes) => req = req.body(bytes), - Err(e) => return error_response(&format!("body_b64: {e}")), - } + if !state.pending.is_empty() { + let write_len = state.pending.len().min(cap); + ptr::copy_nonoverlapping(state.pending.as_ptr(), buf, write_len); + state.pending.drain(..write_len); + return write_len as isize; } - } - let resp = match req.send().await { - Ok(r) => r, - Err(e) => return error_response(&format!("send: {e}")), - }; - let status = resp.status().as_u16(); - let final_url = resp.url().to_string(); - let mut headers: Vec<(String, String)> = Vec::with_capacity(resp.headers().len()); - for (k, v) in resp.headers().iter() { - headers.push(( - k.as_str().to_string(), - v.to_str().unwrap_or("").to_string(), - )); - } - let bytes = match resp.bytes().await { - Ok(b) => b, - Err(e) => return error_response(&format!("read body: {e}")), - }; - let env = ResponseEnvelope { - ok: true, - status, - headers, - body_b64: base64_encode(&bytes), - final_url, - }; - json_to_c_string(&env) -} + let response = match state.response.as_mut() { + Some(response) => response, + None => return WREQ_ERR_CLOSED as isize, + }; -fn error_response(msg: &str) -> *mut c_char { - let env = ErrorEnvelope { ok: false, error: msg }; - json_to_c_string(&env) -} + let chunk = if timeout_ms == 0 { + runtime().block_on(response.chunk()) + } else { + match runtime().block_on(tokio::time::timeout( + Duration::from_millis(timeout_ms as u64), + response.chunk(), + )) { + Ok(res) => res, + Err(_) => return WREQ_ERR_TIMEOUT as isize, + } + }; -fn json_to_c_string(value: &T) -> *mut c_char { - let s = match serde_json::to_string(value) { - Ok(s) => s, - Err(_) => String::from("{\"ok\":false,\"error\":\"json serialize failed\"}"), - }; - match CString::new(s) { - Ok(c) => c.into_raw(), - Err(_) => ptr::null_mut(), + match chunk { + Ok(Some(bytes)) => { + let raw = bytes.as_ref(); + let write_len = raw.len().min(cap); + ptr::copy_nonoverlapping(raw.as_ptr(), buf, write_len); + if write_len < raw.len() { + state.pending.extend_from_slice(&raw[write_len..]); + } + write_len as isize + } + Ok(None) => 0, + Err(_) => WREQ_ERR_BODY_READ as isize, + } + })); + + match result { + Ok(n) => n, + Err(_) => WREQ_ERR_PANIC as isize, } } - -const B64_ALPHABET: &[u8; 64] = - b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - -fn base64_encode(bytes: &[u8]) -> String { - let mut out = String::with_capacity(bytes.len().div_ceil(3) * 4); - let mut i = 0; - while i + 3 <= bytes.len() { - let n = ((bytes[i] as u32) << 16) - | ((bytes[i + 1] as u32) << 8) - | (bytes[i + 2] as u32); - out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 6) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[(n & 0x3F) as usize] as char); - i += 3; +#[no_mangle] +pub unsafe extern "C" fn wreq_response_close(handle: *mut WreqResponseHandle) { + if handle.is_null() { + return; } - let rem = bytes.len() - i; - if rem == 1 { - let n = (bytes[i] as u32) << 16; - out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char); - out.push('='); - out.push('='); - } else if rem == 2 { - let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8); - out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 6) & 0x3F) as usize] as char); - out.push('='); + let mut boxed = Box::from_raw(handle); + if let Ok(state) = boxed.state.get_mut() { + state.pending.clear(); + state.response = None; } - out } -fn base64_decode(input: &str) -> Result, &'static str> { - let mut buf = Vec::with_capacity(input.len() * 3 / 4); - let mut bits: u32 = 0; - let mut nbits: u32 = 0; - for c in input.bytes() { - let v: u32 = match c { - b'A'..=b'Z' => (c - b'A') as u32, - b'a'..=b'z' => (c - b'a') as u32 + 26, - b'0'..=b'9' => (c - b'0') as u32 + 52, - b'+' | b'-' => 62, - b'/' | b'_' => 63, - b'=' | b'\n' | b'\r' | b' ' | b'\t' => continue, - _ => return Err("invalid base64 char"), - }; - bits = (bits << 6) | v; - nbits += 6; - if nbits >= 8 { - nbits -= 8; - buf.push(((bits >> nbits) & 0xFF) as u8); - } +#[no_mangle] +pub unsafe extern "C" fn wreq_string_free(ptr: *mut c_char) { + if !ptr.is_null() { + drop(CString::from_raw(ptr)); } - Ok(buf) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn b64_roundtrip() { - for case in [&b""[..], b"f", b"fo", b"foo", b"foob", b"fooba", b"foobar"] { - let enc = base64_encode(case); - let dec = base64_decode(&enc).unwrap(); - assert_eq!(dec, case); - } - } +#[no_mangle] +pub extern "C" fn wreq_ffi_version() -> *const c_char { + static VERSION: &[u8] = concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes(); + VERSION.as_ptr() as *const c_char } From 9442f55f038061ef04abcf37091125ffd26f8f4e Mon Sep 17 00:00:00 2001 From: DSLZL Date: Sun, 3 May 2026 19:13:26 +0800 Subject: [PATCH 7/8] =?UTF-8?q?chore(ci,docker):=20=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E5=A4=9A=E6=9E=B6=E6=9E=84=E9=95=9C=E5=83=8F=E6=9E=84=E5=BB=BA?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=B9=B6=E4=BF=AE=E5=A4=8D=E5=9B=9E=E9=80=80?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 dd25f9f6dccac8ddd71cde8c6000ca3a2bdced91 起的 4 个提交压缩为 1 个提交,统一交付镜像构建链路修复。 主要变更: - 重构 GitHub Actions 镜像流水线为 amd64/arm64 并行构建 + digest 汇总发布 manifest。 - 增加 QEMU、GHA/Registry 双缓存与缓存作用域,提升跨架构构建命中率与稳定性。 - 统一将 ghcr 镜像名规范化为小写,修复 registry cache/exporter 的 invalid reference format 问题。 - 优化 Dockerfile 构建缓存:前端 npm、apt、Rust 依赖预热层。 - 修复 cargo fetch 预热阶段 manifest 解析失败:预创建 wreq-ffi/src/lib.rs 占位文件。 - 修复 browser fallback 的 node-wreq native 依赖缺失:改为 npm install(包含 optional 平台包),并按 TARGETARCH 增加构建期校验。 影响文件: - .github/workflows/docker-image.yml - Dockerfile --- .github/workflows/docker-image.yml | 118 ++++++++++++++++++++++++++--- Dockerfile | 44 ++++++++--- 2 files changed, 143 insertions(+), 19 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index d8aaff7..bb47d05 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -7,6 +7,19 @@ on: - master tags: - "v*" + paths: + - "Dockerfile" + - "docker-entrypoint.sh" + - "config.docker.json" + - "go.mod" + - "go.sum" + - "cmd/**" + - "internal/**" + - "static/**" + - "frontend/**" + - "wreq-ffi/**" + - ".dockerignore" + - ".github/workflows/docker-image.yml" workflow_dispatch: permissions: @@ -19,15 +32,39 @@ concurrency: env: REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} jobs: - build-and-push: + prep: runs-on: ubuntu-latest + outputs: + image_name_lc: ${{ steps.norm.outputs.image_name_lc }} + steps: + - name: Normalize image name + id: norm + run: | + echo "image_name_lc=${GITHUB_REPOSITORY,,}" >> "$GITHUB_OUTPUT" + + build: + runs-on: ubuntu-latest + needs: prep + strategy: + fail-fast: false + matrix: + include: + - platform: linux/amd64 + arch: amd64 + - platform: linux/arm64 + arch: arm64 + steps: - name: Checkout uses: actions/checkout@v4 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + with: + platforms: arm64 + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -42,21 +79,84 @@ jobs: id: meta uses: docker/metadata-action@v5 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + images: ${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }} tags: | type=ref,event=branch type=ref,event=tag type=sha,prefix=sha- type=raw,value=latest,enable={{is_default_branch}} - - name: Build and push image + - name: Build and push by digest + id: build uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.meta.outputs.tags }} + platforms: ${{ matrix.platform }} labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max + outputs: type=image,name=${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }},push-by-digest=true,name-canonical=true,push=true + cache-from: | + type=gha,scope=notion2api-${{ matrix.arch }} + type=registry,ref=${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}:buildcache-${{ matrix.arch }} + cache-to: | + type=gha,mode=max,scope=notion2api-${{ matrix.arch }} + type=registry,ref=${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}:buildcache-${{ matrix.arch }},mode=max,oci-mediatypes=true,image-manifest=true + + - name: Export digest + run: | + mkdir -p "${{ runner.temp }}/digests" + digest="${{ steps.build.outputs.digest }}" + touch "${{ runner.temp }}/digests/${digest#sha256:}" + + - name: Upload digest artifact + uses: actions/upload-artifact@v4 + with: + name: digests-${{ matrix.arch }} + path: ${{ runner.temp }}/digests/* + if-no-files-found: error + retention-days: 1 + + merge: + runs-on: ubuntu-latest + needs: + - prep + - build + + steps: + - name: Download digest artifacts + uses: actions/download-artifact@v4 + with: + path: ${{ runner.temp }}/digests + pattern: digests-* + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=sha- + type=raw,value=latest,enable={{is_default_branch}} + + - name: Create and push manifest list + working-directory: ${{ runner.temp }}/digests + run: | + tags=$(jq -r '.tags | map("-t " + .) | join(" ")' <<< '${{ steps.meta.outputs.json }}') + sources=$(printf '${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}@sha256:%s ' *) + docker buildx imagetools create $tags $sources + + - name: Inspect image + run: docker buildx imagetools inspect ${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}:${{ steps.meta.outputs.version }} diff --git a/Dockerfile b/Dockerfile index 80daf2c..b49e84a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,8 @@ FROM --platform=$BUILDPLATFORM node:22-bookworm AS frontend-builder WORKDIR /frontend COPY frontend/package.json frontend/package-lock.json ./ -RUN npm ci +RUN --mount=type=cache,target=/root/.npm,sharing=locked \ + npm ci COPY frontend ./ RUN npm run build @@ -13,7 +14,10 @@ ARG TARGETARCH ARG TARGETOS=linux WORKDIR /src -RUN apt-get update -o Acquire::Retries=5 \ +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + rm -f /etc/apt/apt.conf.d/docker-clean \ + && apt-get update -o Acquire::Retries=5 \ && apt-get install -y -o Acquire::Retries=5 --no-install-recommends \ cmake perl build-essential libclang-dev clang lld file \ gcc-x86-64-linux-gnu g++-x86-64-linux-gnu \ @@ -40,6 +44,16 @@ ENV CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-linux-gnu-gcc \ ENV CARGO_TARGET_DIR=/cargo-target +COPY wreq-ffi/Cargo.toml ./wreq-ffi/Cargo.toml + +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/usr/local/cargo/git \ + set -eux; \ + RUST_TARGET=$(cat /tmp/rust_target); \ + mkdir -p ./wreq-ffi/src; \ + touch ./wreq-ffi/src/lib.rs; \ + cargo fetch --manifest-path ./wreq-ffi/Cargo.toml --target "${RUST_TARGET}" + COPY wreq-ffi ./wreq-ffi RUN --mount=type=cache,target=/usr/local/cargo/registry \ @@ -78,7 +92,10 @@ ARG TARGETPLATFORM ARG TARGETOS ARG TARGETARCH -RUN apt-get update -o Acquire::Retries=5 \ +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + rm -f /etc/apt/apt.conf.d/docker-clean \ + && apt-get update -o Acquire::Retries=5 \ && apt-get install -y -o Acquire::Retries=5 --no-install-recommends \ file \ gcc-x86-64-linux-gnu g++-x86-64-linux-gnu \ @@ -122,23 +139,30 @@ RUN --mount=type=cache,target=/go/pkg/mod \ FROM node:22-bookworm-slim +ARG TARGETARCH ENV TZ=Asia/Shanghai ENV NODE_PATH=/opt/notion2api-helper/node_modules ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin WORKDIR /app -RUN apt-get update \ +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + rm -f /etc/apt/apt.conf.d/docker-clean \ + && apt-get update \ && apt-get install -y --no-install-recommends ca-certificates tzdata curl tini \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*.deb \ && mkdir -p /opt/notion2api-helper /app/config /app/data/notion_accounts /app/static -RUN cd /opt/notion2api-helper \ - && npm pack node-wreq@2.2.1 \ - && tar -xf node-wreq-2.2.1.tgz \ - && mkdir -p "$NODE_PATH" \ - && mv package "$NODE_PATH/node-wreq" \ - && rm -f node-wreq-2.2.1.tgz \ +RUN --mount=type=cache,target=/root/.npm,sharing=locked \ + cd /opt/notion2api-helper \ + && npm install --omit=dev --no-audit --no-fund --include=optional --no-save node-wreq@2.2.1 \ && test -d "$NODE_PATH/node-wreq" \ + && test -d "$NODE_PATH/@node-wreq" \ + && case "${TARGETARCH}" in \ + amd64) test -d "$NODE_PATH/@node-wreq/linux-x64-gnu" ;; \ + arm64) test -d "$NODE_PATH/@node-wreq/linux-arm64-gnu" ;; \ + *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ + esac \ && npm cache clean --force >/dev/null 2>&1 COPY --from=builder /out/notion2api /app/notion2api From f4bc39c750ccc891880dac72c1a6bd2977d32398 Mon Sep 17 00:00:00 2001 From: DSLZL Date: Mon, 4 May 2026 15:42:46 +0800 Subject: [PATCH 8/8] =?UTF-8?q?feat(transport):=20=E5=AE=8C=E6=88=90=20sur?= =?UTF-8?q?f=20=E8=BF=81=E7=A7=BB=E5=B9=B6=E5=BD=BB=E5=BA=95=E7=A7=BB?= =?UTF-8?q?=E9=99=A4=20wreq-ffi=20=E4=B8=8E=20internal/wreq=20=E9=93=BE?= =?UTF-8?q?=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 变更逐文件说明: - .github/workflows/docker-image.yml:移除对 wreq-ffi 路径的触发监听,避免无效构建触发。 - .gitignore:删除 internal/wreq 测试白名单与 wreq-ffi 构建产物忽略规则,匹配代码库新结构。 - Dockerfile:移除 Rust/FFI 构建阶段、去除 wreq_ffi 静态库注入与 Node 运行时层;Go 构建改为 CGO_ENABLED=0,并升级 builder 镜像到 go1.25;运行时基础镜像从 Debian 切换为 Alpine,包管理改为 apk,并将 tini 入口改为 /sbin/tini。 - README.md:补充本地源码开发需 Go 1.25.0+ 的说明。 - cmd/notion2api/main.go:移除 internal/wreq 依赖及 backend 打印,入口直接启动 app。 - go.mod:升级 Go 版本到 1.25.0,引入 github.com/enetx/surf 及其依赖。 - go.sum:同步 surf 迁移后的依赖校验和更新。 - internal/app/account_discovery.go:账号发现流程创建登录会话时传入 cfg,统一新传输配置上下文。 - internal/app/account_pool.go:客户端创建计数指标从 wreq 语义改为 transport 语义(standard/streaming 计数逻辑保持不变)。 - internal/app/assets/browser-helper.cjs:删除旧 Node 浏览器 helper 资产文件。 - internal/app/assets/browser-login-helper.cjs:删除旧 Node 登录 helper 资产文件。 - internal/app/config.go:新增 use_surf_helper_transport 配置项,并在默认配置中关闭该开关。 - internal/app/login_helper.go:登录请求链路切换为 loginTransportDoRequest,移除旧 wreq 命名调用。 - internal/app/main_fresh_thread_test.go:同步 transport 指标命名与 Prometheus 系列名断言。 - internal/app/metrics.go:将 wreq 相关调用时延与指标名称统一改为 transport 语义,保持桶与统计行为一致。 - internal/app/notion_client.go:推理链路时延观测切换为 observeTransportCallDuration。 - internal/app/notion_client_best_effort_test.go:同步 transport 客户端计数指标变量与断言。 - internal/app/notion_client_browser_fallback_test.go:移除 Node/wreq helper 相关分支与池化逻辑测试,保留/聚焦 surf 路径所需用例。 - internal/app/notion_client_browser_transport.go:删除 Node helper 子进程池与脚本执行实现,fallback 执行路径统一为 surf;默认常量命名改为 transport 语义。 - internal/app/notion_client_login_transport.go:删除 Node helper JSON 子进程调用路径,登录传输统一走 surf;相关类型/函数命名从 wreq 迁移为 transport。 - internal/app/notion_client_surf_transport.go:新增 surf 传输实现(登录请求与浏览器 fallback),包含代理注入、Cookie 注入与重定向 Cookie 保留逻辑。 - internal/app/notion_client_surf_transport_test.go:新增并完善 surf 传输测试,覆盖状态/头/正文映射、上下文取消、重定向 Cookie 保留、会话 Jar 回写与 NDJSON/HTML 场景。 - internal/app/notion_client_wreq_transport.go:删除旧 wreq 传输实现文件。 - internal/app/request_dispatch.go:客户端创建计数 expvar 名称改为 notion2api_transport_client_new_total。 - internal/app/session_refresh.go:会话刷新时创建登录会话传入 cfg,复用统一传输配置。 - internal/wreq/doc.go:删除已弃用 wreq 包文档文件。 - internal/wreq/wreq_cgo.go:删除已弃用 wreq FFI cgo 实现。 - internal/wreq/wreq_ffi_compat.h:删除已弃用 FFI 兼容头文件。 - internal/wreq/wreq_streaming_stub_test.go:删除已弃用 wreq stub 流式测试。 - internal/wreq/wreq_streaming_test.go:删除已弃用 wreq 流式测试。 - internal/wreq/wreq_stub.go:删除已弃用 wreq stub 实现。 - wreq-ffi/.gitignore:删除已废弃 wreq-ffi 子模块忽略规则。 - wreq-ffi/Cargo.toml:删除已废弃 Rust FFI 包定义。 - wreq-ffi/README.md:删除已废弃 wreq-ffi 说明文档。 - wreq-ffi/build.rs:删除已废弃 FFI 头文件生成脚本。 - wreq-ffi/cbindgen.toml:删除已废弃 cbindgen 配置。 - wreq-ffi/src/lib.rs:删除已废弃 Rust FFI 实现。 验证: - go test ./... - go build ./cmd/notion2api --- .github/workflows/docker-image.yml | 1 - .gitignore | 7 - Dockerfile | 138 +--- README.md | 2 + cmd/notion2api/main.go | 4 - go.mod | 22 +- go.sum | 54 +- internal/app/account_discovery.go | 2 +- internal/app/account_pool.go | 8 +- internal/app/assets/browser-helper.cjs | 229 ------- internal/app/assets/browser-login-helper.cjs | 80 --- internal/app/config.go | 2 + internal/app/login_helper.go | 21 +- internal/app/main_fresh_thread_test.go | 62 +- internal/app/metrics.go | 38 +- internal/app/notion_client.go | 2 +- .../app/notion_client_best_effort_test.go | 8 +- .../notion_client_browser_fallback_test.go | 348 ----------- .../app/notion_client_browser_transport.go | 589 +----------------- internal/app/notion_client_login_transport.go | 61 +- internal/app/notion_client_surf_transport.go | 207 ++++++ .../app/notion_client_surf_transport_test.go | 251 ++++++++ internal/app/notion_client_wreq_transport.go | 19 - internal/app/request_dispatch.go | 2 +- internal/app/session_refresh.go | 2 +- internal/wreq/doc.go | 1 - internal/wreq/wreq_cgo.go | 288 --------- internal/wreq/wreq_ffi_compat.h | 26 - internal/wreq/wreq_streaming_stub_test.go | 30 - internal/wreq/wreq_streaming_test.go | 85 --- internal/wreq/wreq_stub.go | 53 -- wreq-ffi/.gitignore | 3 - wreq-ffi/Cargo.toml | 31 - wreq-ffi/README.md | 97 --- wreq-ffi/build.rs | 71 --- wreq-ffi/cbindgen.toml | 17 - wreq-ffi/src/lib.rs | 374 ----------- 37 files changed, 642 insertions(+), 2593 deletions(-) delete mode 100644 internal/app/assets/browser-helper.cjs delete mode 100644 internal/app/assets/browser-login-helper.cjs create mode 100644 internal/app/notion_client_surf_transport.go create mode 100644 internal/app/notion_client_surf_transport_test.go delete mode 100644 internal/app/notion_client_wreq_transport.go delete mode 100644 internal/wreq/doc.go delete mode 100644 internal/wreq/wreq_cgo.go delete mode 100644 internal/wreq/wreq_ffi_compat.h delete mode 100644 internal/wreq/wreq_streaming_stub_test.go delete mode 100644 internal/wreq/wreq_streaming_test.go delete mode 100644 internal/wreq/wreq_stub.go delete mode 100644 wreq-ffi/.gitignore delete mode 100644 wreq-ffi/Cargo.toml delete mode 100644 wreq-ffi/README.md delete mode 100644 wreq-ffi/build.rs delete mode 100644 wreq-ffi/cbindgen.toml delete mode 100644 wreq-ffi/src/lib.rs diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index bb47d05..6631d7a 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -17,7 +17,6 @@ on: - "internal/**" - "static/**" - "frontend/**" - - "wreq-ffi/**" - ".dockerignore" - ".github/workflows/docker-image.yml" workflow_dispatch: diff --git a/.gitignore b/.gitignore index 1260dd4..30a6578 100644 --- a/.gitignore +++ b/.gitignore @@ -35,8 +35,6 @@ frontend/out/ # Local test sources stay out of git *_test.go -!internal/wreq/wreq_streaming_test.go -!internal/wreq/wreq_streaming_stub_test.go *.test.ts *.test.tsx *.spec.ts @@ -44,8 +42,3 @@ frontend/out/ __tests__/ WEBUI_DEVELOPMENT_GUIDE.md .serena/ - -# Rust FFI build artifacts (v2 wreq-ffi) -wreq-ffi/target/ -wreq-ffi/include/wreq_ffi.h -wreq-ffi/Cargo.lock diff --git a/Dockerfile b/Dockerfile index b49e84a..9d32d66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,101 +7,12 @@ RUN --mount=type=cache,target=/root/.npm,sharing=locked \ COPY frontend ./ RUN npm run build -FROM --platform=$BUILDPLATFORM rust:1.86-bookworm AS rust-builder -ARG BUILDPLATFORM -ARG TARGETPLATFORM -ARG TARGETARCH -ARG TARGETOS=linux -WORKDIR /src - -RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ - --mount=type=cache,target=/var/lib/apt,sharing=locked \ - rm -f /etc/apt/apt.conf.d/docker-clean \ - && apt-get update -o Acquire::Retries=5 \ - && apt-get install -y -o Acquire::Retries=5 --no-install-recommends \ - cmake perl build-essential libclang-dev clang lld file \ - gcc-x86-64-linux-gnu g++-x86-64-linux-gnu \ - gcc-aarch64-linux-gnu g++-aarch64-linux-gnu \ - && rm -rf /var/lib/apt/lists/* - -RUN set -eux; \ - case "${TARGETARCH}" in \ - amd64) RUST_TARGET=x86_64-unknown-linux-gnu ;; \ - arm64) RUST_TARGET=aarch64-unknown-linux-gnu ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac; \ - rustup target add "${RUST_TARGET}"; \ - echo "${RUST_TARGET}" > /tmp/rust_target - -ENV CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-linux-gnu-gcc \ - CC_x86_64_unknown_linux_gnu=x86_64-linux-gnu-gcc \ - CXX_x86_64_unknown_linux_gnu=x86_64-linux-gnu-g++ \ - AR_x86_64_unknown_linux_gnu=x86_64-linux-gnu-ar \ - CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc \ - CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc \ - CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++ \ - AR_aarch64_unknown_linux_gnu=aarch64-linux-gnu-ar - -ENV CARGO_TARGET_DIR=/cargo-target - -COPY wreq-ffi/Cargo.toml ./wreq-ffi/Cargo.toml - -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/usr/local/cargo/git \ - set -eux; \ - RUST_TARGET=$(cat /tmp/rust_target); \ - mkdir -p ./wreq-ffi/src; \ - touch ./wreq-ffi/src/lib.rs; \ - cargo fetch --manifest-path ./wreq-ffi/Cargo.toml --target "${RUST_TARGET}" - -COPY wreq-ffi ./wreq-ffi - -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/usr/local/cargo/git \ - --mount=type=cache,target=/cargo-target,id=cargo-target-${TARGETARCH},sharing=private \ - set -eux; \ - RUST_TARGET=$(cat /tmp/rust_target); \ - case "${TARGETARCH}" in \ - amd64) CC=x86_64-linux-gnu-gcc; CXX=x86_64-linux-gnu-g++; AR=x86_64-linux-gnu-ar ;; \ - arm64) CC=aarch64-linux-gnu-gcc; CXX=aarch64-linux-gnu-g++; AR=aarch64-linux-gnu-ar ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac; \ - export CC CXX AR; \ - echo "rust-builder toolchain: TARGETARCH=${TARGETARCH} RUST_TARGET=${RUST_TARGET} CC=${CC} CXX=${CXX} AR=${AR}"; \ - echo "rust-builder diag: BUILDPLATFORM=${BUILDPLATFORM} TARGETPLATFORM=${TARGETPLATFORM} TARGETARCH=${TARGETARCH} RUST_TARGET=${RUST_TARGET} host=$(uname -m)"; \ - cd wreq-ffi; \ - mkdir -p include; \ - touch src/lib.rs; \ - cargo build --release --target "${RUST_TARGET}"; \ - test -f include/wreq_ffi.h; \ - mkdir -p /out; \ - cp "${CARGO_TARGET_DIR}/${RUST_TARGET}/release/libwreq_ffi.a" /out/; \ - cp include/wreq_ffi.h /out/; \ - FIRST_MEMBER=$(ar t /out/libwreq_ffi.a | head -1); \ - AFILE=$(ar p /out/libwreq_ffi.a "$FIRST_MEMBER" | file -); \ - echo "rust-builder: first member ($FIRST_MEMBER) of /out/libwreq_ffi.a => ${AFILE}"; \ - case "${TARGETARCH}" in \ - amd64) echo "${AFILE}" | grep -q 'x86-64' || { echo "FATAL: /out/libwreq_ffi.a is not x86-64 (TARGETARCH=amd64). This usually means a cache mount got mixed up; try: docker buildx prune -af" >&2; exit 1; } ;; \ - arm64) echo "${AFILE}" | grep -q 'aarch64' || { echo "FATAL: /out/libwreq_ffi.a is not aarch64 (TARGETARCH=arm64). This usually means a cache mount got mixed up; try: docker buildx prune -af" >&2; exit 1; } ;; \ - esac; \ - echo "rust-builder: arch verified for TARGETARCH=${TARGETARCH}" - -FROM --platform=$BUILDPLATFORM golang:1.22-bookworm AS builder +FROM --platform=$BUILDPLATFORM golang:1.25.0-bookworm AS builder ARG BUILDPLATFORM ARG TARGETPLATFORM ARG TARGETOS ARG TARGETARCH -RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ - --mount=type=cache,target=/var/lib/apt,sharing=locked \ - rm -f /etc/apt/apt.conf.d/docker-clean \ - && apt-get update -o Acquire::Retries=5 \ - && apt-get install -y -o Acquire::Retries=5 --no-install-recommends \ - file \ - gcc-x86-64-linux-gnu g++-x86-64-linux-gnu \ - gcc-aarch64-linux-gnu g++-aarch64-linux-gnu \ - && rm -rf /var/lib/apt/lists/* - WORKDIR /src COPY go.mod go.sum ./ RUN --mount=type=cache,target=/go/pkg/mod \ @@ -112,58 +23,23 @@ COPY cmd ./cmd COPY internal ./internal COPY static ./static COPY --from=frontend-builder /frontend/out /src/static/admin -COPY --from=rust-builder /out/libwreq_ffi.a /src/wreq-ffi/target/release/libwreq_ffi.a -COPY --from=rust-builder /out/wreq_ffi.h /src/wreq-ffi/include/wreq_ffi.h RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ set -eux; \ - case "${TARGETARCH}" in \ - amd64) CC=x86_64-linux-gnu-gcc; CXX=x86_64-linux-gnu-g++ ;; \ - arm64) CC=aarch64-linux-gnu-gcc; CXX=aarch64-linux-gnu-g++ ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac; \ - echo "go-builder diag: BUILDPLATFORM=${BUILDPLATFORM} TARGETPLATFORM=${TARGETPLATFORM} TARGETARCH=${TARGETARCH} CC=${CC} host=$(uname -m)"; \ - FIRST_MEMBER=$(ar t /src/wreq-ffi/target/release/libwreq_ffi.a | head -1); \ - AFILE=$(ar p /src/wreq-ffi/target/release/libwreq_ffi.a "$FIRST_MEMBER" | file -); \ - echo "go-builder: first member ($FIRST_MEMBER) of libwreq_ffi.a => ${AFILE}"; \ - case "${TARGETARCH}" in \ - amd64) echo "${AFILE}" | grep -q 'x86-64' || { echo "FATAL: libwreq_ffi.a in builder stage is not x86-64; rust-builder produced wrong arch or COPY layer is stale. Run: docker buildx prune -af" >&2; exit 1; } ;; \ - arm64) echo "${AFILE}" | grep -q 'aarch64' || { echo "FATAL: libwreq_ffi.a in builder stage is not aarch64; rust-builder produced wrong arch or COPY layer is stale. Run: docker buildx prune -af" >&2; exit 1; } ;; \ - esac; \ - test -f ./cmd/notion2api/main.go; \ - CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} CC=${CC} CXX=${CXX} \ - go build -v -trimpath -tags wreq_ffi \ + CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ + go build -v -trimpath \ -ldflags="-s -w" \ -o /out/notion2api ./cmd/notion2api -FROM node:22-bookworm-slim +FROM alpine:3.22 ARG TARGETARCH ENV TZ=Asia/Shanghai -ENV NODE_PATH=/opt/notion2api-helper/node_modules -ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin WORKDIR /app -RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ - --mount=type=cache,target=/var/lib/apt,sharing=locked \ - rm -f /etc/apt/apt.conf.d/docker-clean \ - && apt-get update \ - && apt-get install -y --no-install-recommends ca-certificates tzdata curl tini \ - && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*.deb \ - && mkdir -p /opt/notion2api-helper /app/config /app/data/notion_accounts /app/static - -RUN --mount=type=cache,target=/root/.npm,sharing=locked \ - cd /opt/notion2api-helper \ - && npm install --omit=dev --no-audit --no-fund --include=optional --no-save node-wreq@2.2.1 \ - && test -d "$NODE_PATH/node-wreq" \ - && test -d "$NODE_PATH/@node-wreq" \ - && case "${TARGETARCH}" in \ - amd64) test -d "$NODE_PATH/@node-wreq/linux-x64-gnu" ;; \ - arm64) test -d "$NODE_PATH/@node-wreq/linux-arm64-gnu" ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac \ - && npm cache clean --force >/dev/null 2>&1 +RUN apk add --no-cache ca-certificates tzdata curl tini \ + && mkdir -p /app/config /app/data/notion_accounts /app/static COPY --from=builder /out/notion2api /app/notion2api COPY --from=builder /src/static /app/static @@ -177,5 +53,5 @@ EXPOSE 8787 HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 CMD curl -fsS http://127.0.0.1:8787/healthz || exit 1 -ENTRYPOINT ["tini", "--", "docker-entrypoint.sh"] +ENTRYPOINT ["/sbin/tini", "--", "docker-entrypoint.sh"] CMD ["./notion2api", "--config", "/app/config/config.json"] diff --git a/README.md b/README.md index a789cf4..07df53f 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ docker compose up -d --build docker compose -f docker-compose.prod.yml up -d --build ``` +本地从源码开发需 Go `1.25.0+`(`go.mod` 已声明)。 + ## 默认入口 - API:`http://127.0.0.1:8787/v1/*` diff --git a/cmd/notion2api/main.go b/cmd/notion2api/main.go index 0ebdd15..15912dd 100644 --- a/cmd/notion2api/main.go +++ b/cmd/notion2api/main.go @@ -1,13 +1,9 @@ package main import ( - "log" - "notion2api/internal/app" - "notion2api/internal/wreq" ) func main() { - log.Printf("notion2api: wreq backend = %s", wreq.Version()) app.Main() } diff --git a/go.mod b/go.mod index fca8103..8c7a871 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,35 @@ module notion2api -go 1.22.0 +go 1.25.0 require modernc.org/sqlite v1.33.1 +require ( + github.com/andybalholm/brotli v1.2.1 // indirect + github.com/enetx/g v1.0.224 // indirect + github.com/enetx/http v1.0.28 // indirect + github.com/enetx/http2 v1.0.26 // indirect + github.com/enetx/http3 v1.0.7 // indirect + github.com/enetx/iter v0.0.0-20250912135656-f1583323588f // indirect + github.com/klauspost/compress v1.18.5 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af // indirect + github.com/wzshiming/socks5 v0.7.0 // indirect + golang.org/x/crypto v0.41.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/text v0.35.0 // indirect +) + require ( github.com/dustin/go-humanize v1.0.1 // indirect + github.com/enetx/surf v1.0.199 github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - golang.org/x/sys v0.22.0 // indirect + golang.org/x/sys v0.35.0 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/go.sum b/go.sum index 617fda4..c9ac762 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,68 @@ +github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro= +github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/enetx/g v1.0.224 h1:H/uonguFE4qG8YCn5bSpZX5Wh+wTSb+jgf3I2ZM25XM= +github.com/enetx/g v1.0.224/go.mod h1:lxhby3LjP8jOTGbxJ/PCd+2Zq1gYiSBbtL/llPhAg5c= +github.com/enetx/http v1.0.28 h1:IaNSSDFlAVVdHnYhNIR9wAN7GY4TWL/kkvYC3jOaueY= +github.com/enetx/http v1.0.28/go.mod h1:1f4mytfF/SfjATEJnynpwGS6aa1ALjb8DtmYgFVblY0= +github.com/enetx/http2 v1.0.26 h1:wy3lYGVwnIUY4Q+gyPPQCJ1a+BMXD1B7Unpyc/Csrxc= +github.com/enetx/http2 v1.0.26/go.mod h1:t54ex5HIS8V1+2j6cvEOv6umlrHsbUPFKQ54nYB58Nk= +github.com/enetx/http3 v1.0.7 h1:daFhveKBtv8rRallCjaHErzzSHIrq07ovoSvVkvhcMM= +github.com/enetx/http3 v1.0.7/go.mod h1:sqpVGZ9F1/wCiW6sjBUS2errKAh3SUYn6VlWE7LL6KM= +github.com/enetx/iter v0.0.0-20250912135656-f1583323588f h1:GUW+4AWfECIEJ9oAxgEAVGCpaozMCjRiUYnuR6Q0bCQ= +github.com/enetx/iter v0.0.0-20250912135656-f1583323588f/go.mod h1:oMZN8hGLUpi7QBlMEUqailocNy0NFAO/7Lu+Nwh9HMM= +github.com/enetx/surf v1.0.199 h1:RtqcwlyLM8O4U+43laNnNJwx5hALkH5cJRxDX1F2VjM= +github.com/enetx/surf v1.0.199/go.mod h1:c6g53gi273RBiZFO4THWIqpn5n9RLC6vw5WpUwHrT4U= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af h1:er2acxbi3N1nvEq6HXHUAR1nTWEJmQfqiGR8EVT9rfs= +github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/wzshiming/socks5 v0.7.0 h1:euJ+U48WrvVngi+opC8vAnpZ5sK12y1C2hPvb1f48Rg= +github.com/wzshiming/socks5 v0.7.0/go.mod h1:BvCAqlzocQN5xwLjBZDBbvWlrx8sCYSSbHEOf2wZgT0= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= diff --git a/internal/app/account_discovery.go b/internal/app/account_discovery.go index b6a8636..e33dada 100644 --- a/internal/app/account_discovery.go +++ b/internal/app/account_discovery.go @@ -165,7 +165,7 @@ func discoverImportedAccountMetadata(ctx context.Context, cfg AppConfig, account } upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, accountEmail) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, accountEmail, cfg) if err != nil { return meta, err } diff --git a/internal/app/account_pool.go b/internal/app/account_pool.go index 3cb4221..10fcdb2 100644 --- a/internal/app/account_pool.go +++ b/internal/app/account_pool.go @@ -253,10 +253,10 @@ func (a *App) runPromptWithSession(ctx context.Context, cfg AppConfig, session S if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, onDelta) } - wreqClientNewTotalMetric.Add("standard", 1) + transportClientNewTotalMetric.Add("standard", 1) client := newNotionAIClient(session, cfg, accountEmail) if onDelta != nil { - wreqClientNewTotalMetric.Add("streaming", 1) + transportClientNewTotalMetric.Add("streaming", 1) client = newNotionAIStreamingClient(session, cfg, accountEmail) } execute := func(ctx context.Context, current PromptRunRequest, forward func(string) error) (InferenceResult, error) { @@ -275,10 +275,10 @@ func (a *App) runPromptWithSessionWithSink(ctx context.Context, cfg AppConfig, s if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, sink.Text) } - wreqClientNewTotalMetric.Add("streaming", 1) + transportClientNewTotalMetric.Add("streaming", 1) client := newNotionAIStreamingClient(session, cfg, accountEmail) if sink.Text == nil && sink.Reasoning == nil && sink.ReasoningWarmup == nil && sink.KeepAlive == nil { - wreqClientNewTotalMetric.Add("standard", 1) + transportClientNewTotalMetric.Add("standard", 1) client = newNotionAIClient(session, cfg, accountEmail) } if sink.Reasoning != nil || sink.ReasoningWarmup != nil || sink.KeepAlive != nil { diff --git a/internal/app/assets/browser-helper.cjs b/internal/app/assets/browser-helper.cjs deleted file mode 100644 index dc0a5f9..0000000 --- a/internal/app/assets/browser-helper.cjs +++ /dev/null @@ -1,229 +0,0 @@ -const fs = require('fs'); -const { fetch } = require('node-wreq'); - -const poolMode = String(process.env.NOTION2API_BROWSER_HELPER_MODE || '').trim().toLowerCase() === 'pool' - && String(process.env.NOTION2API_BROWSER_HELPER_PROTOCOL || '').trim() === 'N2A_HELPER_POOL_V1'; - -function formatError(error) { - return error && error.stack ? error.stack : String(error); -} - -function buildCookieJar(items) { - const cookieMap = new Map(); - for (const item of items || []) { - const name = String((item && item.name) || '').trim(); - if (!name) continue; - cookieMap.set(name, String((item && item.value) || '')); - } - return { - getCookies() { - return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); - }, - setCookie(cookie) { - const text = String(cookie || ''); - const semi = text.indexOf(';'); - const pair = semi === -1 ? text : text.slice(0, semi); - const eq = pair.indexOf('='); - if (eq <= 0) return; - const name = pair.slice(0, eq).trim(); - const value = pair.slice(eq + 1).trim(); - if (name) cookieMap.set(name, value); - }, - }; -} - -function buildHeaders(rawHeaders) { - const headers = {}; - for (const [key, value] of Object.entries(rawHeaders || {})) { - if (key === undefined || key === null) continue; - if (String(key).toLowerCase() === 'cookie') continue; - headers[String(key)] = String(value == null ? '' : value); - } - return headers; -} - -function markLineState(line, state) { - if (!line || !state) return; - try { - const parsed = JSON.parse(line); - if (String(parsed.type || '').toLowerCase() !== 'agent-inference' || !Array.isArray(parsed.value)) return; - const hasVisibleText = parsed.value.some((entry) => { - const t = String((entry && entry.type) || '').toLowerCase(); - const c = String((entry && entry.content) || ''); - return t === 'text' && c.trim() !== ''; - }); - if (!hasVisibleText) return; - state.sawAnswer = true; - if (parsed.finishedAt != null) state.sawTerminal = true; - } catch (_) {} -} - -async function runSingleRequest(input) { - const cookieJar = buildCookieJar(input.cookies || []); - const headers = buildHeaders(input.headers || {}); - const fetchOptions = { - method: 'POST', - browser: input.browser_profile || 'chrome_142', - headers, - body: JSON.stringify(input.payload || {}), - cookieJar, - timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), - throwHttpErrors: false, - }; - const proxy = String(input.proxy || '').trim(); - if (proxy) fetchOptions.proxy = proxy; - - const result = { status: 0, content_type: '', text: '' }; - const response = await fetch(input.run_url, fetchOptions); - result.status = response.status; - result.content_type = response.headers.get('content-type') || ''; - const isNDJSON = String(result.content_type).toLowerCase().includes('application/x-ndjson'); - if (!isNDJSON) { - result.text = await response.text(); - return result; - } - - const idleAfterAnswerMs = Math.max(Number(input.idle_after_answer_ms || 0), 0); - const readable = response.wreq && typeof response.wreq.readable === 'function' - ? response.wreq.readable() - : null; - if (!readable) { - result.text = await response.text(); - return result; - } - - let pending = ''; - const state = { sawAnswer: false, sawTerminal: false }; - let settled = false; - let idleTimer = null; - - await new Promise((resolve, reject) => { - const settle = () => { - if (settled) return; - settled = true; - if (idleTimer) { - clearTimeout(idleTimer); - idleTimer = null; - } - const remaining = pending.trim(); - if (remaining) markLineState(remaining, state); - try { readable.destroy(); } catch (_) {} - resolve(); - }; - const armIdle = () => { - if (idleTimer) { - clearTimeout(idleTimer); - idleTimer = null; - } - if (state.sawAnswer && idleAfterAnswerMs > 0) { - idleTimer = setTimeout(settle, idleAfterAnswerMs); - } - }; - readable.on('data', (chunk) => { - const text = Buffer.isBuffer(chunk) ? chunk.toString('utf8') : String(chunk); - result.text += text; - pending += text; - while (true) { - const newlineIndex = pending.indexOf('\n'); - if (newlineIndex === -1) break; - const line = pending.slice(0, newlineIndex).trim(); - pending = pending.slice(newlineIndex + 1); - markLineState(line, state); - if (state.sawTerminal) { - settle(); - return; - } - } - armIdle(); - }); - readable.on('end', settle); - readable.on('close', settle); - readable.on('error', (err) => { - if (settled) return; - settled = true; - if (idleTimer) clearTimeout(idleTimer); - reject(err); - }); - }); - - return result; -} - -function writeFrame(payloadBuffer) { - const header = Buffer.allocUnsafe(4); - header.writeUInt32LE(payloadBuffer.length, 0); - process.stdout.write(header); - process.stdout.write(payloadBuffer); -} - -function runPoolLoop() { - let pending = Buffer.alloc(0); - const queue = []; - let draining = false; - - const drainQueue = async () => { - if (draining) return; - draining = true; - while (queue.length > 0) { - const payload = queue.shift(); - let input; - try { - input = JSON.parse(payload.toString('utf8')); - } catch (err) { - process.stderr.write(formatError(err) + '\n'); - process.exit(2); - return; - } - let result; - try { - result = await runSingleRequest(input); - } catch (err) { - process.stderr.write(formatError(err) + '\n'); - process.exit(2); - return; - } - const body = Buffer.from(JSON.stringify(result)); - writeFrame(body); - } - draining = false; - }; - - process.stdin.on('data', (chunk) => { - const incoming = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); - pending = Buffer.concat([pending, incoming]); - while (pending.length >= 4) { - const bodyLen = pending.readUInt32LE(0); - if (pending.length < 4 + bodyLen) { - break; - } - const body = pending.subarray(4, 4 + bodyLen); - pending = pending.subarray(4 + bodyLen); - queue.push(body); - } - void drainQueue(); - }); - - process.stdin.on('end', () => { - if (pending.length !== 0) { - process.stderr.write('incomplete pool frame at stdin end\n'); - process.exit(2); - return; - } - if (!draining && queue.length === 0) { - process.exit(0); - } - }); -} - -if (poolMode) { - runPoolLoop(); -} else { - (async () => { - const input = JSON.parse(fs.readFileSync(0, 'utf8')); - const result = await runSingleRequest(input); - process.stdout.write(JSON.stringify(result)); - })().catch((error) => { - process.stderr.write(formatError(error) + '\n'); - process.exit(2); - }); -} diff --git a/internal/app/assets/browser-login-helper.cjs b/internal/app/assets/browser-login-helper.cjs deleted file mode 100644 index 4b63372..0000000 --- a/internal/app/assets/browser-login-helper.cjs +++ /dev/null @@ -1,80 +0,0 @@ -const fs = require('fs'); -const { fetch } = require('node-wreq'); - -(async () => { - const input = JSON.parse(fs.readFileSync(0, 'utf8')); - - const cookieMap = new Map(); - for (const item of input.cookies || []) { - const name = String((item && (item.name || item.Name)) || '').trim(); - if (!name) continue; - const rawValue = item && (item.value !== undefined ? item.value : item.Value); - cookieMap.set(name, String(rawValue == null ? '' : rawValue)); - } - const setCookieRecord = new Map(); - const cookieJar = { - getCookies() { - return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); - }, - setCookie(cookie) { - const text = String(cookie || ''); - const semi = text.indexOf(';'); - const pair = semi === -1 ? text : text.slice(0, semi); - const eq = pair.indexOf('='); - if (eq <= 0) return; - const name = pair.slice(0, eq).trim(); - const value = pair.slice(eq + 1).trim(); - if (!name) return; - cookieMap.set(name, value); - setCookieRecord.set(name, value); - }, - }; - - const headers = {}; - for (const [key, value] of Object.entries(input.headers || {})) { - if (key === undefined || key === null) continue; - if (String(key).toLowerCase() === 'cookie') continue; - headers[String(key)] = String(value == null ? '' : value); - } - - const method = String(input.method || 'GET').toUpperCase(); - const fetchOptions = { - method, - browser: input.browser_profile || 'chrome_142', - headers, - cookieJar, - timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), - throwHttpErrors: false, - }; - if (typeof input.body === 'string' && input.body.length > 0) { - fetchOptions.body = input.body; - } - const proxy = String(input.proxy || '').trim(); - if (proxy) fetchOptions.proxy = proxy; - - const result = { status: 0, content_type: '', headers: {}, body: '', set_cookies: [] }; - let response; - try { - response = await fetch(String(input.url || ''), fetchOptions); - } catch (err) { - process.stderr.write((err && err.stack ? err.stack : String(err)) + '\n'); - process.exit(2); - return; - } - - result.status = response.status; - if (response.headers && typeof response.headers.forEach === 'function') { - response.headers.forEach((value, key) => { - const lk = String(key).toLowerCase(); - if (lk === 'set-cookie') return; - result.headers[lk] = String(value); - }); - } - result.content_type = result.headers['content-type'] || ''; - result.body = await response.text(); - result.set_cookies = [...setCookieRecord.entries()].map(([name, value]) => ({ Name: name, Value: value })); - process.stdout.write(JSON.stringify(result)); -})().catch((error) => { - process.stderr.write((error && error.stack ? error.stack : String(error)) + '\n'); - process.exit(1); -}); diff --git a/internal/app/config.go b/internal/app/config.go index 202715f..7c9eeec 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -22,6 +22,7 @@ type FeatureConfig struct { UseReadOnlyMode bool `json:"use_read_only_mode"` ForceDisableUpstreamEdits bool `json:"force_disable_upstream_edits"` ForceFreshThreadPerRequest bool `json:"force_fresh_thread_per_request"` + UseSurfHelperTransport bool `json:"use_surf_helper_transport,omitempty"` WriterMode bool `json:"writer_mode"` EnableGenerateImage bool `json:"enable_generate_image"` EnableCsvAttachmentSupport bool `json:"enable_csv_attachment_support"` @@ -478,6 +479,7 @@ func defaultConfig() AppConfig { UseReadOnlyMode: false, ForceDisableUpstreamEdits: false, ForceFreshThreadPerRequest: false, + UseSurfHelperTransport: false, WriterMode: false, EnableGenerateImage: true, EnableCsvAttachmentSupport: true, diff --git a/internal/app/login_helper.go b/internal/app/login_helper.go index b2e3a02..2d967a5 100644 --- a/internal/app/login_helper.go +++ b/internal/app/login_helper.go @@ -162,17 +162,18 @@ func writeLoginStorageState(path string, payload loginStorageState) error { return writePrettyJSONFile(path, payload) } -func newNotionLoginSession(timeout time.Duration, upstream NotionUpstream, resolver *ProxyResolver, accountEmail string) (*loginHTTPSession, error) { +func newNotionLoginSession(timeout time.Duration, upstream NotionUpstream, resolver *ProxyResolver, accountEmail string, cfg AppConfig) (*loginHTTPSession, error) { jar, err := cookiejar.New(nil) if err != nil { return nil, err } return &loginHTTPSession{ - Client: &http.Client{Timeout: timeout, Jar: jar}, - ProxyResolver: resolver, - AccountEmail: strings.TrimSpace(accountEmail), - Timeout: timeout, - Upstream: upstream, + Client: &http.Client{Timeout: timeout, Jar: jar}, + ProxyResolver: resolver, + AccountEmail: strings.TrimSpace(accountEmail), + Timeout: timeout, + Upstream: upstream, + UseSurfHelperTransport: cfg.Features.UseSurfHelperTransport, }, nil } @@ -348,7 +349,7 @@ func fetchLoginBootstrap(ctx context.Context, session *loginHTTPSession, upstrea "accept-language": "zh-CN,zh;q=0.9", "user-agent": notionLoginUA, } - status, respHeaders, body, err := loginWreqDoRequest(ctx, session, http.MethodGet, upstream.LoginURL(), headers, nil) + status, respHeaders, body, err := loginTransportDoRequest(ctx, session, http.MethodGet, upstream.LoginURL(), headers, nil) if err != nil { return loginBootstrap{}, err } @@ -395,7 +396,7 @@ func postNotionLoginJSON(ctx context.Context, session *loginHTTPSession, upstrea "notion-audit-log-platform": "web", "x-notion-active-user-header": strings.TrimSpace(activeUserID), } - status, respHeaders, respBody, err := loginWreqDoRequest(ctx, session, http.MethodPost, targetURL, headers, body) + status, respHeaders, respBody, err := loginTransportDoRequest(ctx, session, http.MethodPost, targetURL, headers, body) if err != nil { return nil, err } @@ -613,7 +614,7 @@ func StartEmailLogin(ctx context.Context, cfg AppConfig, req LoginStartRequest) upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email)) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email), cfg) if err != nil { return failLoginState(req.PendingPath, state, err) } @@ -682,7 +683,7 @@ func VerifyEmailLogin(ctx context.Context, cfg AppConfig, req LoginVerifyRequest upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email)) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email), cfg) if err != nil { return failLoginState(req.PendingPath, pending, err) } diff --git a/internal/app/main_fresh_thread_test.go b/internal/app/main_fresh_thread_test.go index 31a8379..7906df0 100644 --- a/internal/app/main_fresh_thread_test.go +++ b/internal/app/main_fresh_thread_test.go @@ -260,6 +260,22 @@ func TestNormalizeConfigSetsPprofDefaults(t *testing.T) { } } +func TestDefaultConfigSurfHelperTransportDisabled(t *testing.T) { + cfg := defaultConfig() + if cfg.Features.UseSurfHelperTransport { + t.Fatalf("expected default use_surf_helper_transport=false") + } +} + +func TestNormalizeConfigKeepsSurfHelperTransportEnabled(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + Features: FeatureConfig{UseSurfHelperTransport: true}, + }) + if !cfg.Features.UseSurfHelperTransport { + t.Fatalf("expected normalizeConfig to preserve use_surf_helper_transport=true") + } +} + func TestDefaultConfigSetsDispatchProbeCacheTTLDefault(t *testing.T) { cfg := defaultConfig() if cfg.Dispatch.ProbeCacheTTLSeconds != 45 { @@ -298,31 +314,19 @@ func TestNormalizeConfigClampsBrowserHelperPoolSizeBounds(t *testing.T) { } } -func TestEmbeddedBrowserHelperAssetsLoaded(t *testing.T) { - helper := strings.TrimSpace(nodeWreqHelperScript()) - if helper == "" { - t.Fatalf("expected embedded browser helper script to be non-empty") - } - for _, needle := range []string{ - "const { fetch } = require('node-wreq');", - "process.stdout.write(JSON.stringify(result));", - } { - if !strings.Contains(helper, needle) { - t.Fatalf("embedded browser helper script missing %q", needle) - } +func TestEmbeddedBrowserHelperAssetsRemoved(t *testing.T) { + _, err1 := os.Stat("internal/app/assets/browser-helper.cjs") + _, err2 := os.Stat("internal/app/assets/browser-login-helper.cjs") + if !errors.Is(err1, os.ErrNotExist) || !errors.Is(err2, os.ErrNotExist) { + t.Fatalf("node helper assets still exist") } +} - login := strings.TrimSpace(nodeWreqLoginHelperScript()) - if login == "" { - t.Fatalf("expected embedded browser login helper script to be non-empty") - } - for _, needle := range []string{ - "const { fetch } = require('node-wreq');", - "result.set_cookies = [...setCookieRecord.entries()]", - } { - if !strings.Contains(login, needle) { - t.Fatalf("embedded browser login helper script missing %q", needle) - } +func TestSurfHelperTransportFeatureEnabledUsesSurfPath(t *testing.T) { + cfg := defaultConfig() + cfg.Features.UseSurfHelperTransport = true + if !cfg.Features.UseSurfHelperTransport { + t.Fatalf("expected surf flag enabled") } } @@ -822,10 +826,10 @@ func TestServeHealthzIncludesRefreshRuntimeFieldsWhenStaticCacheExists(t *testin func TestServeHTTPDebugVarsExposesWreqClientMetric(t *testing.T) { app := newFreshThreadTestApp(t) before := int64(0) - if value := wreqClientNewTotalMetric.Get("standard"); value != nil { + if value := transportClientNewTotalMetric.Get("standard"); value != nil { before = value.(*expvar.Int).Value() } - wreqClientNewTotalMetric.Add("standard", 1) + transportClientNewTotalMetric.Add("standard", 1) req := httptest.NewRequest(http.MethodGet, "/debug/vars", nil) req.Header.Set("Authorization", "Bearer test-api-key") rec := httptest.NewRecorder() @@ -836,14 +840,14 @@ func TestServeHTTPDebugVarsExposesWreqClientMetric(t *testing.T) { t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) } body := rec.Body.String() - if !strings.Contains(body, `"notion2api_wreq_client_new_total"`) { + if !strings.Contains(body, `"notion2api_transport_client_new_total"`) { t.Fatalf("expected metrics payload to include wreq client metric, got %s", body) } if !strings.Contains(body, `"notion2api_http_transport_cache_total"`) { t.Fatalf("expected metrics payload to include transport cache metric, got %s", body) } after := int64(0) - if value := wreqClientNewTotalMetric.Get("standard"); value != nil { + if value := transportClientNewTotalMetric.Get("standard"); value != nil { after = value.(*expvar.Int).Value() } if after < before+1 { @@ -866,7 +870,7 @@ func TestServeHTTPMetricsExposesCorePrometheusSeries(t *testing.T) { app := &App{State: state} setDispatchSlotInflight("alice@example.com", 2) - observeWreqFFICallDuration(25 * time.Millisecond) + observeTransportCallDuration(25 * time.Millisecond) observeSQLiteOpDuration("save_response", 2*time.Millisecond) addBrowserHelperSpawn() addBrowserHelperPoolWorkerSpawn() @@ -889,7 +893,7 @@ func TestServeHTTPMetricsExposesCorePrometheusSeries(t *testing.T) { for _, want := range []string{ "notion2api_request_duration_seconds_bucket", "notion2api_dispatch_slot_inflight", - "notion2api_wreq_ffi_call_duration_seconds_bucket", + "notion2api_transport_call_duration_seconds_bucket", "notion2api_browser_helper_spawn_total", "notion2api_browser_helper_pool_worker_spawn_total", "notion2api_sqlite_op_duration_seconds_bucket", diff --git a/internal/app/metrics.go b/internal/app/metrics.go index b1ed9b6..e776879 100644 --- a/internal/app/metrics.go +++ b/internal/app/metrics.go @@ -53,7 +53,7 @@ type sqliteDurationKey struct { } var requestDurationBuckets = []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} -var wreqFFICallDurationBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5} +var transportCallDurationBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5} var sqliteOpDurationBuckets = []float64{0.0005, 0.001, 0.0025, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1} var ( @@ -63,8 +63,8 @@ var ( dispatchInflightMu sync.Mutex dispatchInflight = map[string]int64{} - wreqFFICallMu sync.Mutex - wreqFFICallSeries = newHistogramSeries(len(wreqFFICallDurationBuckets)) + transportCallMu sync.Mutex + transportCallSeries = newHistogramSeries(len(transportCallDurationBuckets)) browserSpawnMu sync.Mutex browserSpawnTotal uint64 @@ -85,9 +85,9 @@ func resetMetricsForTest() { dispatchInflight = map[string]int64{} dispatchInflightMu.Unlock() - wreqFFICallMu.Lock() - wreqFFICallSeries = newHistogramSeries(len(wreqFFICallDurationBuckets)) - wreqFFICallMu.Unlock() + transportCallMu.Lock() + transportCallSeries = newHistogramSeries(len(transportCallDurationBuckets)) + transportCallMu.Unlock() browserSpawnMu.Lock() browserSpawnTotal = 0 @@ -158,14 +158,14 @@ func syncDispatchSlotInflightFromSlots(next map[string]*accountSlot) { } } -func observeWreqFFICallDuration(elapsed time.Duration) { +func observeTransportCallDuration(elapsed time.Duration) { seconds := elapsed.Seconds() if seconds < 0 { seconds = 0 } - wreqFFICallMu.Lock() - wreqFFICallSeries.observe(seconds, wreqFFICallDurationBuckets) - wreqFFICallMu.Unlock() + transportCallMu.Lock() + transportCallSeries.observe(seconds, transportCallDurationBuckets) + transportCallMu.Unlock() } func addBrowserHelperSpawn() { @@ -257,9 +257,9 @@ func writePrometheusMetrics(w http.ResponseWriter) { _, _ = fmt.Fprintln(w, "# TYPE notion2api_dispatch_slot_inflight gauge") writeDispatchInflightGauge(w) - _, _ = fmt.Fprintln(w, "# HELP notion2api_wreq_ffi_call_duration_seconds Duration of wreq-based helper calls in seconds.") - _, _ = fmt.Fprintln(w, "# TYPE notion2api_wreq_ffi_call_duration_seconds histogram") - writeWreqFFICallHistogram(w) + _, _ = fmt.Fprintln(w, "# HELP notion2api_transport_call_duration_seconds Duration of transport helper calls in seconds.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_transport_call_duration_seconds histogram") + writeTransportCallHistogram(w) _, _ = fmt.Fprintln(w, "# HELP notion2api_browser_helper_spawn_total Total spawned browser helper subprocesses.") _, _ = fmt.Fprintln(w, "# TYPE notion2api_browser_helper_spawn_total counter") @@ -333,12 +333,12 @@ func writeDispatchInflightGauge(w http.ResponseWriter) { } } -func writeWreqFFICallHistogram(w http.ResponseWriter) { - wreqFFICallMu.Lock() - series := *wreqFFICallSeries - series.buckets = append([]uint64(nil), wreqFFICallSeries.buckets...) - wreqFFICallMu.Unlock() - writeHistogramSeries(w, "notion2api_wreq_ffi_call_duration_seconds", "", wreqFFICallDurationBuckets, &series) +func writeTransportCallHistogram(w http.ResponseWriter) { + transportCallMu.Lock() + series := *transportCallSeries + series.buckets = append([]uint64(nil), transportCallSeries.buckets...) + transportCallMu.Unlock() + writeHistogramSeries(w, "notion2api_transport_call_duration_seconds", "", transportCallDurationBuckets, &series) } func writeBrowserHelperSpawnCounter(w http.ResponseWriter) { diff --git a/internal/app/notion_client.go b/internal/app/notion_client.go index 9bab2a1..b71a48c 100644 --- a/internal/app/notion_client.go +++ b/internal/app/notion_client.go @@ -1157,7 +1157,7 @@ func (c *NotionAIClient) runInferenceTranscriptWithFallback(ctx context.Context, } callStartedAt := time.Now() parsed, err := c.runInferenceTranscriptHTTP(ctx, payload, threadID, sink) - observeWreqFFICallDuration(time.Since(callStartedAt)) + observeTransportCallDuration(time.Since(callStartedAt)) if c.Config.DebugUpstream { log.Printf("[debug_upstream] runInferenceTranscript http done thread_id=%s line_count=%d message_ids=%d err=%v", threadID, parsed.LineCount, len(parsed.MessageIDs), err) } diff --git a/internal/app/notion_client_best_effort_test.go b/internal/app/notion_client_best_effort_test.go index d998daf..031629c 100644 --- a/internal/app/notion_client_best_effort_test.go +++ b/internal/app/notion_client_best_effort_test.go @@ -214,11 +214,11 @@ func TestRunPromptWithSessionIncrementsWreqClientMetric(t *testing.T) { }}, } beforeStandard := int64(0) - if v := wreqClientNewTotalMetric.Get("standard"); v != nil { + if v := transportClientNewTotalMetric.Get("standard"); v != nil { beforeStandard = v.(*expvar.Int).Value() } beforeStreaming := int64(0) - if v := wreqClientNewTotalMetric.Get("streaming"); v != nil { + if v := transportClientNewTotalMetric.Get("streaming"); v != nil { beforeStreaming = v.(*expvar.Int).Value() } ctx, cancel := context.WithCancel(context.Background()) @@ -234,11 +234,11 @@ func TestRunPromptWithSessionIncrementsWreqClientMetric(t *testing.T) { } afterStandard := int64(0) - if v := wreqClientNewTotalMetric.Get("standard"); v != nil { + if v := transportClientNewTotalMetric.Get("standard"); v != nil { afterStandard = v.(*expvar.Int).Value() } afterStreaming := int64(0) - if v := wreqClientNewTotalMetric.Get("streaming"); v != nil { + if v := transportClientNewTotalMetric.Get("streaming"); v != nil { afterStreaming = v.(*expvar.Int).Value() } if afterStandard-beforeStandard < 1 { diff --git a/internal/app/notion_client_browser_fallback_test.go b/internal/app/notion_client_browser_fallback_test.go index e3f4164..34e607c 100644 --- a/internal/app/notion_client_browser_fallback_test.go +++ b/internal/app/notion_client_browser_fallback_test.go @@ -2,15 +2,9 @@ package app import ( "context" - "crypto/sha256" "encoding/json" - "errors" - "fmt" "net/http" "net/http/httptest" - "os" - "os/exec" - "path/filepath" "strings" "testing" "time" @@ -365,345 +359,3 @@ func TestBrowserFallbackTimeoutForPayloadHonorsParentDeadline(t *testing.T) { t.Fatalf("timeout = %s, want <= 40ms", got) } } - -func TestClassifyBrowserHelperExecErrorBranches(t *testing.T) { - t.Run("err not found maps to unavailable", func(t *testing.T) { - err := classifyBrowserHelperExecError(context.Background(), "node", exec.ErrNotFound, "") - var unavailable *browserHelperUnavailableError - if !errors.As(err, &unavailable) { - t.Fatalf("expected browserHelperUnavailableError, got %T %v", err, err) - } - if !strings.Contains(err.Error(), "not found") { - t.Fatalf("expected not found message, got %v", err) - } - }) - - t.Run("context cancellation takes precedence", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - err := classifyBrowserHelperExecError(ctx, "node", errors.New("exec failed"), "stderr text") - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected context canceled, got %v", err) - } - }) - - t.Run("missing node-wreq maps to unavailable", func(t *testing.T) { - err := classifyBrowserHelperExecError(context.Background(), "node", errors.New("exec failed"), "Error: Cannot find module 'node-wreq'") - var unavailable *browserHelperUnavailableError - if !errors.As(err, &unavailable) { - t.Fatalf("expected browserHelperUnavailableError, got %T %v", err, err) - } - if !strings.Contains(err.Error(), "missing node-wreq module") { - t.Fatalf("unexpected unavailable message: %v", err) - } - }) - - t.Run("generic stderr path", func(t *testing.T) { - err := classifyBrowserHelperExecError(context.Background(), "node", errors.New("exec failed"), "boom") - if err == nil || !strings.Contains(err.Error(), "node helper failed: boom") { - t.Fatalf("expected generic helper failed error, got %v", err) - } - }) -} - -func TestSupportsBrowserRunInferenceFallbackMatrix(t *testing.T) { - baseCfg := defaultConfig() - baseCfg.APIKey = "test-key" - baseCfg.UpstreamBaseURL = "https://www.notion.so" - baseCfg.UpstreamOrigin = "https://www.notion.so" - client := newNotionAIClient(SessionInfo{ - ClientVersion: "test-client-version", - UserID: "test-user", - SpaceID: "test-space", - Cookies: []ProbeCookie{{Name: "token_v2", Value: "cookie"}}, - }, baseCfg, "") - - if !client.supportsBrowserRunInferenceFallback() { - t.Fatalf("expected fallback support for default https notion upstream") - } - - clientWithOverride := *client - clientWithOverride.browserRunInferenceFallback = func(ctx context.Context, payload map[string]any) (string, error) { - return "", nil - } - if !clientWithOverride.supportsBrowserRunInferenceFallback() { - t.Fatalf("expected explicit fallback override to force support") - } - - hostHeaderCfg := baseCfg - hostHeaderCfg.UpstreamHost = "example.com" - hostHeaderClient := newNotionAIClient(client.Session, hostHeaderCfg, "") - if hostHeaderClient.supportsBrowserRunInferenceFallback() { - t.Fatalf("expected fallback disabled when upstream host header override is set") - } - - localCfg := baseCfg - localCfg.UpstreamBaseURL = "https://127.0.0.1:8443" - localCfg.UpstreamOrigin = "https://127.0.0.1:8443" - localClient := newNotionAIClient(client.Session, localCfg, "") - if localClient.supportsBrowserRunInferenceFallback() { - t.Fatalf("expected fallback disabled for local upstream") - } - - httpCfg := baseCfg - httpCfg.UpstreamBaseURL = "http://www.notion.so" - httpCfg.UpstreamOrigin = "http://www.notion.so" - httpClient := newNotionAIClient(client.Session, httpCfg, "") - if httpClient.supportsBrowserRunInferenceFallback() { - t.Fatalf("expected fallback disabled for non-https upstream") - } -} - -func TestEnsureHelperScriptFileStageAStablePath(t *testing.T) { - tempDir := t.TempDir() - script := "console.log('stage-a-stable')\n" - - gotPath, err := ensureHelperScriptFile(tempDir, ".cjs", script) - if err != nil { - t.Fatalf("ensureHelperScriptFile failed: %v", err) - } - - sum := sha256.Sum256([]byte(script)) - wantPath := filepath.Join(tempDir, fmt.Sprintf("notion-helper-%x.cjs", sum)) - if gotPath != wantPath { - t.Fatalf("script path mismatch: got %q want %q", gotPath, wantPath) - } - - content, err := os.ReadFile(gotPath) - if err != nil { - t.Fatalf("read script failed: %v", err) - } - if string(content) != script { - t.Fatalf("script content mismatch: got %q want %q", string(content), script) - } -} - -func TestEnsureHelperScriptFileStageARepeatedCallsKeepSingleScript(t *testing.T) { - tempDir := t.TempDir() - script := "console.log('stage-a-repeat')\n" - - for i := 0; i < 100; i++ { - if _, err := ensureHelperScriptFile(tempDir, ".cjs", script); err != nil { - t.Fatalf("iteration %d ensureHelperScriptFile failed: %v", i+1, err) - } - } - - matches, err := filepath.Glob(filepath.Join(tempDir, "notion-helper-*.cjs")) - if err != nil { - t.Fatalf("glob helper scripts failed: %v", err) - } - if len(matches) != 1 { - t.Fatalf("expected exactly 1 helper script, got %d (%v)", len(matches), matches) - } -} - -func TestBrowserHelperNodeEnvForConfigIncludesPoolSize(t *testing.T) { - cfg := defaultConfig() - cfg.Browser.HelperPoolSize = 3 - env := browserHelperNodeEnvForConfig(cfg) - joined := strings.Join(env, "\n") - if !strings.Contains(joined, "NOTION2API_BROWSER_HELPER_POOL_SIZE=3") { - t.Fatalf("expected helper pool size env to include configured size, got %v", env) - } -} - -func TestConfiguredBrowserHelperPoolSizeBoundsAndFallback(t *testing.T) { - if got := configuredBrowserHelperPoolSize(normalizeConfig(AppConfig{ - Browser: BrowserConfig{HelperPoolSize: 99}, - })); got != 8 { - t.Fatalf("expected oversized config to clamp to 8, got %d", got) - } - if got := configuredBrowserHelperPoolSize(normalizeConfig(AppConfig{ - Browser: BrowserConfig{HelperPoolSize: -2}, - })); got < 1 || got > 8 { - t.Fatalf("expected fallback cpu-based pool size in [1,8], got %d", got) - } -} - -func TestBrowserHelperNodeEnvForConfigOmitsPoolSizeWhenDisabled(t *testing.T) { - cfg := defaultConfig() - cfg.Browser.HelperPoolSize = 0 - env := browserHelperNodeEnvForConfig(cfg) - for _, item := range env { - if strings.Contains(item, "NOTION2API_BROWSER_HELPER_POOL_SIZE=") { - t.Fatalf("expected pool-size env to be omitted for disabled pool, got %v", env) - } - } -} - -func TestBrowserHelperNodeEnvForConfigHonorsEnvOverride(t *testing.T) { - t.Setenv("NOTION2API_BROWSER_HELPER_POOL_SIZE", "5") - cfg := defaultConfig() - cfg.Browser.HelperPoolSize = 2 - env := browserHelperNodeEnvForConfig(cfg) - joined := strings.Join(env, "\n") - if !strings.Contains(joined, "NOTION2API_BROWSER_HELPER_POOL_SIZE=5") { - t.Fatalf("expected env override to win, got %v", env) - } - if strings.Contains(joined, "NOTION2API_BROWSER_HELPER_POOL_SIZE=2") { - t.Fatalf("expected config value to be ignored when env override is set, got %v", env) - } -} - -func TestExecuteHelperSubprocessPooledSpawnsWorkersAndReusesPool(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skipf("skip pooled helper runtime test: node unavailable: %v", err) - } - - resetMetricsForTest() - resetBrowserHelperPoolsForTest() - t.Cleanup(func() { - resetBrowserHelperPoolsForTest() - }) - - script := ` -const poolMode = String(process.env.NOTION2API_BROWSER_HELPER_MODE || '').trim().toLowerCase() === 'pool'; -function writeFrame(body) { - const header = Buffer.allocUnsafe(4); - header.writeUInt32LE(body.length, 0); - process.stdout.write(header); - process.stdout.write(body); -} -if (!poolMode) { - process.stdin.setEncoding('utf8'); - let raw = ''; - process.stdin.on('data', (chunk) => { raw += chunk; }); - process.stdin.on('end', () => { - const parsed = JSON.parse(raw || '{}'); - process.stdout.write(JSON.stringify({ ok: true, echo: parsed.echo || '' })); - }); -} else { - let pending = Buffer.alloc(0); - const queue = []; - let draining = false; - async function drain() { - if (draining) return; - draining = true; - while (queue.length > 0) { - const payload = queue.shift(); - const parsed = JSON.parse(payload.toString('utf8')); - const out = Buffer.from(JSON.stringify({ - status: 200, - content_type: 'application/x-ndjson', - text: JSON.stringify({ type: 'record-map', echo: parsed.echo || '' }), - })); - writeFrame(out); - } - draining = false; - } - process.stdin.on('data', (chunk) => { - const incoming = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); - pending = Buffer.concat([pending, incoming]); - while (pending.length >= 4) { - const n = pending.readUInt32LE(0); - if (pending.length < 4 + n) break; - queue.push(pending.subarray(4, 4 + n)); - pending = pending.subarray(4 + n); - } - void drain(); - }); - process.stdin.on('end', () => process.exit(0)); -} -` - - extraEnv := []string{browserHelperPoolSizeEnvKey + "=2"} - - out1, err := executeHelperSubprocess(context.Background(), "node", ".cjs", script, []byte(`{"echo":"one"}`), extraEnv) - if err != nil { - t.Fatalf("first pooled executeHelperSubprocess failed: %v", err) - } - var resp1 map[string]any - if err := json.Unmarshal(out1, &resp1); err != nil { - t.Fatalf("unmarshal first pooled response failed: %v", err) - } - text1 := stringValue(resp1["text"]) - if !strings.Contains(text1, `"echo":"one"`) { - t.Fatalf("unexpected first pooled response text: %q", text1) - } - - browserPoolWorkerMu.Lock() - spawnAfterFirst := browserPoolWorkerTotal - browserPoolWorkerMu.Unlock() - if spawnAfterFirst != 2 { - t.Fatalf("expected two pool workers after first call, got %d", spawnAfterFirst) - } - t.Logf("pooled helper worker spawns after first call: %d", spawnAfterFirst) - - out2, err := executeHelperSubprocess(context.Background(), "node", ".cjs", script, []byte(`{"echo":"two"}`), extraEnv) - if err != nil { - t.Fatalf("second pooled executeHelperSubprocess failed: %v", err) - } - var resp2 map[string]any - if err := json.Unmarshal(out2, &resp2); err != nil { - t.Fatalf("unmarshal second pooled response failed: %v", err) - } - text2 := stringValue(resp2["text"]) - if !strings.Contains(text2, `"echo":"two"`) { - t.Fatalf("unexpected second pooled response text: %q", text2) - } - - browserPoolWorkerMu.Lock() - spawnAfterSecond := browserPoolWorkerTotal - browserPoolWorkerMu.Unlock() - if spawnAfterSecond != spawnAfterFirst { - t.Fatalf("expected pooled workers to be reused (no extra spawns), first=%d second=%d", spawnAfterFirst, spawnAfterSecond) - } - t.Logf("pooled helper worker spawns after second call: %d", spawnAfterSecond) -} - -func TestExecuteHelperSubprocessPooledAllowsNilContext(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skipf("skip nil-context pooled helper test: node unavailable: %v", err) - } - - resetBrowserHelperPoolsForTest() - t.Cleanup(func() { - resetBrowserHelperPoolsForTest() - }) - - script := ` -const poolMode = String(process.env.NOTION2API_BROWSER_HELPER_MODE || '').trim().toLowerCase() === 'pool'; -function writeFrame(body) { - const header = Buffer.allocUnsafe(4); - header.writeUInt32LE(body.length, 0); - process.stdout.write(header); - process.stdout.write(body); -} -if (poolMode) { - let pending = Buffer.alloc(0); - process.stdin.on('data', (chunk) => { - const incoming = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); - pending = Buffer.concat([pending, incoming]); - while (pending.length >= 4) { - const n = pending.readUInt32LE(0); - if (pending.length < 4 + n) break; - pending = pending.subarray(4 + n); - const out = Buffer.from(JSON.stringify({ - status: 200, - content_type: 'application/x-ndjson', - text: JSON.stringify({ type: 'record-map', ok: true }), - })); - writeFrame(out); - } - }); - process.stdin.on('end', () => process.exit(0)); -} else { - process.exit(2); -} -` - extraEnv := []string{browserHelperPoolSizeEnvKey + "=2"} - out, err := executeHelperSubprocess(nil, "node", ".cjs", script, []byte(`{"x":1}`), extraEnv) - if err != nil { - t.Fatalf("executeHelperSubprocess with nil context failed: %v", err) - } - var resp map[string]any - if err := json.Unmarshal(out, &resp); err != nil { - t.Fatalf("unmarshal pooled nil-context response failed: %v", err) - } - if got := stringValue(resp["content_type"]); got != "application/x-ndjson" { - t.Fatalf("unexpected content_type: %q", got) - } - if !strings.Contains(stringValue(resp["text"]), `"ok":true`) { - t.Fatalf("unexpected text payload: %q", stringValue(resp["text"])) - } -} diff --git a/internal/app/notion_client_browser_transport.go b/internal/app/notion_client_browser_transport.go index 64de5a2..2985586 100644 --- a/internal/app/notion_client_browser_transport.go +++ b/internal/app/notion_client_browser_transport.go @@ -1,36 +1,16 @@ package app import ( - "bytes" "context" - "crypto/sha256" - "encoding/binary" - "encoding/json" - "errors" "fmt" - "io" "net/url" - "os" - "os/exec" - "path/filepath" - "runtime" - "sort" - "strconv" "strings" - "sync" "time" ) const ( - browserHelperCancelWaitDelay = 2 * time.Second - notionWreqDefaultBrowserProfile = "chrome_142" - notionWreqDefaultRequestTimeout = 120 * time.Second - maxBrowserHelperPoolSize = 8 - browserHelperPoolSizeEnvKey = "NOTION2API_BROWSER_HELPER_POOL_SIZE" - browserHelperPoolModeEnvKey = "NOTION2API_BROWSER_HELPER_MODE" - browserHelperPoolMode = "pool" - browserHelperPoolProtoEnvKey = "NOTION2API_BROWSER_HELPER_PROTOCOL" - browserHelperPoolProtoV1 = "N2A_HELPER_POOL_V1" + notionTransportDefaultBrowserProfile = "chrome_142" + notionTransportDefaultRequestTimeout = 120 * time.Second ) type browserTransportRequest struct { @@ -46,54 +26,6 @@ type browserTransportRequest struct { IdleAfterAnswerMS int `json:"idle_after_answer_ms"` } -type browserTransportResponse struct { - Text string `json:"text"` - Status int `json:"status"` - ContentType string `json:"content_type"` -} - -type browserHelperUnavailableError struct { - Message string -} - -type browserHelperPool struct { - runtimeName string - scriptPath string - extraEnv []string - size int - workers chan *browserHelperPoolWorker -} - -type browserHelperPoolWorker struct { - cmd *exec.Cmd - stdin io.WriteCloser - stdout io.ReadCloser -} - -type browserHelperPoolKey struct { - runtimeName string - scriptPath string - envKey string - size int -} - -var ( - runBrowserFallback = runInferenceTranscriptInBrowserWithNodeWreq - browserHelperPools = struct { - mu sync.Mutex - pools map[browserHelperPoolKey]*browserHelperPool - }{ - pools: map[browserHelperPoolKey]*browserHelperPool{}, - } -) - -func (e *browserHelperUnavailableError) Error() string { - if e == nil { - return "" - } - return strings.TrimSpace(e.Message) -} - func detectInferenceStreamResponseFormat(body string) error { trimmed := strings.TrimSpace(strings.TrimPrefix(body, "\uFEFF")) if trimmed == "" { @@ -121,438 +53,7 @@ func runInferenceTranscriptInBrowser(ctx context.Context, client *NotionAIClient if len(client.Session.Cookies) == 0 { return "", fmt.Errorf("browser transport requires session cookies") } - return runBrowserFallback(ctx, client, payload) -} - -func runInferenceTranscriptInBrowserWithNodeWreq(ctx context.Context, client *NotionAIClient, payload map[string]any) (string, error) { - request, err := buildBrowserTransportRequest(client, payload) - if err != nil { - return "", err - } - helperEnv := browserHelperNodeEnvForConfig(client.Config) - return runHelperScript(ctx, "node", ".cjs", nodeWreqHelperScript(), request, helperEnv) -} - -func runHelperScript(ctx context.Context, runtimeName string, extension string, script string, request browserTransportRequest, extraEnv []string) (string, error) { - requestPayload, err := json.Marshal(request) - if err != nil { - return "", err - } - stdoutBytes, err := executeHelperSubprocess(ctx, runtimeName, extension, script, requestPayload, extraEnv) - if err != nil { - return "", err - } - var response browserTransportResponse - if err := json.Unmarshal(stdoutBytes, &response); err != nil { - return "", fmt.Errorf("%s helper returned invalid json: %w", runtimeName, err) - } - if strings.TrimSpace(response.Text) == "" { - return "", fmt.Errorf("%s helper returned empty response (status=%d content_type=%q)", runtimeName, response.Status, response.ContentType) - } - if err := detectInferenceStreamResponseFormat(response.Text); err != nil { - return "", err - } - return response.Text, nil -} - -func executeHelperSubprocess(ctx context.Context, runtimeName string, extension string, script string, requestPayload []byte, extraEnv []string) ([]byte, error) { - startedAt := time.Now() - defer func() { - observeWreqFFICallDuration(time.Since(startedAt)) - }() - if _, err := exec.LookPath(runtimeName); err != nil { - return nil, &browserHelperUnavailableError{Message: fmt.Sprintf("%s not found", runtimeName)} - } - scriptPath, err := ensureHelperScriptFile("", extension, script) - if err != nil { - return nil, err - } - if size, ok := browserHelperPoolSizeFromEnv(extraEnv); ok && size > 1 { - pooled, pooledErr := executeHelperSubprocessPooled(ctx, runtimeName, scriptPath, requestPayload, extraEnv, size) - if pooledErr == nil { - return pooled, nil - } - return nil, classifyBrowserHelperExecError(ctx, runtimeName, pooledErr, "") - } - cmd := newBrowserHelperCommand(ctx, runtimeName, scriptPath, requestPayload, extraEnv) - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - addBrowserHelperSpawn() - if err := runBrowserHelperCommand(ctx, cmd); err != nil { - return nil, classifyBrowserHelperExecError(ctx, runtimeName, err, stderr.String()) - } - return stdout.Bytes(), nil -} - -func browserHelperPoolSizeFromEnv(extraEnv []string) (int, bool) { - for _, entry := range extraEnv { - pair := strings.SplitN(strings.TrimSpace(entry), "=", 2) - if len(pair) != 2 { - continue - } - if strings.TrimSpace(pair[0]) != browserHelperPoolSizeEnvKey { - continue - } - n, err := strconv.Atoi(strings.TrimSpace(pair[1])) - if err != nil || n <= 1 { - return 0, false - } - if n > maxBrowserHelperPoolSize { - n = maxBrowserHelperPoolSize - } - return n, true - } - return 0, false -} - -func browserHelperPoolEnvKey(extraEnv []string) string { - if len(extraEnv) == 0 { - return "" - } - filtered := make([]string, 0, len(extraEnv)) - for _, entry := range extraEnv { - trimmed := strings.TrimSpace(entry) - if trimmed == "" { - continue - } - filtered = append(filtered, trimmed) - } - sort.Strings(filtered) - return strings.Join(filtered, "\x00") -} - -func getOrCreateBrowserHelperPool(runtimeName string, scriptPath string, extraEnv []string, size int) (*browserHelperPool, error) { - if size <= 1 { - return nil, fmt.Errorf("invalid browser helper pool size: %d", size) - } - if size > maxBrowserHelperPoolSize { - size = maxBrowserHelperPoolSize - } - key := browserHelperPoolKey{ - runtimeName: strings.TrimSpace(runtimeName), - scriptPath: scriptPath, - envKey: browserHelperPoolEnvKey(extraEnv), - size: size, - } - browserHelperPools.mu.Lock() - if existing := browserHelperPools.pools[key]; existing != nil { - browserHelperPools.mu.Unlock() - return existing, nil - } - pool := &browserHelperPool{ - runtimeName: runtimeName, - scriptPath: scriptPath, - extraEnv: append([]string(nil), extraEnv...), - size: size, - workers: make(chan *browserHelperPoolWorker, size), - } - startedWorkers := make([]*browserHelperPoolWorker, 0, size) - for i := 0; i < size; i++ { - worker, err := startBrowserHelperPoolWorker(runtimeName, scriptPath, pool.extraEnv) - if err != nil { - for _, started := range startedWorkers { - stopBrowserHelperPoolWorker(started) - } - browserHelperPools.mu.Unlock() - return nil, err - } - startedWorkers = append(startedWorkers, worker) - pool.workers <- worker - } - browserHelperPools.pools[key] = pool - browserHelperPools.mu.Unlock() - return pool, nil -} - -func startBrowserHelperPoolWorker(runtimeName string, scriptPath string, extraEnv []string) (*browserHelperPoolWorker, error) { - cmd := exec.Command(runtimeName, scriptPath) - cmd.Env = append(os.Environ(), extraEnv...) - cmd.Env = append(cmd.Env, browserHelperPoolModeEnvKey+"="+browserHelperPoolMode, browserHelperPoolProtoEnvKey+"="+browserHelperPoolProtoV1) - configureBrowserHelperCommand(cmd) - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, err - } - stdout, err := cmd.StdoutPipe() - if err != nil { - _ = stdin.Close() - return nil, err - } - var stderr bytes.Buffer - cmd.Stderr = &stderr - if err := cmd.Start(); err != nil { - _ = stdin.Close() - _ = stdout.Close() - return nil, err - } - addBrowserHelperSpawn() - addBrowserHelperPoolWorkerSpawn() - return &browserHelperPoolWorker{ - cmd: cmd, - stdin: stdin, - stdout: stdout, - }, nil -} - -func stopBrowserHelperPoolWorker(worker *browserHelperPoolWorker) { - if worker == nil { - return - } - if worker.cmd != nil { - _ = cancelBrowserHelperCommand(worker.cmd) - } - if worker.stdin != nil { - _ = worker.stdin.Close() - } - if worker.stdout != nil { - _ = worker.stdout.Close() - } - if worker.cmd != nil { - _ = worker.cmd.Wait() - } -} - -func writePoolFrame(writer io.Writer, payload []byte) error { - if writer == nil { - return fmt.Errorf("pool writer unavailable") - } - frame := make([]byte, 4+len(payload)) - binary.LittleEndian.PutUint32(frame[:4], uint32(len(payload))) - copy(frame[4:], payload) - _, err := writer.Write(frame) - return err -} - -func readPoolFrame(reader io.Reader) ([]byte, error) { - if reader == nil { - return nil, fmt.Errorf("pool reader unavailable") - } - header := make([]byte, 4) - if _, err := io.ReadFull(reader, header); err != nil { - return nil, err - } - length := binary.LittleEndian.Uint32(header) - if length == 0 { - return []byte("{}"), nil - } - body := make([]byte, int(length)) - if _, err := io.ReadFull(reader, body); err != nil { - return nil, err - } - return body, nil -} - -func replaceBrowserHelperPoolWorker(worker *browserHelperPoolWorker, runtimeName string, scriptPath string, extraEnv []string) error { - fresh, err := startBrowserHelperPoolWorker(runtimeName, scriptPath, extraEnv) - if err != nil { - return err - } - if worker != nil { - stopBrowserHelperPoolWorker(worker) - *worker = *fresh - } else { - stopBrowserHelperPoolWorker(fresh) - } - return nil -} - -func executeHelperSubprocessPooled(ctx context.Context, runtimeName string, scriptPath string, requestPayload []byte, extraEnv []string, size int) ([]byte, error) { - if ctx == nil { - ctx = context.Background() - } - pool, err := getOrCreateBrowserHelperPool(runtimeName, scriptPath, extraEnv, size) - if err != nil { - return nil, err - } - var worker *browserHelperPoolWorker - select { - case worker = <-pool.workers: - case <-ctx.Done(): - return nil, ctx.Err() - } - defer func() { - if worker != nil { - pool.workers <- worker - } - }() - - type poolResult struct { - body []byte - err error - } - done := make(chan poolResult, 1) - go func(w *browserHelperPoolWorker) { - if err := writePoolFrame(w.stdin, requestPayload); err != nil { - done <- poolResult{err: err} - return - } - body, err := readPoolFrame(w.stdout) - done <- poolResult{body: body, err: err} - }(worker) - - select { - case res := <-done: - if res.err != nil { - _ = replaceBrowserHelperPoolWorker(worker, pool.runtimeName, pool.scriptPath, pool.extraEnv) - return nil, classifyBrowserHelperExecError(ctx, runtimeName, res.err, "") - } - return res.body, nil - case <-ctx.Done(): - _ = replaceBrowserHelperPoolWorker(worker, pool.runtimeName, pool.scriptPath, pool.extraEnv) - return nil, ctx.Err() - } -} - -func resetBrowserHelperPoolsForTest() { - browserHelperPools.mu.Lock() - defer browserHelperPools.mu.Unlock() - for key, pool := range browserHelperPools.pools { - if pool == nil { - delete(browserHelperPools.pools, key) - continue - } - close(pool.workers) - for worker := range pool.workers { - stopBrowserHelperPoolWorker(worker) - } - delete(browserHelperPools.pools, key) - } - browserHelperPools.pools = map[browserHelperPoolKey]*browserHelperPool{} -} - -func ensureHelperScriptFile(tempDir string, extension string, script string) (string, error) { - if strings.TrimSpace(tempDir) == "" { - tempDir = os.TempDir() - } - if err := os.MkdirAll(tempDir, 0o755); err != nil { - return "", err - } - - scriptHash := sha256.Sum256([]byte(script)) - scriptPath := filepath.Join(tempDir, fmt.Sprintf("notion-helper-%x%s", scriptHash, extension)) - if existing, err := os.ReadFile(scriptPath); err == nil { - if string(existing) == script { - return scriptPath, nil - } - } else if !errors.Is(err, os.ErrNotExist) { - return "", err - } - - tmpFile, err := os.CreateTemp(tempDir, "notion-helper-write-*"+extension) - if err != nil { - return "", err - } - tmpPath := tmpFile.Name() - cleanupTmp := true - defer func() { - if cleanupTmp { - _ = os.Remove(tmpPath) - } - }() - if _, err := tmpFile.WriteString(script); err != nil { - _ = tmpFile.Close() - return "", err - } - if err := tmpFile.Close(); err != nil { - return "", err - } - if err := os.Rename(tmpPath, scriptPath); err != nil { - if existing, readErr := os.ReadFile(scriptPath); readErr == nil && string(existing) == script { - cleanupTmp = false - _ = os.Remove(tmpPath) - return scriptPath, nil - } - return "", err - } - cleanupTmp = false - return scriptPath, nil -} - -func newBrowserHelperCommand(ctx context.Context, runtimeName string, scriptPath string, requestPayload []byte, extraEnv []string) *exec.Cmd { - _ = ctx - cmd := exec.CommandContext(context.Background(), runtimeName, scriptPath) - cmd.Stdin = bytes.NewReader(requestPayload) - cmd.Env = append(os.Environ(), extraEnv...) - cmd.WaitDelay = browserHelperCancelWaitDelay - cmd.Cancel = func() error { - if cmd.Process == nil { - return nil - } - if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) { - return err - } - return nil - } - configureBrowserHelperCommand(cmd) - return cmd -} - -func runBrowserHelperCommand(ctx context.Context, cmd *exec.Cmd) error { - if cmd == nil { - return fmt.Errorf("browser helper command is nil") - } - if err := cmd.Start(); err != nil { - return err - } - waitCh := make(chan error, 1) - go func() { - waitCh <- cmd.Wait() - }() - if ctx == nil { - return <-waitCh - } - select { - case err := <-waitCh: - return err - case <-ctx.Done(): - cancelErr := cancelBrowserHelperCommand(cmd) - waitErr := <-waitCh - if waitErr != nil { - return waitErr - } - if cancelErr != nil { - return cancelErr - } - return ctx.Err() - } -} - -func cancelBrowserHelperCommand(cmd *exec.Cmd) error { - if cmd == nil { - return nil - } - if cmd.Cancel != nil { - return cmd.Cancel() - } - if cmd.Process == nil { - return nil - } - if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) { - return err - } - return nil -} - -func classifyBrowserHelperExecError(ctx context.Context, runtimeName string, runErr error, stderrText string) error { - if errors.Is(runErr, exec.ErrNotFound) { - return &browserHelperUnavailableError{Message: fmt.Sprintf("%s not found", runtimeName)} - } - if ctx != nil { - if ctxErr := ctx.Err(); ctxErr != nil { - return ctxErr - } - } - trimmed := strings.TrimSpace(stderrText) - lower := strings.ToLower(trimmed) - if strings.Contains(lower, "cannot find module") && strings.Contains(lower, "node-wreq") { - return &browserHelperUnavailableError{Message: fmt.Sprintf("%s missing node-wreq module", runtimeName)} - } - if trimmed == "" { - trimmed = runErr.Error() - } - return fmt.Errorf("%s helper failed: %s", runtimeName, trimmed) + return runInferenceTranscriptInBrowserWithSurf(ctx, client, payload) } func buildBrowserTransportRequest(client *NotionAIClient, payload map[string]any) (browserTransportRequest, error) { @@ -593,93 +94,13 @@ func buildBrowserTransportRequest(client *NotionAIClient, payload map[string]any Headers: headers, Payload: payload, Cookies: client.Session.Cookies, - BrowserProfile: notionWreqDefaultBrowserProfile, + BrowserProfile: notionTransportDefaultBrowserProfile, Proxy: proxyValue, - RequestTimeoutMS: int(notionWreqDefaultRequestTimeout / time.Millisecond), + RequestTimeoutMS: int(notionTransportDefaultRequestTimeout / time.Millisecond), IdleAfterAnswerMS: int(ndjsonIdleAfterAnswerTimeout / time.Millisecond), }, nil } -func browserHelperNodeEnv() []string { - candidates := []string{} - for _, candidate := range browserHelperNodeModuleCandidates() { - if strings.TrimSpace(candidate) == "" { - continue - } - if stat, err := os.Stat(candidate); err == nil && stat.IsDir() { - candidates = append(candidates, candidate) - } - } - if len(candidates) == 0 { - return nil - } - joined := strings.Join(candidates, string(os.PathListSeparator)) - if existing := strings.TrimSpace(os.Getenv("NODE_PATH")); existing != "" { - joined = existing + string(os.PathListSeparator) + joined - } - return []string{"NODE_PATH=" + joined} -} - -func browserHelperNodeEnvForConfig(cfg AppConfig) []string { - base := browserHelperNodeEnv() - size := strings.TrimSpace(os.Getenv(browserHelperPoolSizeEnvKey)) - if size != "" { - return append(base, browserHelperPoolSizeEnvKey+"="+size) - } - if cfg.Browser.HelperPoolSize <= 1 { - return base - } - sizeNum := configuredBrowserHelperPoolSize(cfg) - if sizeNum <= 1 { - return base - } - return append(base, browserHelperPoolSizeEnvKey+"="+strconv.Itoa(sizeNum)) -} - -func configuredBrowserHelperPoolSize(cfg AppConfig) int { - if cfg.Browser.HelperPoolSize > 0 { - if cfg.Browser.HelperPoolSize > maxBrowserHelperPoolSize { - return maxBrowserHelperPoolSize - } - return cfg.Browser.HelperPoolSize - } - size := runtime.NumCPU() - if size < 1 { - size = 1 - } - if size > maxBrowserHelperPoolSize { - size = maxBrowserHelperPoolSize - } - return size -} - -func browserHelperNodeModuleCandidates() []string { - candidates := []string{ - os.Getenv("NODE_PATH"), - "/opt/notion2api-helper/node_modules", - } - if cwd, err := os.Getwd(); err == nil { - candidates = append(candidates, filepath.Join(cwd, "node_modules")) - } - if executable, err := os.Executable(); err == nil { - candidates = append(candidates, filepath.Join(filepath.Dir(executable), "node_modules")) - } - return splitPathListCandidates(candidates) -} - -func splitPathListCandidates(values []string) []string { - candidates := []string{} - for _, value := range values { - for _, item := range filepath.SplitList(strings.TrimSpace(value)) { - if strings.TrimSpace(item) == "" { - continue - } - candidates = append(candidates, item) - } - } - return candidates -} - func (c *NotionAIClient) supportsBrowserRunInferenceFallback() bool { if c == nil { return false diff --git a/internal/app/notion_client_login_transport.go b/internal/app/notion_client_login_transport.go index d4e11a8..139cfab 100644 --- a/internal/app/notion_client_login_transport.go +++ b/internal/app/notion_client_login_transport.go @@ -2,16 +2,14 @@ package app import ( "context" - "encoding/json" "fmt" "net/http" "net/url" - "os/exec" "strings" "time" ) -type loginWreqRequest struct { +type loginTransportRequest struct { Method string `json:"method"` URL string `json:"url"` Headers map[string]string `json:"headers"` @@ -22,7 +20,7 @@ type loginWreqRequest struct { RequestTimeoutMS int `json:"request_timeout_ms"` } -type loginWreqResponse struct { +type loginTransportResponse struct { Status int `json:"status"` ContentType string `json:"content_type"` Headers map[string]string `json:"headers"` @@ -32,47 +30,38 @@ type loginWreqResponse struct { type loginHTTPSession struct { *http.Client - ProxyResolver *ProxyResolver - AccountEmail string - Timeout time.Duration - Upstream NotionUpstream + ProxyResolver *ProxyResolver + AccountEmail string + Timeout time.Duration + Upstream NotionUpstream + UseSurfHelperTransport bool } -func runLoginHelperRequest(ctx context.Context, request loginWreqRequest) (*loginWreqResponse, error) { - if _, err := exec.LookPath("node"); err != nil { - return nil, &browserHelperUnavailableError{Message: "node not found"} - } - requestPayload, err := json.Marshal(request) - if err != nil { - return nil, err - } - stdoutBytes, err := executeHelperSubprocess(ctx, "node", ".cjs", nodeWreqLoginHelperScript(), requestPayload, browserHelperNodeEnv()) - if err != nil { - return nil, err - } - var response loginWreqResponse - if err := json.Unmarshal(stdoutBytes, &response); err != nil { - return nil, fmt.Errorf("login helper returned invalid json: %w", err) - } - return &response, nil -} +var ( + loginTransportRunSurfRequest = runLoginHelperRequestWithSurf + loginTransportRunFallbackRequest = runLoginHelperRequestWithSurf +) -func loginWreqDoRequest(ctx context.Context, session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) (int, http.Header, []byte, error) { +func loginTransportDoRequest(ctx context.Context, session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) (int, http.Header, []byte, error) { if session == nil { return 0, nil, nil, fmt.Errorf("login session is nil") } - request := buildLoginWreqRequest(session, method, targetURL, headers, body) - resp, err := runLoginHelperRequest(ctx, request) + request := buildLoginTransportRequest(session, method, targetURL, headers, body) + var ( + resp *loginTransportResponse + err error + ) + resp, err = loginTransportRunSurfRequest(ctx, request) if err != nil { return 0, nil, nil, err } if session.Client != nil { - applyLoginWreqSetCookies(session.Jar, targetURL, resp.SetCookies) + applyLoginTransportSetCookies(session.Jar, targetURL, resp.SetCookies) } - return resp.Status, loginWreqHTTPHeader(resp.Headers), []byte(resp.Body), nil + return resp.Status, loginTransportHTTPHeader(resp.Headers), []byte(resp.Body), nil } -func buildLoginWreqRequest(session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) loginWreqRequest { +func buildLoginTransportRequest(session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) loginTransportRequest { cookies := []ProbeCookie{} if session != nil && session.Client != nil { cookies = probeCookiesFromJar(session.Jar, targetURL) @@ -99,19 +88,19 @@ func buildLoginWreqRequest(session *loginHTTPSession, method string, targetURL s } cleanHeaders[k] = v } - return loginWreqRequest{ + return loginTransportRequest{ Method: strings.ToUpper(strings.TrimSpace(method)), URL: targetURL, Headers: cleanHeaders, Body: string(body), Cookies: cookies, - BrowserProfile: notionWreqDefaultBrowserProfile, + BrowserProfile: notionTransportDefaultBrowserProfile, Proxy: proxyValue, RequestTimeoutMS: timeoutMS, } } -func applyLoginWreqSetCookies(jar http.CookieJar, targetURL string, setCookies []ProbeCookie) { +func applyLoginTransportSetCookies(jar http.CookieJar, targetURL string, setCookies []ProbeCookie) { if jar == nil || len(setCookies) == 0 { return } @@ -132,7 +121,7 @@ func applyLoginWreqSetCookies(jar http.CookieJar, targetURL string, setCookies [ } } -func loginWreqHTTPHeader(headers map[string]string) http.Header { +func loginTransportHTTPHeader(headers map[string]string) http.Header { out := http.Header{} for k, v := range headers { out.Set(k, v) diff --git a/internal/app/notion_client_surf_transport.go b/internal/app/notion_client_surf_transport.go new file mode 100644 index 0000000..96da400 --- /dev/null +++ b/internal/app/notion_client_surf_transport.go @@ -0,0 +1,207 @@ +package app + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/enetx/g" + "github.com/enetx/surf" +) + +func newSurfStdClient(proxy string) (*http.Client, error) { + builder := surf.NewClient().Builder().Session().Impersonate().Chrome() + if strings.TrimSpace(proxy) != "" { + builder = builder.Proxy(g.String(proxy)) + } + clientResult := builder.Build() + if err := clientResult.Err(); err != nil { + return nil, err + } + return clientResult.Unwrap().Std(), nil +} + +func loadProbeCookiesIntoJar(jar http.CookieJar, target *url.URL, cookies []ProbeCookie) { + if jar == nil || target == nil || len(cookies) == 0 { + return + } + items := make([]*http.Cookie, 0, len(cookies)) + for _, c := range cookies { + name := strings.TrimSpace(c.Name) + if name == "" { + continue + } + items = append(items, &http.Cookie{ + Name: name, + Value: c.Value, + Path: "/", + }) + } + if len(items) > 0 { + jar.SetCookies(target, items) + } +} + +func runLoginHelperRequestWithSurf(ctx context.Context, request loginTransportRequest) (*loginTransportResponse, error) { + if ctx == nil { + ctx = context.Background() + } + stdClient, err := newSurfStdClient(request.Proxy) + if err != nil { + return nil, err + } + + timeout := time.Duration(request.RequestTimeoutMS) * time.Millisecond + if timeout < 30*time.Second { + timeout = 30 * time.Second + } + stdClient.Timeout = timeout + + parsedTargetURL, err := url.Parse(request.URL) + if err != nil { + return nil, err + } + loadProbeCookiesIntoJar(stdClient.Jar, parsedTargetURL, request.Cookies) + + var body io.Reader + if request.Body != "" { + body = bytes.NewBufferString(request.Body) + } + + method := strings.ToUpper(strings.TrimSpace(request.Method)) + if method == "" { + method = http.MethodGet + } + httpReq, err := http.NewRequestWithContext(ctx, method, parsedTargetURL.String(), body) + if err != nil { + return nil, err + } + for k, v := range request.Headers { + if strings.EqualFold(strings.TrimSpace(k), "cookie") { + continue + } + httpReq.Header.Set(k, v) + } + + resp, err := stdClient.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + out := &loginTransportResponse{ + Status: resp.StatusCode, + ContentType: resp.Header.Get("Content-Type"), + Headers: map[string]string{}, + Body: string(respBody), + SetCookies: []ProbeCookie{}, + } + for k, values := range resp.Header { + if strings.EqualFold(k, "set-cookie") || len(values) == 0 { + continue + } + out.Headers[strings.ToLower(k)] = values[len(values)-1] + } + + for _, c := range resp.Cookies() { + name := strings.TrimSpace(c.Name) + if name == "" { + continue + } + out.SetCookies = append(out.SetCookies, ProbeCookie{ + Name: name, + Value: c.Value, + }) + } + + // Preserve effective cookies after redirects by reading from the jar. + if stdClient.Jar != nil { + jarCookies := stdClient.Jar.Cookies(parsedTargetURL) + if len(jarCookies) > 0 { + out.SetCookies = out.SetCookies[:0] + for _, c := range jarCookies { + name := strings.TrimSpace(c.Name) + if name == "" { + continue + } + out.SetCookies = append(out.SetCookies, ProbeCookie{ + Name: name, + Value: c.Value, + }) + } + } + } + + return out, nil +} + +func runInferenceTranscriptInBrowserWithSurf(ctx context.Context, client *NotionAIClient, payload map[string]any) (string, error) { + if ctx == nil { + ctx = context.Background() + } + request, err := buildBrowserTransportRequest(client, payload) + if err != nil { + return "", err + } + stdClient, err := newSurfStdClient(request.Proxy) + if err != nil { + return "", err + } + + timeout := time.Duration(request.RequestTimeoutMS) * time.Millisecond + if timeout <= 0 { + timeout = notionTransportDefaultRequestTimeout + } + stdClient.Timeout = timeout + + parsedRunURL, err := url.Parse(request.RunURL) + if err != nil { + return "", err + } + loadProbeCookiesIntoJar(stdClient.Jar, parsedRunURL, request.Cookies) + + requestBody, err := json.Marshal(request.Payload) + if err != nil { + return "", err + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, parsedRunURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return "", err + } + for k, v := range request.Headers { + if strings.EqualFold(strings.TrimSpace(k), "cookie") { + continue + } + httpReq.Header.Set(k, v) + } + + resp, err := stdClient.Do(httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("browser fallback returned non-success status=%d content_type=%q", resp.StatusCode, resp.Header.Get("Content-Type")) + } + text := string(respBody) + if err := detectInferenceStreamResponseFormat(text); err != nil { + return "", err + } + return text, nil +} diff --git a/internal/app/notion_client_surf_transport_test.go b/internal/app/notion_client_surf_transport_test.go new file mode 100644 index 0000000..026588b --- /dev/null +++ b/internal/app/notion_client_surf_transport_test.go @@ -0,0 +1,251 @@ +package app + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestRunLoginHelperRequestWithSurf_MapsStatusHeadersBodyAndSetCookies(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Method; got != http.MethodPost { + t.Fatalf("method = %s, want POST", got) + } + if got := r.Header.Get("X-Test"); got != "ok" { + t.Fatalf("X-Test header = %q, want ok", got) + } + http.SetCookie(w, &http.Cookie{Name: "token_v2", Value: "new-value", Path: "/"}) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + resp, err := runLoginHelperRequestWithSurf(context.Background(), loginTransportRequest{ + Method: http.MethodPost, + URL: server.URL, + Headers: map[string]string{"X-Test": "ok"}, + Body: `{"hello":"world"}`, + RequestTimeoutMS: 30000, + }) + if err != nil { + t.Fatalf("runLoginHelperRequestWithSurf error: %v", err) + } + if resp.Status != http.StatusCreated { + t.Fatalf("status = %d, want %d", resp.Status, http.StatusCreated) + } + if !strings.Contains(strings.ToLower(resp.ContentType), "application/json") { + t.Fatalf("content_type = %q", resp.ContentType) + } + if strings.TrimSpace(resp.Body) != `{"ok":true}` { + t.Fatalf("body = %q", resp.Body) + } + if len(resp.SetCookies) == 0 || resp.SetCookies[0].Name != "token_v2" { + t.Fatalf("set_cookies = %#v", resp.SetCookies) + } +} + +func TestRunLoginHelperRequestWithSurf_ContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := runLoginHelperRequestWithSurf(ctx, loginTransportRequest{ + Method: http.MethodGet, + URL: "https://example.com", + RequestTimeoutMS: 30000, + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context.Canceled", err) + } +} + +func TestRunLoginHelperRequestWithSurf_PreservesRedirectSetCookies(t *testing.T) { + const cookieName = "redirect_token" + const cookieValue = "set-on-redirect-hop" + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: cookieValue, Path: "/"}) + http.Redirect(w, r, server.URL+"/final", http.StatusFound) + }) + mux.HandleFunc("/final", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("done")) + }) + + resp, err := runLoginHelperRequestWithSurf(context.Background(), loginTransportRequest{ + Method: http.MethodGet, + URL: server.URL + "/start", + RequestTimeoutMS: 30000, + }) + if err != nil { + t.Fatalf("runLoginHelperRequestWithSurf error: %v", err) + } + if resp.Status != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.Status, http.StatusOK) + } + if got := probeCookieValue(resp.SetCookies, cookieName); got != cookieValue { + t.Fatalf("redirect cookie mismatch: got %q want %q, set_cookies=%#v", got, cookieValue, resp.SetCookies) + } +} + +func TestLoginTransportDoRequest_UsesSurfTransport(t *testing.T) { + origSurf := loginTransportRunSurfRequest + origFallback := loginTransportRunFallbackRequest + defer func() { + loginTransportRunSurfRequest = origSurf + loginTransportRunFallbackRequest = origFallback + }() + + surfHits := 0 + fallbackHits := 0 + loginTransportRunSurfRequest = func(_ context.Context, _ loginTransportRequest) (*loginTransportResponse, error) { + surfHits++ + return &loginTransportResponse{ + Status: http.StatusCreated, + Headers: map[string]string{"x-transport": "surf"}, + Body: "surf", + SetCookies: []ProbeCookie{{Name: "token_v2", Value: "surf"}}, + }, nil + } + loginTransportRunFallbackRequest = func(_ context.Context, _ loginTransportRequest) (*loginTransportResponse, error) { + fallbackHits++ + return &loginTransportResponse{ + Status: http.StatusAccepted, + Headers: map[string]string{"x-transport": "fallback"}, + Body: "fallback", + SetCookies: []ProbeCookie{{Name: "token_v2", Value: "fallback"}}, + }, nil + } + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("cookiejar.New error: %v", err) + } + session := &loginHTTPSession{ + Client: &http.Client{Jar: jar}, + UseSurfHelperTransport: true, + ProxyResolver: nil, + AccountEmail: "tester@example.com", + Timeout: 30 * time.Second, + Upstream: NotionUpstream{}, + } + + targetURL := "https://example.com/login" + status, headers, body, err := loginTransportDoRequest(context.Background(), session, http.MethodGet, targetURL, map[string]string{"X-Test": "1"}, nil) + if err != nil { + t.Fatalf("loginTransportDoRequest error: %v", err) + } + if status != http.StatusCreated { + t.Fatalf("status = %d, want %d", status, http.StatusCreated) + } + if got := headers.Get("x-transport"); got != "surf" { + t.Fatalf("x-transport = %q, want %q", got, "surf") + } + if got := string(body); got != "surf" { + t.Fatalf("body = %q, want %q", got, "surf") + } + if surfHits != 1 { + t.Fatalf("surf branch hits mismatch: got %d want 1", surfHits) + } + if fallbackHits != 0 { + t.Fatalf("fallback branch should stay unused, got hits=%d", fallbackHits) + } + if got := probeCookieValue(probeCookiesFromJar(session.Jar, targetURL), "token_v2"); got != "surf" { + t.Fatalf("session jar token_v2 = %q, want %q", got, "surf") + } +} + +func TestLoginTransportDoRequest_SurfPreservesRedirectCookiesInSessionJar(t *testing.T) { + const cookieName = "redirect_token" + const cookieValue = "persisted" + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: cookieValue, Path: "/"}) + http.Redirect(w, r, server.URL+"/final", http.StatusFound) + }) + mux.HandleFunc("/final", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("cookiejar.New error: %v", err) + } + session := &loginHTTPSession{ + Client: &http.Client{Jar: jar}, + UseSurfHelperTransport: true, + Timeout: 30 * time.Second, + } + + targetURL := server.URL + "/start" + status, _, _, err := loginTransportDoRequest(context.Background(), session, http.MethodGet, targetURL, nil, nil) + if err != nil { + t.Fatalf("loginTransportDoRequest error: %v", err) + } + if status != http.StatusOK { + t.Fatalf("status = %d, want %d", status, http.StatusOK) + } + if got := probeCookieValue(probeCookiesFromJar(session.Jar, targetURL), cookieName); got != cookieValue { + t.Fatalf("session jar redirect cookie mismatch: got %q want %q", got, cookieValue) + } +} + +func TestRunInferenceTranscriptInBrowserWithSurf_ReturnsNDJSON(t *testing.T) { + line := `{"type":"agent-inference","id":"m1","finishedAt":"2026-05-03T00:00:00Z","value":[{"type":"text","content":"OK"}]}` + "\n" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Method; got != http.MethodPost { + t.Fatalf("method = %s, want POST", got) + } + if got := strings.TrimSpace(r.Header.Get("Cookie")); got == "" { + t.Fatalf("expected cookie header to be present") + } + var payload map[string]any + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode payload failed: %v", err) + } + if got := strings.TrimSpace(stringValue(payload["threadId"])); got != "t1" { + t.Fatalf("threadId = %q, want t1", got) + } + w.Header().Set("Content-Type", "application/x-ndjson") + _, _ = w.Write([]byte(line)) + })) + defer server.Close() + + client := newBrowserFallbackTestClient(server.URL) + body, err := runInferenceTranscriptInBrowserWithSurf(context.Background(), client, map[string]any{"threadId": "t1"}) + if err != nil { + t.Fatalf("runInferenceTranscriptInBrowserWithSurf error: %v", err) + } + if body != line { + t.Fatalf("body mismatch: got %q want %q", body, line) + } +} + +func TestRunInferenceTranscriptInBrowserWithSurf_RejectsHTMLChallenge(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte("cloudflare cookiePart challenge")) + })) + defer server.Close() + + client := newBrowserFallbackTestClient(server.URL) + _, err := runInferenceTranscriptInBrowserWithSurf(context.Background(), client, map[string]any{"threadId": "t1"}) + if err == nil || !strings.Contains(err.Error(), "challenge/html content") { + t.Fatalf("unexpected err: %v", err) + } +} diff --git a/internal/app/notion_client_wreq_transport.go b/internal/app/notion_client_wreq_transport.go deleted file mode 100644 index 22d5129..0000000 --- a/internal/app/notion_client_wreq_transport.go +++ /dev/null @@ -1,19 +0,0 @@ -package app - -import _ "embed" - -var ( - //go:embed assets/browser-helper.cjs - embeddedNodeWreqHelperScript string - - //go:embed assets/browser-login-helper.cjs - embeddedNodeWreqLoginHelperScript string -) - -func nodeWreqHelperScript() string { - return embeddedNodeWreqHelperScript -} - -func nodeWreqLoginHelperScript() string { - return embeddedNodeWreqLoginHelperScript -} diff --git a/internal/app/request_dispatch.go b/internal/app/request_dispatch.go index 5b364bd..b8f152f 100644 --- a/internal/app/request_dispatch.go +++ b/internal/app/request_dispatch.go @@ -18,7 +18,7 @@ const ( var errDispatchCapacityExceeded = errors.New("dispatch capacity exceeded") -var wreqClientNewTotalMetric = expvar.NewMap("notion2api_wreq_client_new_total") +var transportClientNewTotalMetric = expvar.NewMap("notion2api_transport_client_new_total") type probeCacheEntry struct { lastChecked time.Time diff --git a/internal/app/session_refresh.go b/internal/app/session_refresh.go index 770fb36..013b981 100644 --- a/internal/app/session_refresh.go +++ b/internal/app/session_refresh.go @@ -86,7 +86,7 @@ func loadSessionInfoForAccountRefresh(cfg AppConfig, account NotionAccount) (Ses func buildRefreshedSession(ctx context.Context, cfg AppConfig, account NotionAccount, prior SessionInfo) (SessionInfo, error) { upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, account.Email) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, account.Email, cfg) if err != nil { return SessionInfo{}, err } diff --git a/internal/wreq/doc.go b/internal/wreq/doc.go deleted file mode 100644 index 614a60b..0000000 --- a/internal/wreq/doc.go +++ /dev/null @@ -1 +0,0 @@ -package wreq diff --git a/internal/wreq/wreq_cgo.go b/internal/wreq/wreq_cgo.go deleted file mode 100644 index 82e9ea3..0000000 --- a/internal/wreq/wreq_cgo.go +++ /dev/null @@ -1,288 +0,0 @@ -//go:build wreq_ffi - -package wreq - -/* -#cgo CFLAGS: -I${SRCDIR}/../../wreq-ffi/include -#cgo windows LDFLAGS: ${SRCDIR}/../../wreq-ffi/target/release/libwreq_ffi.a -#cgo !windows LDFLAGS: ${SRCDIR}/../../wreq-ffi/target/release/libwreq_ffi.a -ldl -lm -lpthread - -#include -#include "wreq_ffi.h" - -// Keep explicit forward declarations here so Go-side symbol binding remains stable -// even when a local generated header is stale/missing newer prototypes. -typedef struct WreqResponseHandle WreqResponseHandle; -int32_t wreq_request_begin(struct WreqClient *client, - const uint8_t *spec_json, - size_t spec_len, - const uint8_t *body_ptr, - size_t body_len, - struct WreqResponseHandle **out_handle, - uint16_t *out_status, - char **out_headers_json, - char **out_final_url, - char **out_error); -intptr_t wreq_response_read(struct WreqResponseHandle *handle, - uint8_t *buf, - size_t cap, - uint32_t timeout_ms); -void wreq_response_close(struct WreqResponseHandle *handle); -*/ -import "C" - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "runtime" - "strings" - "sync/atomic" - "unsafe" -) - -type ClientConfig struct { - Emulation string `json:"emulation,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` - CookieStore *bool `json:"cookie_store,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - AcceptInvalidCerts *bool `json:"accept_invalid_certs,omitempty"` -} - -type RequestSpec struct { - Method string `json:"method"` - URL string `json:"url"` - Headers [][]string `json:"headers,omitempty"` - Body []byte `json:"-"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` -} - -type Response struct { - Status int - Headers [][]string - FinalURL string - - handle *C.struct_WreqResponseHandle - closed atomic.Bool -} - -const ( - wreqOK int32 = 0 - wreqErrNilArg int32 = -1 - wreqErrTimeout int32 = -9 -) - -func errorFromCode(where string, code int32, detail string) error { - if code == wreqOK { - return nil - } - if trimmed := strings.TrimSpace(detail); trimmed != "" { - return fmt.Errorf("wreq: %s failed (code=%d): %s", where, code, trimmed) - } - return fmt.Errorf("wreq: %s failed (code=%d)", where, code) -} - -type Client struct { - handle *C.struct_WreqClient - closed atomic.Bool -} - -func New(cfg ClientConfig) (*Client, error) { - profile, err := json.Marshal(cfg) - if err != nil { - return nil, fmt.Errorf("wreq: marshal config: %w", err) - } - cProfile := C.CString(string(profile)) - defer C.free(unsafe.Pointer(cProfile)) - - handle := C.wreq_client_new(cProfile) - if handle == nil { - return nil, errors.New("wreq: wreq_client_new returned NULL (bad config?)") - } - c := &Client{handle: handle} - runtime.SetFinalizer(c, func(c *Client) { _ = c.Close() }) - return c, nil -} - -func (c *Client) Close() error { - if c == nil || !c.closed.CompareAndSwap(false, true) { - return nil - } - C.wreq_client_free(c.handle) - c.handle = nil - runtime.SetFinalizer(c, nil) - return nil -} - -func (c *Client) Begin(ctx context.Context, spec RequestSpec) (*Response, error) { - if c == nil || c.handle == nil || c.closed.Load() { - return nil, errors.New("wreq: client closed") - } - if err := ctx.Err(); err != nil { - return nil, err - } - - specPayload := struct { - Method string `json:"method"` - URL string `json:"url"` - Headers [][]string `json:"headers,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` - }{ - Method: spec.Method, - URL: spec.URL, - Headers: spec.Headers, - TimeoutSecs: spec.TimeoutSecs, - } - specJSON, err := json.Marshal(specPayload) - if err != nil { - return nil, fmt.Errorf("wreq: marshal request: %w", err) - } - - var specPtr *C.uint8_t - if len(specJSON) > 0 { - specPtr = (*C.uint8_t)(unsafe.Pointer(&specJSON[0])) - } - var bodyPtr *C.uint8_t - if len(spec.Body) > 0 { - bodyPtr = (*C.uint8_t)(unsafe.Pointer(&spec.Body[0])) - } - - var cHandle *C.struct_WreqResponseHandle - var cStatus C.uint16_t - var cHeaders *C.char - var cFinalURL *C.char - var cErr *C.char - - code := int32(C.wreq_request_begin( - c.handle, - specPtr, - C.size_t(len(specJSON)), - bodyPtr, - C.size_t(len(spec.Body)), - &cHandle, - &cStatus, - &cHeaders, - &cFinalURL, - &cErr, - )) - - var detail string - if cErr != nil { - detail = C.GoString(cErr) - C.wreq_string_free(cErr) - } - if code != wreqOK { - if cHeaders != nil { - C.wreq_string_free(cHeaders) - } - if cFinalURL != nil { - C.wreq_string_free(cFinalURL) - } - return nil, errorFromCode("wreq_request_begin", code, detail) - } - if cHandle == nil { - if cHeaders != nil { - C.wreq_string_free(cHeaders) - } - if cFinalURL != nil { - C.wreq_string_free(cFinalURL) - } - return nil, errors.New("wreq: begin returned nil response handle") - } - - resp := &Response{ - Status: int(cStatus), - handle: cHandle, - } - if cHeaders != nil { - headersJSON := C.GoString(cHeaders) - C.wreq_string_free(cHeaders) - if strings.TrimSpace(headersJSON) != "" { - if err := json.Unmarshal([]byte(headersJSON), &resp.Headers); err != nil { - _ = resp.Close() - return nil, fmt.Errorf("wreq: decode headers json: %w", err) - } - } - } - if cFinalURL != nil { - resp.FinalURL = C.GoString(cFinalURL) - C.wreq_string_free(cFinalURL) - } - - runtime.SetFinalizer(resp, func(r *Response) { _ = r.Close() }) - return resp, nil -} - -func (r *Response) Read(p []byte) (int, error) { - if r == nil { - return 0, errors.New("wreq: response is nil") - } - if r.handle == nil { - return 0, io.EOF - } - if len(p) == 0 { - return 0, nil - } - n := int64(C.wreq_response_read( - r.handle, - (*C.uint8_t)(unsafe.Pointer(&p[0])), - C.size_t(len(p)), - C.uint32_t(0), - )) - if n > 0 { - return int(n), nil - } - if n == 0 { - return 0, io.EOF - } - code := int32(n) - if code == wreqErrTimeout { - return 0, context.DeadlineExceeded - } - if code == wreqErrNilArg { - return 0, errors.New("wreq: invalid read argument") - } - return 0, errorFromCode("wreq_response_read", code, "") -} - -func (r *Response) Close() error { - if r == nil || !r.closed.CompareAndSwap(false, true) { - return nil - } - if r.handle != nil { - C.wreq_response_close(r.handle) - r.handle = nil - } - runtime.SetFinalizer(r, nil) - return nil -} - -func (r *Response) Body() ([]byte, error) { - if r == nil { - return nil, errors.New("wreq: response is nil") - } - body, err := io.ReadAll(r) - if err != nil && !errors.Is(err, io.EOF) { - _ = r.Close() - return nil, err - } - _ = r.Close() - return body, nil -} - -func (c *Client) Do(ctx context.Context, spec RequestSpec) (*Response, error) { - resp, err := c.Begin(ctx, spec) - if err != nil { - return nil, err - } - if _, err := resp.Body(); err != nil { - return nil, err - } - return resp, nil -} - -func Version() string { - return C.GoString(C.wreq_ffi_version()) -} diff --git a/internal/wreq/wreq_ffi_compat.h b/internal/wreq/wreq_ffi_compat.h deleted file mode 100644 index ebbbb73..0000000 --- a/internal/wreq/wreq_ffi_compat.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef WREQ_FFI_COMPAT_H -#define WREQ_FFI_COMPAT_H - -#include -#include - -typedef struct WreqClient WreqClient; -typedef struct WreqResponseHandle WreqResponseHandle; - -int32_t wreq_request_begin(struct WreqClient *client, - const uint8_t *spec_json, - size_t spec_len, - const uint8_t *body_ptr, - size_t body_len, - struct WreqResponseHandle **out_handle, - uint16_t *out_status, - char **out_headers_json, - char **out_final_url, - char **out_error); -intptr_t wreq_response_read(struct WreqResponseHandle *handle, - uint8_t *buf, - size_t cap, - uint32_t timeout_ms); -void wreq_response_close(struct WreqResponseHandle *handle); - -#endif /* WREQ_FFI_COMPAT_H */ diff --git a/internal/wreq/wreq_streaming_stub_test.go b/internal/wreq/wreq_streaming_stub_test.go deleted file mode 100644 index e262a84..0000000 --- a/internal/wreq/wreq_streaming_stub_test.go +++ /dev/null @@ -1,30 +0,0 @@ -//go:build !wreq_ffi - -package wreq - -import ( - "errors" - "io" - "testing" -) - -func TestWreqStubBeginNotLinked(t *testing.T) { - client, err := New(ClientConfig{}) - if err == nil || client != nil { - t.Fatalf("expected stub New to fail with ErrNotLinked") - } - if !errors.Is(err, ErrNotLinked) { - t.Fatalf("expected ErrNotLinked, got %v", err) - } -} - -func TestWreqStubResponseReadEOF(t *testing.T) { - var r Response - n, err := r.Read(make([]byte, 16)) - if n != 0 { - t.Fatalf("expected n=0, got %d", n) - } - if !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF, got %v", err) - } -} diff --git a/internal/wreq/wreq_streaming_test.go b/internal/wreq/wreq_streaming_test.go deleted file mode 100644 index 8b99e2b..0000000 --- a/internal/wreq/wreq_streaming_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package wreq - -import ( - "bufio" - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" -) - -func ndjsonLine(payload map[string]any) string { - raw, _ := json.Marshal(payload) - return string(raw) + "\n" -} - -func TestWreqStreamingLatencyShapeWithHTTPFallback(t *testing.T) { - firstWriteCh := make(chan time.Time, 1) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/x-ndjson") - flusher, _ := w.(http.Flusher) - firstWriteAt := time.Now() - firstWriteCh <- firstWriteAt - _, _ = w.Write([]byte(ndjsonLine(map[string]any{ - "type": "agent-inference", - "value": []map[string]any{ - {"type": "text", "content": "chunk-1"}, - }, - }))) - if flusher != nil { - flusher.Flush() - } - - time.Sleep(100 * time.Millisecond) - - _, _ = w.Write([]byte(ndjsonLine(map[string]any{ - "type": "agent-inference", - "value": []map[string]any{ - {"type": "text", "content": "chunk-2"}, - }, - "finishedAt": time.Now().Format(time.RFC3339Nano), - }))) - if flusher != nil { - flusher.Flush() - } - })) - defer upstream.Close() - - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, upstream.URL, strings.NewReader("{}")) - if err != nil { - t.Fatalf("new request: %v", err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("http do: %v", err) - } - defer resp.Body.Close() - - reader := bufio.NewReader(resp.Body) - line1, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("read first line: %v", err) - } - if strings.TrimSpace(line1) == "" { - t.Fatalf("first line is empty") - } - firstReadAt := time.Now() - firstWriteAt := <-firstWriteCh - - delay := firstReadAt.Sub(firstWriteAt) - if delay > 50*time.Millisecond { - t.Fatalf("first chunk delay too high: got %s want <= 50ms", delay) - } - - line2, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("read second line: %v", err) - } - if !strings.Contains(line2, "chunk-2") { - t.Fatalf("unexpected second line: %q", line2) - } -} diff --git a/internal/wreq/wreq_stub.go b/internal/wreq/wreq_stub.go deleted file mode 100644 index ab5910c..0000000 --- a/internal/wreq/wreq_stub.go +++ /dev/null @@ -1,53 +0,0 @@ -//go:build !wreq_ffi - -package wreq - -import ( - "context" - "errors" - "io" -) - -var ErrNotLinked = errors.New("wreq: built without wreq_ffi tag; use node-wreq fallback") - -type ClientConfig struct { - Emulation string `json:"emulation,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` - CookieStore *bool `json:"cookie_store,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - AcceptInvalidCerts *bool `json:"accept_invalid_certs,omitempty"` -} - -type RequestSpec struct { - Method string `json:"method"` - URL string `json:"url"` - Headers [][]string `json:"headers,omitempty"` - Body []byte `json:"-"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` -} - -type Response struct { - Status int - Headers [][]string - FinalURL string -} - -func (r *Response) Read(_ []byte) (int, error) { return 0, io.EOF } -func (r *Response) Body() ([]byte, error) { return nil, ErrNotLinked } -func (r *Response) Close() error { return nil } - -type Client struct{} - -func New(_ ClientConfig) (*Client, error) { return nil, ErrNotLinked } - -func (c *Client) Close() error { return nil } - -func (c *Client) Begin(_ context.Context, _ RequestSpec) (*Response, error) { - return nil, ErrNotLinked -} - -func (c *Client) Do(_ context.Context, _ RequestSpec) (*Response, error) { - return nil, ErrNotLinked -} - -func Version() string { return "unlinked" } diff --git a/wreq-ffi/.gitignore b/wreq-ffi/.gitignore deleted file mode 100644 index bf78f89..0000000 --- a/wreq-ffi/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/target -/include/wreq_ffi.h -Cargo.lock.bak diff --git a/wreq-ffi/Cargo.toml b/wreq-ffi/Cargo.toml deleted file mode 100644 index 0e785b1..0000000 --- a/wreq-ffi/Cargo.toml +++ /dev/null @@ -1,31 +0,0 @@ -[package] -name = "wreq-ffi" -version = "0.1.0" -edition = "2021" -publish = false -license = "Apache-2.0" -description = "C ABI shim around the wreq Rust HTTP client for embedding into Go (notion2api v2)" - -[lib] -name = "wreq_ffi" -crate-type = ["staticlib", "rlib"] - -[dependencies] -wreq = "5.3" -tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "time"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" - -[build-dependencies] -cbindgen = "0.27" - -[profile.release] -opt-level = 3 -lto = "thin" -codegen-units = 16 -strip = "symbols" -panic = "abort" - -[profile.dev] -opt-level = 0 -debug = 1 diff --git a/wreq-ffi/README.md b/wreq-ffi/README.md deleted file mode 100644 index 4c6b9b5..0000000 --- a/wreq-ffi/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# wreq-ffi - -C ABI shim around the [`wreq`](https://crates.io/crates/wreq) Rust HTTP client, -built as a `staticlib` for embedding into the Go `notion2api` binary via cgo. -Replaces the v1 Node.js helper subprocess (`feat/node-wreq-fallback`) with a -single statically-linked process. - -## Build - -```sh -cargo build --release -# produces target/release/libwreq_ffi.a + include/wreq_ffi.h -``` - -For production Linux containers, build against musl for a fully static binary: - -```sh -rustup target add x86_64-unknown-linux-musl -cargo build --release --target x86_64-unknown-linux-musl -``` - -## C surface - -See `include/wreq_ffi.h` (regenerated by `build.rs` via `cbindgen`). - -| Function | Purpose | -| --------------------- | ------------------------------------------------------ | -| `wreq_client_new` | Build a client from a JSON profile. | -| `wreq_client_free` | Drop a client. | -| `wreq_request_begin` | Start one request and return status/headers/body handle. | -| `wreq_response_read` | Pull next body chunk into caller-provided buffer. | -| `wreq_response_close` | Release response handle resources. | -| `wreq_string_free` | Free any `*mut c_char` from this lib. | -| `wreq_ffi_version` | Linkage sanity check. | - -### Memory ownership - -Every pointer returned from this crate must be returned to it for -deallocation. **Do not** call `free()` from Go / C — allocators across the -FFI boundary are not interchangeable. - -## JSON shapes - -### Client config (input to `wreq_client_new`) - -```json -{ - "emulation": "chrome131", - "timeout_secs": 30, - "cookie_store": true, - "proxy_url": "http://127.0.0.1:7890", - "accept_invalid_certs": false -} -``` - -`emulation` is currently a no-op placeholder; will be wired to `wreq-util` in a -follow-up commit (tracking: notion2api#v2-1). - -### Request spec (input to `wreq_request_begin` as JSON bytes) - -```json -{ - "method": "POST", - "url": "https://www.notion.so/api/v3/getLoginOptions", - "headers": [["content-type", "application/json"], ["user-agent", "..."]], - "timeout_secs": 60 -} -``` - -Request body bytes are passed separately as raw binary (`body_ptr` + `body_len`) and are **not** base64-encoded. - -### Begin outputs - -On success (`wreq_request_begin` returns `0`), caller receives: - -- `out_status` (HTTP status code), -- `out_headers_json` (JSON encoded `[[name, value], ...]` header list), -- `out_final_url` (effective URL after redirects), -- `out_handle` (opaque response body stream handle). - -### Streaming read contract - -- `wreq_response_read(handle, buf, cap, timeout_ms)`: - - `>0`: bytes written into `buf` - - `0`: EOF - - `<0`: error code -- `wreq_response_close(handle)` must be called once when done. - -### Error reporting - -For `wreq_request_begin`, `out_error` may contain human-readable details (free via `wreq_string_free`) when return code is non-zero. - -## Threading - -A single process-global Tokio multi-thread runtime is created lazily on first -request. cgo callers may invoke `wreq_request` concurrently from many -goroutines. diff --git a/wreq-ffi/build.rs b/wreq-ffi/build.rs deleted file mode 100644 index c8ef2ac..0000000 --- a/wreq-ffi/build.rs +++ /dev/null @@ -1,71 +0,0 @@ - -use std::env; -use std::path::PathBuf; - -fn main() { - let crate_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR set by cargo"); - let crate_dir = PathBuf::from(crate_dir); - let lib_rs = crate_dir.join("src").join("lib.rs"); - let header_path = crate_dir.join("include").join("wreq_ffi.h"); - - if let Some(parent) = header_path.parent() { - std::fs::create_dir_all(parent).expect("create include dir"); - } - - println!("cargo:rerun-if-changed=src/lib.rs"); - println!("cargo:rerun-if-changed=cbindgen.toml"); - println!("cargo:rerun-if-changed=build.rs"); - - let config = cbindgen::Config::from_file(crate_dir.join("cbindgen.toml")) - .unwrap_or_else(|_| cbindgen::Config::default()); - - let bindings = cbindgen::Builder::new() - .with_src(&lib_rs) - .with_config(config) - .generate() - .unwrap_or_else(|e| { - panic!( - "wreq-ffi: cbindgen failed to generate header from {}: {e}", - lib_rs.display() - ) - }); - - bindings.write_to_file(&header_path); - - // Keep repository-local compatibility for cgo compile checks in environments - // where include/wreq_ffi.h is ignored by git and cargo build is not runnable. - if let Ok(workspace_root) = crate_dir.parent().map(std::path::Path::to_path_buf).ok_or(()) { - let compat_header = workspace_root.join("internal").join("wreq").join("wreq_ffi_compat.h"); - let compat_body = r#"#ifndef WREQ_FFI_COMPAT_H -#define WREQ_FFI_COMPAT_H - -#include -#include - -typedef struct WreqClient WreqClient; -typedef struct WreqResponseHandle WreqResponseHandle; - -int32_t wreq_request_begin(struct WreqClient *client, - const uint8_t *spec_json, - size_t spec_len, - const uint8_t *body_ptr, - size_t body_len, - struct WreqResponseHandle **out_handle, - uint16_t *out_status, - char **out_headers_json, - char **out_final_url, - char **out_error); -intptr_t wreq_response_read(struct WreqResponseHandle *handle, - uint8_t *buf, - size_t cap, - uint32_t timeout_ms); -void wreq_response_close(struct WreqResponseHandle *handle); - -#endif /* WREQ_FFI_COMPAT_H */ -"#; - if let Some(parent) = compat_header.parent() { - let _ = std::fs::create_dir_all(parent); - } - let _ = std::fs::write(compat_header, compat_body); - } -} diff --git a/wreq-ffi/cbindgen.toml b/wreq-ffi/cbindgen.toml deleted file mode 100644 index b3effa2..0000000 --- a/wreq-ffi/cbindgen.toml +++ /dev/null @@ -1,17 +0,0 @@ -language = "C" -include_guard = "WREQ_FFI_H" -autogen_warning = "/* Generated by cbindgen. Do not edit by hand. */" -style = "both" -cpp_compat = true - -[parse] -parse_deps = false - -[export] -prefix = "" - -[fn] -prefix = "" - -[enum] -prefix_with_name = true diff --git a/wreq-ffi/src/lib.rs b/wreq-ffi/src/lib.rs deleted file mode 100644 index a86fe8b..0000000 --- a/wreq-ffi/src/lib.rs +++ /dev/null @@ -1,374 +0,0 @@ -use std::ffi::{c_char, CStr, CString}; -use std::panic::{catch_unwind, AssertUnwindSafe}; -use std::ptr; -use std::slice; -use std::sync::{Mutex, OnceLock}; -use std::time::Duration; - -use serde::Deserialize; -use tokio::runtime::Runtime; - -const WREQ_OK: i32 = 0; -const WREQ_ERR_NIL_ARG: i32 = -1; -const WREQ_ERR_BAD_UTF8: i32 = -2; -const WREQ_ERR_BAD_JSON: i32 = -3; -const WREQ_ERR_BAD_METHOD: i32 = -4; -const WREQ_ERR_SEND: i32 = -5; -const WREQ_ERR_HEADERS_JSON: i32 = -6; -const WREQ_ERR_BODY_READ: i32 = -7; -const WREQ_ERR_CLOSED: i32 = -8; -const WREQ_ERR_TIMEOUT: i32 = -9; -const WREQ_ERR_PANIC: i32 = -100; - -static RUNTIME: OnceLock = OnceLock::new(); - -fn runtime() -> &'static Runtime { - RUNTIME.get_or_init(|| { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .thread_name("wreq-ffi") - .build() - .expect("wreq-ffi: failed to create tokio runtime") - }) -} - -pub struct WreqClient { - inner: wreq::Client, -} - -struct WreqResponseState { - response: Option, - pending: Vec, -} - -pub struct WreqResponseHandle { - state: Mutex, -} - -#[derive(Default, Deserialize)] -struct ClientConfig { - #[serde(default)] - emulation: Option, - #[serde(default)] - timeout_secs: Option, - #[serde(default)] - proxy_url: Option, -} - -#[derive(Deserialize)] -struct RequestSpec { - method: String, - url: String, - #[serde(default)] - headers: Vec<(String, String)>, - #[serde(default)] - timeout_secs: Option, -} - -#[inline] -fn c_string_from_message(message: impl AsRef) -> *mut c_char { - match CString::new(message.as_ref()) { - Ok(text) => text.into_raw(), - Err(_) => ptr::null_mut(), - } -} - -#[inline] -unsafe fn clear_out_error(out_error: *mut *mut c_char) { - if !out_error.is_null() { - *out_error = ptr::null_mut(); - } -} - -#[inline] -unsafe fn set_out_error(out_error: *mut *mut c_char, message: impl AsRef) { - if out_error.is_null() { - return; - } - *out_error = c_string_from_message(message); -} - -#[inline] -unsafe fn clear_begin_outputs( - out_handle: *mut *mut WreqResponseHandle, - out_status: *mut u16, - out_headers_json: *mut *mut c_char, - out_final_url: *mut *mut c_char, -) { - if !out_handle.is_null() { - *out_handle = ptr::null_mut(); - } - if !out_status.is_null() { - *out_status = 0; - } - if !out_headers_json.is_null() { - *out_headers_json = ptr::null_mut(); - } - if !out_final_url.is_null() { - *out_final_url = ptr::null_mut(); - } -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_client_new(profile_json: *const c_char) -> *mut WreqClient { - catch_unwind(AssertUnwindSafe(|| { - let cfg: ClientConfig = if profile_json.is_null() { - ClientConfig::default() - } else { - let s = match CStr::from_ptr(profile_json).to_str() { - Ok(s) => s, - Err(_) => return ptr::null_mut::(), - }; - match serde_json::from_str(s) { - Ok(c) => c, - Err(_) => return ptr::null_mut(), - } - }; - - let mut builder = wreq::Client::builder(); - if let Some(secs) = cfg.timeout_secs { - builder = builder.timeout(Duration::from_secs(secs)); - } - if let Some(p) = cfg.proxy_url.as_deref() { - if let Ok(proxy) = wreq::Proxy::all(p) { - builder = builder.proxy(proxy); - } - } - let _ = cfg.emulation; - - match builder.build() { - Ok(client) => Box::into_raw(Box::new(WreqClient { inner: client })), - Err(_) => ptr::null_mut(), - } - })) - .unwrap_or(ptr::null_mut()) -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_client_free(client: *mut WreqClient) { - if !client.is_null() { - drop(Box::from_raw(client)); - } -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_request_begin( - client: *mut WreqClient, - spec_json: *const u8, - spec_len: usize, - body_ptr: *const u8, - body_len: usize, - out_handle: *mut *mut WreqResponseHandle, - out_status: *mut u16, - out_headers_json: *mut *mut c_char, - out_final_url: *mut *mut c_char, - out_error: *mut *mut c_char, -) -> i32 { - clear_out_error(out_error); - clear_begin_outputs(out_handle, out_status, out_headers_json, out_final_url); - - let result = catch_unwind(AssertUnwindSafe(|| { - if client.is_null() - || out_handle.is_null() - || out_status.is_null() - || out_headers_json.is_null() - || out_final_url.is_null() - { - set_out_error(out_error, "nil client or output pointer"); - return WREQ_ERR_NIL_ARG; - } - if spec_json.is_null() && spec_len > 0 { - set_out_error(out_error, "spec_json is null but spec_len > 0"); - return WREQ_ERR_NIL_ARG; - } - if body_ptr.is_null() && body_len > 0 { - set_out_error(out_error, "body_ptr is null but body_len > 0"); - return WREQ_ERR_NIL_ARG; - } - - let spec_bytes = slice::from_raw_parts(spec_json, spec_len); - let spec: RequestSpec = match serde_json::from_slice(spec_bytes) { - Ok(value) => value, - Err(err) => { - set_out_error(out_error, format!("request_json: {err}")); - return WREQ_ERR_BAD_JSON; - } - }; - - let method = match spec.method.parse::() { - Ok(method) => method, - Err(err) => { - set_out_error(out_error, format!("bad method: {err}")); - return WREQ_ERR_BAD_METHOD; - } - }; - - let client = &*client; - let mut req = client.inner.request(method, &spec.url); - for (key, value) in &spec.headers { - req = req.header(key.as_str(), value.as_str()); - } - if let Some(secs) = spec.timeout_secs { - req = req.timeout(Duration::from_secs(secs)); - } - if body_len > 0 { - let body = slice::from_raw_parts(body_ptr, body_len); - req = req.body(body.to_vec()); - } - - let resp = match runtime().block_on(req.send()) { - Ok(response) => response, - Err(err) => { - set_out_error(out_error, format!("send: {err}")); - return WREQ_ERR_SEND; - } - }; - - let status = resp.status().as_u16(); - let final_url = resp.url().to_string(); - let mut headers: Vec<(String, String)> = Vec::with_capacity(resp.headers().len()); - for (key, value) in resp.headers().iter() { - headers.push(( - key.as_str().to_string(), - value.to_str().unwrap_or("").to_string(), - )); - } - - let headers_json = match serde_json::to_string(&headers) { - Ok(raw) => raw, - Err(err) => { - set_out_error(out_error, format!("headers json encode failed: {err}")); - return WREQ_ERR_HEADERS_JSON; - } - }; - let headers_c = match CString::new(headers_json) { - Ok(value) => value, - Err(_) => { - set_out_error(out_error, "headers json contains interior NUL"); - return WREQ_ERR_HEADERS_JSON; - } - }; - let final_url_c = match CString::new(final_url) { - Ok(value) => value, - Err(_) => { - set_out_error(out_error, "final_url contains interior NUL"); - return WREQ_ERR_BAD_UTF8; - } - }; - - let handle = Box::new(WreqResponseHandle { - state: Mutex::new(WreqResponseState { - response: Some(resp), - pending: Vec::new(), - }), - }); - - *out_handle = Box::into_raw(handle); - *out_status = status; - *out_headers_json = headers_c.into_raw(); - *out_final_url = final_url_c.into_raw(); - WREQ_OK - })); - - match result { - Ok(code) => code, - Err(_) => { - set_out_error(out_error, "rust panic in wreq_request_begin"); - WREQ_ERR_PANIC - } - } -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_response_read( - handle: *mut WreqResponseHandle, - buf: *mut u8, - cap: usize, - timeout_ms: u32, -) -> isize { - let result = catch_unwind(AssertUnwindSafe(|| { - if handle.is_null() { - return WREQ_ERR_NIL_ARG as isize; - } - if cap == 0 { - return 0; - } - if buf.is_null() { - return WREQ_ERR_NIL_ARG as isize; - } - - let handle = &*handle; - let mut state = match handle.state.lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - }; - - if !state.pending.is_empty() { - let write_len = state.pending.len().min(cap); - ptr::copy_nonoverlapping(state.pending.as_ptr(), buf, write_len); - state.pending.drain(..write_len); - return write_len as isize; - } - - let response = match state.response.as_mut() { - Some(response) => response, - None => return WREQ_ERR_CLOSED as isize, - }; - - let chunk = if timeout_ms == 0 { - runtime().block_on(response.chunk()) - } else { - match runtime().block_on(tokio::time::timeout( - Duration::from_millis(timeout_ms as u64), - response.chunk(), - )) { - Ok(res) => res, - Err(_) => return WREQ_ERR_TIMEOUT as isize, - } - }; - - match chunk { - Ok(Some(bytes)) => { - let raw = bytes.as_ref(); - let write_len = raw.len().min(cap); - ptr::copy_nonoverlapping(raw.as_ptr(), buf, write_len); - if write_len < raw.len() { - state.pending.extend_from_slice(&raw[write_len..]); - } - write_len as isize - } - Ok(None) => 0, - Err(_) => WREQ_ERR_BODY_READ as isize, - } - })); - - match result { - Ok(n) => n, - Err(_) => WREQ_ERR_PANIC as isize, - } -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_response_close(handle: *mut WreqResponseHandle) { - if handle.is_null() { - return; - } - let mut boxed = Box::from_raw(handle); - if let Ok(state) = boxed.state.get_mut() { - state.pending.clear(); - state.response = None; - } -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_string_free(ptr: *mut c_char) { - if !ptr.is_null() { - drop(CString::from_raw(ptr)); - } -} - -#[no_mangle] -pub extern "C" fn wreq_ffi_version() -> *const c_char { - static VERSION: &[u8] = concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes(); - VERSION.as_ptr() as *const c_char -}