diff --git a/cmd/agentledger/serve.go b/cmd/agentledger/serve.go index 41fa9ca..13031b8 100644 --- a/cmd/agentledger/serve.go +++ b/cmd/agentledger/serve.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "database/sql" "fmt" "log/slog" @@ -110,6 +111,7 @@ func runServe(configPath string) error { }) } budgetMgr = budget.NewManager(store, budgetCfg, logger) + defer budgetMgr.Close() if budgetMgr.Enabled() { logger.Info("budget enforcement enabled") } @@ -269,6 +271,7 @@ func runServe(configPath string) error { }) } limiter = ratelimit.New(rlCfg) + defer limiter.Close() if limiter.Enabled() { logger.Info("rate limiting enabled", "default_rpm", rlCfg.Default.RequestsPerMinute, @@ -323,29 +326,38 @@ func runServe(configPath string) error { // Admin API (optional). if adminStore != nil && cfg.Admin.Token != "" { - adminHandler := admin.NewHandler(adminStore, store, budgetMgr, cfg.Admin.Token, blocklist) + adminHandler := admin.NewHandler(adminStore, store, budgetMgr, cfg.Admin.Token, blocklist, logger) adminHandler.RegisterRoutes(mux) logger.Info("admin API enabled") } if cfg.Dashboard.Enabled { - dashHandler := dashboard.NewHandler(store, tracker) + dashHandler := dashboard.NewHandler(store, tracker, logger) dashHandler.RegisterRoutes(mux) mux.Handle("/", dashboard.StaticHandler()) logger.Info("dashboard enabled") } + // Apply CORS middleware. + handler := corsMiddleware(cfg.CORS.AllowOrigins, mux) + srv := &http.Server{ Addr: cfg.Listen, - Handler: mux, + Handler: handler, ReadHeaderTimeout: 10 * time.Second, } // Graceful shutdown errCh := make(chan error, 1) go func() { - logger.Info("proxy listening", "addr", cfg.Listen) - errCh <- srv.ListenAndServe() + if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" { + srv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + logger.Info("starting HTTPS server", "listen", cfg.Listen) + errCh <- srv.ListenAndServeTLS(cfg.TLS.CertFile, cfg.TLS.KeyFile) + } else { + logger.Info("starting HTTP server", "listen", cfg.Listen) + errCh <- srv.ListenAndServe() + } }() quit := make(chan os.Signal, 1) @@ -372,6 +384,37 @@ func runServe(configPath string) error { return nil } +func corsMiddleware(origins []string, next http.Handler) http.Handler { + if len(origins) == 0 { + return next + } + + allowed := make(map[string]bool, len(origins)) + for _, o := range origins { + allowed[o] = true + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" && allowed[origin] { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + } + + if r.Method == http.MethodOptions { + if origin != "" && allowed[origin] { + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With") + w.Header().Set("Access-Control-Max-Age", "86400") + } + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} + func newLogger(cfg config.LogConfig) *slog.Logger { var level slog.Level switch cfg.Level { diff --git a/configs/agentledger.example.yaml b/configs/agentledger.example.yaml index f9e1d0a..1fe93fd 100644 --- a/configs/agentledger.example.yaml +++ b/configs/agentledger.example.yaml @@ -164,6 +164,17 @@ recording: # enabled: true # token: "your-secret-admin-token" # Bearer token for auth +# CORS (optional — omit to use same-origin only) +# cors: +# allow_origins: +# - "https://dashboard.example.com" +# - "https://admin.example.com" + +# TLS (optional — omit for plain HTTP) +# tls: +# cert_file: "/path/to/cert.pem" +# key_file: "/path/to/key.pem" + # MCP (Model Context Protocol) tool call metering (optional — omit to disable) # mcp: # enabled: true # enable HTTP proxy for MCP servers diff --git a/configs/demo.yaml b/configs/demo.yaml index 75f2aaa..ff7d06b 100644 --- a/configs/demo.yaml +++ b/configs/demo.yaml @@ -31,6 +31,7 @@ agent: admin: enabled: true + # WARNING: Demo token only. Generate a real one: openssl rand -hex 32 token: "demo-admin-token" budgets: diff --git a/docs/assets/favicon.svg b/docs/assets/favicon.svg new file mode 100644 index 0000000..1d2e125 --- /dev/null +++ b/docs/assets/favicon.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/docs/index.md b/docs/index.md index a04086d..a038597 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,37 @@ hide:
-
AL
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
# AgentLedger diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index dd10c8f..7bfe899 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -56,15 +56,8 @@ .al-logo-mark { display: inline-block; - font-family: "JetBrains Mono", monospace; - font-size: 1.4rem; - font-weight: 700; - color: #0a0e14; - background: #3fb950; - border-radius: 8px; - padding: 0.4rem 0.8rem; margin-bottom: 1.5rem; - letter-spacing: 0.05em; + line-height: 0; } .al-hero h1 { diff --git a/internal/admin/admin_test.go b/internal/admin/admin_test.go index 7c7fc8e..c2e7013 100644 --- a/internal/admin/admin_test.go +++ b/internal/admin/admin_test.go @@ -5,6 +5,8 @@ import ( "context" "database/sql" "encoding/json" + "io" + "log/slog" "net/http" "net/http/httptest" "testing" @@ -34,6 +36,10 @@ func setupTestDB(t *testing.T) *sql.DB { return db } +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + func TestStore_GetSetDelete(t *testing.T) { db := setupTestDB(t) s := admin.NewStore(db) @@ -127,7 +133,7 @@ func TestStore_ListAll(t *testing.T) { func TestHandler_RequiresAuth(t *testing.T) { db := setupTestDB(t) store := admin.NewStore(db) - handler := admin.NewHandler(store, nil, nil, "secret-token", nil) + handler := admin.NewHandler(store, nil, nil, "secret-token", nil, testLogger()) mux := http.NewServeMux() handler.RegisterRoutes(mux) @@ -162,7 +168,7 @@ func TestHandler_RequiresAuth(t *testing.T) { func TestHandler_CRUDRules(t *testing.T) { db := setupTestDB(t) store := admin.NewStore(db) - handler := admin.NewHandler(store, nil, nil, "token", nil) + handler := admin.NewHandler(store, nil, nil, "token", nil, testLogger()) mux := http.NewServeMux() handler.RegisterRoutes(mux) @@ -185,6 +191,7 @@ func TestHandler_CRUDRules(t *testing.T) { body, _ := json.Marshal(rule) req = httptest.NewRequest("POST", "/api/admin/budgets/rules", bytes.NewReader(body)) auth(req) + req.Header.Set("X-Requested-With", "XMLHttpRequest") rec = httptest.NewRecorder() mux.ServeHTTP(rec, req) if rec.Code != http.StatusCreated { @@ -205,6 +212,7 @@ func TestHandler_CRUDRules(t *testing.T) { // Delete. req = httptest.NewRequest("DELETE", "/api/admin/budgets/rules?pattern=sk-prod-*", nil) auth(req) + req.Header.Set("X-Requested-With", "XMLHttpRequest") rec = httptest.NewRecorder() mux.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { @@ -223,16 +231,55 @@ func TestHandler_CRUDRules(t *testing.T) { } } +func TestHandler_CSRFProtection(t *testing.T) { + db := setupTestDB(t) + store := admin.NewStore(db) + handler := admin.NewHandler(store, nil, nil, "token", nil, testLogger()) + + mux := http.NewServeMux() + handler.RegisterRoutes(mux) + + // POST without X-Requested-With should be rejected. + rule := budget.Rule{APIKeyPattern: "sk-*", DailyLimitUSD: 10.0, Action: "block"} + body, _ := json.Marshal(rule) + req := httptest.NewRequest("POST", "/api/admin/budgets/rules", bytes.NewReader(body)) + req.Header.Set("Authorization", "Bearer token") + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusForbidden { + t.Fatalf("expected 403 without X-Requested-With, got %d", rec.Code) + } + + // DELETE without X-Requested-With should be rejected. + req = httptest.NewRequest("DELETE", "/api/admin/budgets/rules?pattern=sk-*", nil) + req.Header.Set("Authorization", "Bearer token") + rec = httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusForbidden { + t.Fatalf("expected 403 without X-Requested-With on DELETE, got %d", rec.Code) + } + + // GET without X-Requested-With should be allowed. + req = httptest.NewRequest("GET", "/api/admin/budgets/rules", nil) + req.Header.Set("Authorization", "Bearer token") + rec = httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 for GET without X-Requested-With, got %d", rec.Code) + } +} + func TestHandler_DeleteNonExistent(t *testing.T) { db := setupTestDB(t) store := admin.NewStore(db) - handler := admin.NewHandler(store, nil, nil, "token", nil) + handler := admin.NewHandler(store, nil, nil, "token", nil, testLogger()) mux := http.NewServeMux() handler.RegisterRoutes(mux) req := httptest.NewRequest("DELETE", "/api/admin/budgets/rules?pattern=nonexistent", nil) req.Header.Set("Authorization", "Bearer token") + req.Header.Set("X-Requested-With", "XMLHttpRequest") rec := httptest.NewRecorder() mux.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { @@ -240,10 +287,36 @@ func TestHandler_DeleteNonExistent(t *testing.T) { } } +func TestHandler_BudgetStatus(t *testing.T) { + db := setupTestDB(t) + store := admin.NewStore(db) + handler := admin.NewHandler(store, nil, nil, "token", nil, testLogger()) + + mux := http.NewServeMux() + handler.RegisterRoutes(mux) + + // Budget status should return an empty array when no ledger is configured. + req := httptest.NewRequest("GET", "/api/admin/budgets/status", nil) + req.Header.Set("Authorization", "Bearer token") + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + + var statuses []json.RawMessage + if err := json.NewDecoder(rec.Body).Decode(&statuses); err != nil { + t.Fatal(err) + } + if len(statuses) != 0 { + t.Fatalf("expected 0 statuses with nil ledger, got %d", len(statuses)) + } +} + func TestHandler_NoToken(t *testing.T) { db := setupTestDB(t) store := admin.NewStore(db) - handler := admin.NewHandler(store, nil, nil, "", nil) + handler := admin.NewHandler(store, nil, nil, "", nil, testLogger()) mux := http.NewServeMux() handler.RegisterRoutes(mux) diff --git a/internal/admin/handlers.go b/internal/admin/handlers.go index 5f4d4f7..d5558be 100644 --- a/internal/admin/handlers.go +++ b/internal/admin/handlers.go @@ -1,7 +1,9 @@ package admin import ( + "crypto/subtle" "encoding/json" + "log/slog" "net/http" "time" @@ -16,16 +18,18 @@ type Handler struct { budgetMgr *budget.Manager token string // admin authentication token blocklist *Blocklist + logger *slog.Logger } // NewHandler creates an admin API handler. -func NewHandler(store *Store, l ledger.Ledger, budgetMgr *budget.Manager, token string, blocklist *Blocklist) *Handler { +func NewHandler(store *Store, l ledger.Ledger, budgetMgr *budget.Manager, token string, blocklist *Blocklist, logger *slog.Logger) *Handler { return &Handler{ store: store, ledger: l, budgetMgr: budgetMgr, token: token, blocklist: blocklist, + logger: logger, } } @@ -39,6 +43,7 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /api/admin/api-keys/block", h.requireAuth(h.handleBlockKey)) mux.HandleFunc("DELETE /api/admin/api-keys/block", h.requireAuth(h.handleUnblockKey)) mux.HandleFunc("GET /api/admin/providers", h.requireAuth(h.handleListProviders)) + mux.HandleFunc("GET /api/admin/budgets/status", h.requireAuth(h.handleBudgetStatus)) } func (h *Handler) requireAuth(next http.HandlerFunc) http.HandlerFunc { @@ -48,10 +53,17 @@ func (h *Handler) requireAuth(next http.HandlerFunc) http.HandlerFunc { return } auth := r.Header.Get("Authorization") - if auth != "Bearer "+h.token { + expected := "Bearer " + h.token + if subtle.ConstantTimeCompare([]byte(auth), []byte(expected)) != 1 { writeAdminError(w, http.StatusUnauthorized, "invalid admin token") return } + if r.Method != http.MethodGet { + if r.Header.Get("X-Requested-With") != "XMLHttpRequest" { + writeAdminError(w, http.StatusForbidden, "missing required header") + return + } + } next(w, r) } } @@ -60,7 +72,8 @@ func (h *Handler) requireAuth(next http.HandlerFunc) http.HandlerFunc { func (h *Handler) handleListRules(w http.ResponseWriter, r *http.Request) { var rules []budget.Rule if err := h.store.GetJSON(r.Context(), "budget_rules", &rules); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("listing budget rules", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } writeAdminJSON(w, rules) @@ -80,7 +93,8 @@ func (h *Handler) handleCreateRule(w http.ResponseWriter, r *http.Request) { rules = append(rules, rule) if err := h.store.SetJSON(r.Context(), "budget_rules", rules); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("saving budget rules", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } @@ -120,7 +134,8 @@ func (h *Handler) handleDeleteRule(w http.ResponseWriter, r *http.Request) { } if err := h.store.SetJSON(r.Context(), "budget_rules", filtered); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("deleting budget rule", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } @@ -150,7 +165,8 @@ func (h *Handler) handleBlockKey(w http.ResponseWriter, r *http.Request) { patterns = append(patterns, req.Pattern) if err := h.store.SetJSON(r.Context(), "blocked_keys", patterns); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("blocking API key", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } @@ -189,7 +205,8 @@ func (h *Handler) handleUnblockKey(w http.ResponseWriter, r *http.Request) { } if err := h.store.SetJSON(r.Context(), "blocked_keys", filtered); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("unblocking API key", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } @@ -204,7 +221,8 @@ func (h *Handler) handleUnblockKey(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleListBlocked(w http.ResponseWriter, r *http.Request) { var patterns []string if err := h.store.GetJSON(r.Context(), "blocked_keys", &patterns); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("listing blocked keys", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } if patterns == nil { @@ -224,7 +242,8 @@ func (h *Handler) handleListAPIKeys(w http.ResponseWriter, r *http.Request) { GroupBy: "key", }) if err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("listing API keys", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } @@ -251,7 +270,8 @@ func (h *Handler) handleListProviders(w http.ResponseWriter, r *http.Request) { // Return from runtime config if available. var providers map[string]bool if err := h.store.GetJSON(r.Context(), "providers_enabled", &providers); err != nil { - writeAdminError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("listing providers", "error", err) + writeAdminError(w, http.StatusInternalServerError, "internal server error") return } if providers == nil { @@ -260,6 +280,75 @@ func (h *Handler) handleListProviders(w http.ResponseWriter, r *http.Request) { writeAdminJSON(w, providers) } +// BudgetStatus shows current utilization of a budget rule. +type BudgetStatus struct { + Pattern string `json:"pattern"` + DailySpent float64 `json:"daily_spent"` + DailyLimit float64 `json:"daily_limit"` + MonthlySpent float64 `json:"monthly_spent"` + MonthlyLimit float64 `json:"monthly_limit"` + Action string `json:"action"` +} + +// handleBudgetStatus returns budget utilization for all configured rules. +func (h *Handler) handleBudgetStatus(w http.ResponseWriter, r *http.Request) { + now := time.Now().UTC() + dayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + + var rules []budget.Rule + _ = h.store.GetJSON(r.Context(), "budget_rules", &rules) + + var statuses []BudgetStatus + + // Per-key rules: aggregate spend for each key matching the pattern. + // Since we can't glob-match key hashes, we use total spend grouped by key. + if h.ledger != nil { + dayEntries, _ := h.ledger.QueryCosts(r.Context(), ledger.CostFilter{ + Since: dayStart, Until: now, GroupBy: "key", + }) + monthEntries, _ := h.ledger.QueryCosts(r.Context(), ledger.CostFilter{ + Since: monthStart, Until: now, GroupBy: "key", + }) + + // Calculate total spend across all keys for the default rule. + var totalDailySpend, totalMonthlySpend float64 + for _, e := range dayEntries { + totalDailySpend += e.TotalCostUSD + } + for _, e := range monthEntries { + totalMonthlySpend += e.TotalCostUSD + } + + // Per-rule entries. + for _, rule := range rules { + statuses = append(statuses, BudgetStatus{ + Pattern: rule.APIKeyPattern, + DailySpent: totalDailySpend, + DailyLimit: rule.DailyLimitUSD, + MonthlySpent: totalMonthlySpend, + MonthlyLimit: rule.MonthlyLimitUSD, + Action: rule.Action, + }) + } + + // Default rule entry. + if h.budgetMgr != nil && h.budgetMgr.Enabled() { + statuses = append(statuses, BudgetStatus{ + Pattern: "(default)", + DailySpent: totalDailySpend, + MonthlySpent: totalMonthlySpend, + Action: "default", + }) + } + } + + if statuses == nil { + statuses = []BudgetStatus{} + } + writeAdminJSON(w, statuses) +} + func writeAdminJSON(w http.ResponseWriter, data any) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(data) diff --git a/internal/agent/session.go b/internal/agent/session.go index 53f1d74..eb39ae5 100644 --- a/internal/agent/session.go +++ b/internal/agent/session.go @@ -81,7 +81,8 @@ type trackedSession struct { } const ( - flushInterval = 10 * time.Second + flushInterval = 10 * time.Second + maxActiveSessions = 10000 StatusActive = "active" StatusCompleted = "completed" @@ -126,6 +127,9 @@ func (t *Tracker) TrackCall(sessionID, agentID, userID, task, model, path string ts, ok := t.sessions[sessionID] if !ok { + if len(t.sessions) >= maxActiveSessions { + t.evictOldest() + } ts = &trackedSession{ session: Session{ ID: sessionID, @@ -267,6 +271,32 @@ func (t *Tracker) Close() { }) } +func (t *Tracker) evictOldest() { + var oldestID string + var oldestTime time.Time + first := true + + for id, ts := range t.sessions { + lastActivity := ts.session.StartedAt + if len(ts.calls) > 0 { + lastActivity = ts.calls[len(ts.calls)-1].Timestamp + } + if first || lastActivity.Before(oldestTime) { + oldestID = id + oldestTime = lastActivity + first = false + } + } + + if oldestID != "" { + t.logger.Warn("evicting oldest session due to max sessions cap", + "session_id", oldestID, + "last_activity", oldestTime, + ) + delete(t.sessions, oldestID) + } +} + func (t *Tracker) backgroundLoop() { ticker := time.NewTicker(flushInterval) defer ticker.Stop() diff --git a/internal/agent/session_test.go b/internal/agent/session_test.go index c9c0cd5..5ded3e6 100644 --- a/internal/agent/session_test.go +++ b/internal/agent/session_test.go @@ -285,3 +285,53 @@ func TestExpireIdleSessions(t *testing.T) { t.Errorf("status = %q, want completed (should be expired)", ts.session.Status) } } + +func TestTrackerEvictsOldestSession(t *testing.T) { + store := newStubStore() + cfg := Config{SessionTimeoutMins: 30} + tracker := NewTracker(store, cfg, nil, testLogger()) + defer tracker.Close() + + // Fill to max capacity. + tracker.mu.Lock() + for i := 0; i < maxActiveSessions; i++ { + id := "sess-" + time.Now().Add(time.Duration(i)*time.Millisecond).Format("150405.000000") + tracker.sessions[id] = &trackedSession{ + session: Session{ + ID: id, + Status: StatusActive, + StartedAt: time.Now(), + }, + calls: []CallRecord{ + {Timestamp: time.Now()}, + }, + } + } + // Add one old session that should be evicted. + tracker.sessions["oldest"] = &trackedSession{ + session: Session{ + ID: "oldest", + Status: StatusActive, + StartedAt: time.Now().Add(-1 * time.Hour), + }, + calls: []CallRecord{ + {Timestamp: time.Now().Add(-1 * time.Hour)}, + }, + } + tracker.mu.Unlock() + + // This should evict the oldest session to make room. + tracker.TrackCall("new-session", "agent1", "user1", "task", "gpt-4o", "/v1/chat/completions") + + tracker.mu.RLock() + _, oldestExists := tracker.sessions["oldest"] + _, newExists := tracker.sessions["new-session"] + tracker.mu.RUnlock() + + if oldestExists { + t.Error("oldest session should have been evicted") + } + if !newExists { + t.Error("new session should have been created") + } +} diff --git a/internal/budget/budget.go b/internal/budget/budget.go index 6beed50..f511eb5 100644 --- a/internal/budget/budget.go +++ b/internal/budget/budget.go @@ -58,6 +58,10 @@ type Manager struct { logger *slog.Logger onWarn func(ctx context.Context, apiKeyHash string, result Result) onBlock func(ctx context.Context, apiKeyHash string, result Result) + + mu sync.RWMutex // protects config.Rules + done chan struct{} + closed sync.Once } type spendEntry struct { @@ -70,12 +74,22 @@ const defaultCacheTTL = 30 * time.Second // NewManager creates a budget enforcement manager. func NewManager(l ledger.Ledger, cfg Config, logger *slog.Logger) *Manager { - return &Manager{ + m := &Manager{ ledger: l, config: cfg, cacheTTL: defaultCacheTTL, logger: logger, + done: make(chan struct{}), } + go m.cacheCleanupLoop() + return m +} + +// Close stops the background cache cleanup goroutine. +func (m *Manager) Close() { + m.closed.Do(func() { + close(m.done) + }) } // SetCallbacks configures alert callbacks for budget events. @@ -90,9 +104,14 @@ func (m *Manager) SetCallbacks( // UpdateRules replaces the per-key rules at runtime (hot-reload from admin API). // The default rule is not changed. func (m *Manager) UpdateRules(rules []Rule) { + m.mu.Lock() m.config.Rules = rules + m.mu.Unlock() // Invalidate cache so new rules take effect immediately. - m.cache = sync.Map{} + m.cache.Range(func(key, _ any) bool { + m.cache.Delete(key) + return true + }) } // Enabled returns true if any budget limits are configured. @@ -100,7 +119,10 @@ func (m *Manager) Enabled() bool { if m.config.Default.DailyLimitUSD > 0 || m.config.Default.MonthlyLimitUSD > 0 { return true } - return len(m.config.Rules) > 0 + m.mu.RLock() + n := len(m.config.Rules) + m.mu.RUnlock() + return n > 0 } // Check evaluates budget for a request. rawKey is used for rule pattern @@ -198,7 +220,12 @@ func (m *Manager) evaluateRule(rule Rule, daily, monthly float64) Result { // matchRule returns the most specific rule matching the raw API key, // falling back to the default rule. func (m *Manager) matchRule(rawKey string) Rule { - for _, r := range m.config.Rules { + m.mu.RLock() + rules := make([]Rule, len(m.config.Rules)) + copy(rules, m.config.Rules) + m.mu.RUnlock() + + for _, r := range rules { if matched, _ := filepath.Match(r.APIKeyPattern, rawKey); matched { return m.mergeWithDefault(r) } @@ -225,7 +252,12 @@ func (m *Manager) mergeWithDefault(r Rule) Rule { // matchTenantRule finds a rule that targets a specific tenant (no API key pattern). func (m *Manager) matchTenantRule(tenantID string) *Rule { - for _, r := range m.config.Rules { + m.mu.RLock() + rules := make([]Rule, len(m.config.Rules)) + copy(rules, m.config.Rules) + m.mu.RUnlock() + + for _, r := range rules { if r.TenantID == tenantID && r.APIKeyPattern == "" { merged := m.mergeWithDefault(r) return &merged @@ -300,3 +332,28 @@ func (m *Manager) getSpend(ctx context.Context, apiKeyHash string) (daily, month return daily, monthly } + +func (m *Manager) cacheCleanupLoop() { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.evictStaleCache() + case <-m.done: + return + } + } +} + +func (m *Manager) evictStaleCache() { + now := time.Now() + m.cache.Range(func(key, value any) bool { + e := value.(*spendEntry) + if now.Sub(e.fetched) > m.cacheTTL { + m.cache.Delete(key) + } + return true + }) +} diff --git a/internal/budget/budget_test.go b/internal/budget/budget_test.go index 4793687..ed0786c 100644 --- a/internal/budget/budget_test.go +++ b/internal/budget/budget_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "log/slog" + "sync" "testing" "time" @@ -40,6 +41,15 @@ func (s *stubLedger) QueryRecentExpensive(_ context.Context, _, _ time.Time, _ s func (s *stubLedger) QueryErrorStats(_ context.Context, _, _ time.Time, _ string) (*ledger.ErrorStats, error) { return &ledger.ErrorStats{}, nil } +func (s *stubLedger) QueryRecentSessions(_ context.Context, _, _ time.Time, _ string, _ int) ([]ledger.SessionRecord, error) { + return nil, nil +} +func (s *stubLedger) QueryLatencyPercentiles(_ context.Context, _, _ time.Time, _ string) (*ledger.LatencyStats, error) { + return &ledger.LatencyStats{}, nil +} +func (s *stubLedger) QueryTokenTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TokenTimeseriesPoint, error) { + return nil, nil +} func (s *stubLedger) Close() error { return nil } func newTestLogger() *slog.Logger { @@ -48,6 +58,7 @@ func newTestLogger() *slog.Logger { func TestBudgetAllowWhenNoLimits(t *testing.T) { mgr := NewManager(&stubLedger{}, Config{}, newTestLogger()) + defer mgr.Close() if mgr.Enabled() { t.Error("should not be enabled with no limits") @@ -70,6 +81,7 @@ func TestBudgetAllowUnderLimit(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() if !mgr.Enabled() { t.Fatal("should be enabled") @@ -92,6 +104,7 @@ func TestBudgetWarnAtSoftLimit(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() result := mgr.Check(context.Background(), "sk-test-key", "hash123", "") if result.Decision != Warn { @@ -110,6 +123,7 @@ func TestBudgetBlockAtHardLimit(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() result := mgr.Check(context.Background(), "sk-test-key", "hash123", "") if result.Decision != Block { @@ -132,6 +146,7 @@ func TestBudgetWarnActionAtHardLimit(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() result := mgr.Check(context.Background(), "sk-test-key", "hash123", "") if result.Decision != Warn { @@ -149,6 +164,7 @@ func TestBudgetMonthlyBlock(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() result := mgr.Check(context.Background(), "sk-test-key", "hash123", "") if result.Decision != Block { @@ -172,6 +188,7 @@ func TestBudgetRulePatternMatch(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() // Dev key should be blocked at 8.0 > 5.0. result := mgr.Check(context.Background(), "sk-proj-dev-abc123", "hash-dev", "") @@ -204,6 +221,7 @@ func TestBudgetRuleMergesDefaults(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() // Should block because monthly 600 > default 500. result := mgr.Check(context.Background(), "sk-proj-dev-abc", "hash-dev", "") @@ -221,6 +239,7 @@ func TestBudgetSpendCaching(t *testing.T) { }, } mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() // First call populates cache. mgr.Check(context.Background(), "sk-key", "hash1", "") @@ -248,9 +267,71 @@ func TestBudgetEnabled(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mgr := NewManager(&stubLedger{}, tt.cfg, newTestLogger()) + defer mgr.Close() if got := mgr.Enabled(); got != tt.want { t.Errorf("Enabled() = %v, want %v", got, tt.want) } }) } } + +func TestBudgetCacheEviction(t *testing.T) { + store := &stubLedger{dailySpend: 1.0, monthlySpend: 5.0} + cfg := Config{ + Default: Rule{ + DailyLimitUSD: 50.0, + Action: "block", + }, + } + mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() + + // Populate cache. + mgr.Check(context.Background(), "sk-key", "hash-evict", "") + + // Mark the entry as stale by setting fetched time in the past. + mgr.cache.Store("hash-evict", &spendEntry{ + daily: 1.0, + monthly: 5.0, + fetched: time.Now().Add(-2 * mgr.cacheTTL), + }) + + mgr.evictStaleCache() + + // Entry should be gone. + _, exists := mgr.cache.Load("hash-evict") + if exists { + t.Error("expected stale cache entry to be evicted") + } +} + +func TestUpdateRulesRace(t *testing.T) { + store := &stubLedger{dailySpend: 1.0, monthlySpend: 5.0} + cfg := Config{ + Default: Rule{ + DailyLimitUSD: 50.0, + Action: "block", + }, + } + mgr := NewManager(store, cfg, newTestLogger()) + defer mgr.Close() + + // Concurrently update rules and check budget to detect races. + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + wg.Add(2) + go func() { + defer wg.Done() + mgr.UpdateRules([]Rule{ + {APIKeyPattern: "sk-*", DailyLimitUSD: 10.0, Action: "block"}, + }) + }() + go func() { + defer wg.Done() + mgr.Check(context.Background(), "sk-test", "hash-race", "") + }() + } + + wg.Wait() +} diff --git a/internal/config/config.go b/internal/config/config.go index 7dd05fa..d7d77da 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,19 @@ type Config struct { RateLimits RateLimitsConfig `mapstructure:"rate_limits"` Admin AdminConfig `mapstructure:"admin"` MCP MCPConfig `mapstructure:"mcp"` + CORS CORSConfig `mapstructure:"cors"` + TLS TLSConfig `mapstructure:"tls"` +} + +// CORSConfig holds CORS settings. +type CORSConfig struct { + AllowOrigins []string `mapstructure:"allow_origins"` +} + +// TLSConfig holds TLS settings. +type TLSConfig struct { + CertFile string `mapstructure:"cert_file"` + KeyFile string `mapstructure:"key_file"` } // ProvidersConfig holds per-provider settings. diff --git a/internal/dashboard/handlers.go b/internal/dashboard/handlers.go index f0a37b0..01be22e 100644 --- a/internal/dashboard/handlers.go +++ b/internal/dashboard/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/csv" "encoding/json" "fmt" + "log/slog" "net/http" "strconv" "time" @@ -16,11 +17,12 @@ import ( type Handler struct { ledger ledger.Ledger tracker *agent.Tracker + logger *slog.Logger } // NewHandler creates a dashboard API handler. -func NewHandler(l ledger.Ledger, tracker *agent.Tracker) *Handler { - return &Handler{ledger: l, tracker: tracker} +func NewHandler(l ledger.Ledger, tracker *agent.Tracker, logger *slog.Logger) *Handler { + return &Handler{ledger: l, tracker: tracker, logger: logger} } // RegisterRoutes registers dashboard API routes on the given mux. @@ -29,9 +31,12 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /api/dashboard/timeseries", h.handleTimeseries) mux.HandleFunc("GET /api/dashboard/costs", h.handleCosts) mux.HandleFunc("GET /api/dashboard/sessions", h.handleSessions) + mux.HandleFunc("GET /api/dashboard/sessions/history", h.handleSessionHistory) mux.HandleFunc("GET /api/dashboard/export", h.handleExport) mux.HandleFunc("GET /api/dashboard/expensive", h.handleExpensive) mux.HandleFunc("GET /api/dashboard/stats", h.handleStats) + mux.HandleFunc("GET /api/dashboard/latency", h.handleLatency) + mux.HandleFunc("GET /api/dashboard/timeseries/tokens", h.handleTokenTimeseries) } func (h *Handler) handleSummary(w http.ResponseWriter, r *http.Request) { @@ -48,7 +53,8 @@ func (h *Handler) handleSummary(w http.ResponseWriter, r *http.Request) { TenantID: tenantID, }) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("querying today costs", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } @@ -67,7 +73,8 @@ func (h *Handler) handleSummary(w http.ResponseWriter, r *http.Request) { TenantID: tenantID, }) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("querying month costs", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } @@ -108,7 +115,8 @@ func (h *Handler) handleTimeseries(w http.ResponseWriter, r *http.Request) { points, err := h.ledger.QueryCostTimeseries(r.Context(), interval, since, now, tenantID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("querying timeseries", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } @@ -138,7 +146,8 @@ func (h *Handler) handleCosts(w http.ResponseWriter, r *http.Request) { TenantID: tenantID, }) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("querying costs", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } @@ -183,7 +192,8 @@ func (h *Handler) handleExport(w http.ResponseWriter, r *http.Request) { TenantID: tenantID, }) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("exporting costs", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } @@ -239,7 +249,8 @@ func (h *Handler) handleExpensive(w http.ResponseWriter, r *http.Request) { results, err := h.ledger.QueryRecentExpensive(r.Context(), since, now, tenantID, limit) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("querying expensive requests", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } writeJSON(w, results) @@ -256,12 +267,84 @@ func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) { stats, err := h.ledger.QueryErrorStats(r.Context(), since, now, tenantID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + h.logger.Error("querying error stats", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") return } writeJSON(w, stats) } +func (h *Handler) handleSessionHistory(w http.ResponseWriter, r *http.Request) { + hours, _ := strconv.Atoi(r.URL.Query().Get("hours")) + if hours <= 0 { + hours = 24 + } + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + if limit <= 0 { + limit = 50 + } + status := r.URL.Query().Get("status") + + now := time.Now().UTC() + since := now.Add(-time.Duration(hours) * time.Hour) + + records, err := h.ledger.QueryRecentSessions(r.Context(), since, now, status, limit) + if err != nil { + h.logger.Error("querying session history", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") + return + } + if records == nil { + records = []ledger.SessionRecord{} + } + writeJSON(w, records) +} + +func (h *Handler) handleLatency(w http.ResponseWriter, r *http.Request) { + hours, _ := strconv.Atoi(r.URL.Query().Get("hours")) + if hours <= 0 { + hours = 24 + } + tenantID := r.URL.Query().Get("tenant") + + now := time.Now().UTC() + since := now.Add(-time.Duration(hours) * time.Hour) + + stats, err := h.ledger.QueryLatencyPercentiles(r.Context(), since, now, tenantID) + if err != nil { + h.logger.Error("querying latency", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") + return + } + writeJSON(w, stats) +} + +func (h *Handler) handleTokenTimeseries(w http.ResponseWriter, r *http.Request) { + interval := r.URL.Query().Get("interval") + if interval == "" { + interval = "hour" + } + + hoursF, _ := strconv.ParseFloat(r.URL.Query().Get("hours"), 64) + if hoursF <= 0 { + hoursF = 24 + } + + tenantID := r.URL.Query().Get("tenant") + + now := time.Now().UTC() + since := now.Add(-time.Duration(hoursF * float64(time.Hour))) + + points, err := h.ledger.QueryTokenTimeseries(r.Context(), interval, since, now, tenantID) + if err != nil { + h.logger.Error("querying token timeseries", "error", err) + writeError(w, http.StatusInternalServerError, "internal server error") + return + } + + writeJSON(w, points) +} + func writeJSON(w http.ResponseWriter, data any) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(data) diff --git a/internal/dashboard/handlers_test.go b/internal/dashboard/handlers_test.go index ed0cbec..4b44b08 100644 --- a/internal/dashboard/handlers_test.go +++ b/internal/dashboard/handlers_test.go @@ -3,6 +3,8 @@ package dashboard import ( "context" "encoding/json" + "io" + "log/slog" "net/http" "net/http/httptest" "testing" @@ -12,8 +14,11 @@ import ( ) type stubLedger struct { - costs []ledger.CostEntry - timeseries []ledger.TimeseriesPoint + costs []ledger.CostEntry + timeseries []ledger.TimeseriesPoint + sessions []ledger.SessionRecord + latency *ledger.LatencyStats + tokenTimeseries []ledger.TokenTimeseriesPoint } func (s *stubLedger) RecordUsage(_ context.Context, _ *ledger.UsageRecord) error { return nil } @@ -35,8 +40,24 @@ func (s *stubLedger) QueryRecentExpensive(_ context.Context, _, _ time.Time, _ s func (s *stubLedger) QueryErrorStats(_ context.Context, _, _ time.Time, _ string) (*ledger.ErrorStats, error) { return &ledger.ErrorStats{}, nil } +func (s *stubLedger) QueryRecentSessions(_ context.Context, _, _ time.Time, _ string, _ int) ([]ledger.SessionRecord, error) { + return s.sessions, nil +} +func (s *stubLedger) QueryLatencyPercentiles(_ context.Context, _, _ time.Time, _ string) (*ledger.LatencyStats, error) { + if s.latency != nil { + return s.latency, nil + } + return &ledger.LatencyStats{Buckets: []ledger.LatencyBucket{}}, nil +} +func (s *stubLedger) QueryTokenTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TokenTimeseriesPoint, error) { + return s.tokenTimeseries, nil +} func (s *stubLedger) Close() error { return nil } +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + func TestHandleSummary(t *testing.T) { store := &stubLedger{ costs: []ledger.CostEntry{ @@ -44,7 +65,7 @@ func TestHandleSummary(t *testing.T) { {Model: "claude-sonnet-4-6", Requests: 5, TotalCostUSD: 1.20}, }, } - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -76,7 +97,7 @@ func TestHandleTimeseries(t *testing.T) { {Timestamp: time.Now(), CostUSD: 0.50, Requests: 10}, }, } - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -104,7 +125,7 @@ func TestHandleCosts(t *testing.T) { {Model: "gpt-4o-mini", Requests: 10, InputTokens: 100, OutputTokens: 50, TotalCostUSD: 0.50}, }, } - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -132,7 +153,7 @@ func TestHandleCostsWithTenant(t *testing.T) { {Model: "gpt-4o-mini", Requests: 3, TotalCostUSD: 0.15}, }, } - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -152,7 +173,7 @@ func TestHandleSummaryWithTenant(t *testing.T) { {Model: "gpt-4o-mini", Requests: 5, TotalCostUSD: 0.25}, }, } - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -172,7 +193,7 @@ func TestHandleTimeseriesWithTenant(t *testing.T) { {Timestamp: time.Now(), CostUSD: 0.10, Requests: 2}, }, } - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -188,7 +209,7 @@ func TestHandleTimeseriesWithTenant(t *testing.T) { func TestHandleSessionsWithoutTracker(t *testing.T) { store := &stubLedger{} - h := NewHandler(store, nil) + h := NewHandler(store, nil, testLogger()) mux := http.NewServeMux() h.RegisterRoutes(mux) @@ -202,6 +223,98 @@ func TestHandleSessionsWithoutTracker(t *testing.T) { } } +func TestHandleSessionHistory(t *testing.T) { + now := time.Now() + store := &stubLedger{ + sessions: []ledger.SessionRecord{ + {ID: "sess-1", AgentID: "agent-a", Status: "completed", StartedAt: now.Add(-time.Hour), CallCount: 5, TotalCostUSD: 0.50}, + }, + } + h := NewHandler(store, nil, testLogger()) + + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest("GET", "/api/dashboard/sessions/history?hours=24&limit=10", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } + + var records []ledger.SessionRecord + if err := json.NewDecoder(w.Body).Decode(&records); err != nil { + t.Fatal(err) + } + if len(records) != 1 { + t.Errorf("expected 1 record, got %d", len(records)) + } +} + +func TestHandleLatency(t *testing.T) { + store := &stubLedger{ + latency: &ledger.LatencyStats{ + P50: 150, P90: 500, P99: 2000, + Buckets: []ledger.LatencyBucket{ + {Label: "<100ms", Count: 10}, + {Label: "100-500ms", Count: 20}, + }, + }, + } + h := NewHandler(store, nil, testLogger()) + + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest("GET", "/api/dashboard/latency?hours=24", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } + + var stats ledger.LatencyStats + if err := json.NewDecoder(w.Body).Decode(&stats); err != nil { + t.Fatal(err) + } + if stats.P50 != 150 { + t.Errorf("P50 = %v, want 150", stats.P50) + } + if len(stats.Buckets) != 2 { + t.Errorf("expected 2 buckets, got %d", len(stats.Buckets)) + } +} + +func TestHandleTokenTimeseries(t *testing.T) { + store := &stubLedger{ + tokenTimeseries: []ledger.TokenTimeseriesPoint{ + {Timestamp: time.Now(), InputTokens: 1000, OutputTokens: 500}, + }, + } + h := NewHandler(store, nil, testLogger()) + + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest("GET", "/api/dashboard/timeseries/tokens?interval=hour&hours=24", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } + + var points []ledger.TokenTimeseriesPoint + if err := json.NewDecoder(w.Body).Decode(&points); err != nil { + t.Fatal(err) + } + if len(points) != 1 { + t.Errorf("expected 1 point, got %d", len(points)) + } +} + func TestStaticHandler(t *testing.T) { handler := StaticHandler() diff --git a/internal/dashboard/static/app.js b/internal/dashboard/static/app.js index 2f43545..2c1a8f3 100644 --- a/internal/dashboard/static/app.js +++ b/internal/dashboard/static/app.js @@ -26,6 +26,27 @@ return (n * 100).toFixed(1) + "%"; } + function fmtDuration(startedAt, endedAt) { + var start = new Date(startedAt).getTime(); + var end = endedAt ? new Date(endedAt).getTime() : Date.now(); + var ms = end - start; + if (ms < 0) return "--"; + var secs = Math.floor(ms / 1000); + if (secs < 60) return secs + "s"; + var mins = Math.floor(secs / 60); + secs = secs % 60; + if (mins < 60) return mins + "m " + secs + "s"; + var hrs = Math.floor(mins / 60); + mins = mins % 60; + return hrs + "h " + mins + "m"; + } + + function fmtTokenCount(n) { + if (n >= 1000000) return (n / 1000000).toFixed(1) + "M"; + if (n >= 1000) return (n / 1000).toFixed(1) + "k"; + return String(n); + } + const DONUT_COLORS = [ "#388bfd", "#3fb950", "#d29922", "#f85149", "#a371f7", "#79c0ff", "#56d364", "#e3b341", "#ff7b72", "#bc8cff", @@ -49,11 +70,17 @@ adminToken = $("#admin-token").value.trim(); localStorage.setItem("agentledger_admin_token", adminToken); loadRules(); + loadBudgetStatus(); + loadBlocked(); }); async function adminFetch(url, opts = {}) { if (!adminToken) return null; - opts.headers = { ...opts.headers, Authorization: "Bearer " + adminToken }; + opts.headers = { + ...opts.headers, + Authorization: "Bearer " + adminToken, + "X-Requested-With": "XMLHttpRequest", + }; const resp = await fetch(url, opts); if (!resp.ok) return null; return resp.json(); @@ -104,6 +131,7 @@ // ── Timeseries chart ── let timeseriesChart = null; + let currentTimeseriesTab = "cost"; function formatLabel(ts, interval) { const d = new Date(ts); @@ -189,6 +217,102 @@ } } + // ── Token timeseries chart ── + let tokenChart = null; + + async function loadTokenTimeseries() { + const hours = parseFloat($("#timeseries-hours").value); + let interval = "hour"; + if (hours <= 6) interval = "minute"; + else if (hours > 24) interval = "day"; + + try { + const points = await fetchJSON( + tenantQS(`/api/dashboard/timeseries/tokens?interval=${interval}&hours=${hours}`) + ); + const data = points || []; + const labels = data.map((p) => formatLabel(p.Timestamp, interval)); + const ctx = document.getElementById("token-chart").getContext("2d"); + + if (tokenChart) tokenChart.destroy(); + + tokenChart = new Chart(ctx, { + type: "line", + data: { + labels, + datasets: [ + { + label: "Input Tokens", + data: data.map((p) => p.InputTokens), + backgroundColor: "rgba(56, 139, 253, 0.15)", + borderColor: "rgba(56, 139, 253, 1)", + borderWidth: 2, + fill: true, + tension: 0.35, + pointRadius: data.length > 60 ? 0 : 3, + }, + { + label: "Output Tokens", + data: data.map((p) => p.OutputTokens), + backgroundColor: "rgba(63, 185, 80, 0.15)", + borderColor: "rgba(63, 185, 80, 1)", + borderWidth: 2, + fill: true, + tension: 0.35, + pointRadius: data.length > 60 ? 0 : 3, + }, + ], + }, + options: { + responsive: true, + maintainAspectRatio: false, + interaction: { intersect: false, mode: "index" }, + plugins: { + legend: { display: true, position: "top", labels: { color: "#8b949e", font: { size: 11 } } }, + tooltip: { + backgroundColor: "#1c2128", borderColor: "#30363d", borderWidth: 1, + titleColor: "#e1e4e8", bodyColor: "#c9d1d9", + padding: 12, cornerRadius: 8, + callbacks: { label: (ctx) => " " + ctx.dataset.label + ": " + fmtTokenCount(ctx.parsed.y) }, + }, + }, + scales: { + x: { + grid: { color: "rgba(33,38,45,0.5)", drawBorder: false }, + ticks: { color: "#8b949e", font: { size: 11 }, maxRotation: 0, autoSkip: true, maxTicksLimit: 10 }, + }, + y: { + beginAtZero: true, + stacked: true, + grid: { color: "rgba(33,38,45,0.5)", drawBorder: false }, + ticks: { color: "#8b949e", font: { size: 11 }, maxTicksLimit: 6, callback: (v) => fmtTokenCount(v) }, + }, + }, + }, + }); + } catch (e) { + console.error("token timeseries:", e); + } + } + + // Cost/Token tab switching + document.querySelectorAll("[data-tab]").forEach((btn) => { + btn.addEventListener("click", () => { + document.querySelectorAll("[data-tab]").forEach((b) => b.classList.remove("tab-active")); + btn.classList.add("tab-active"); + currentTimeseriesTab = btn.dataset.tab; + if (currentTimeseriesTab === "cost") { + document.getElementById("timeseries-chart").parentElement.style.display = ""; + document.getElementById("token-chart-container").style.display = "none"; + loadTimeseries(); + } else { + document.getElementById("timeseries-chart").parentElement.style.display = "none"; + document.getElementById("token-chart-container").style.display = ""; + loadTokenTimeseries(); + } + }); + }); + // ── Provider donut chart ── let providerChart = null; @@ -263,6 +387,124 @@ } } + // ── Agent cost leaderboard ── + let agentChart = null; + + async function loadAgentChart() { + try { + const entries = await fetchJSON(tenantQS("/api/dashboard/costs?group_by=agent&hours=168")); + if (!entries || !entries.length) return; + + const top10 = entries.slice(0, 10); + const labels = top10.map((e) => e.AgentID || "(none)"); + const values = top10.map((e) => e.TotalCostUSD); + const ctx = document.getElementById("agent-chart").getContext("2d"); + + if (agentChart) agentChart.destroy(); + + agentChart = new Chart(ctx, { + type: "bar", + data: { + labels, + datasets: [{ + label: "Cost (USD)", + data: values, + backgroundColor: DONUT_COLORS.slice(0, labels.length), + borderColor: "#161b22", + borderWidth: 1, + borderRadius: 4, + }], + }, + options: { + responsive: true, + maintainAspectRatio: false, + indexAxis: "y", + plugins: { + legend: { display: false }, + tooltip: { + backgroundColor: "#1c2128", borderColor: "#30363d", borderWidth: 1, + titleColor: "#e1e4e8", bodyColor: "#c9d1d9", + padding: 12, cornerRadius: 8, + callbacks: { label: (ctx) => " " + fmtCost(ctx.parsed.x) }, + }, + }, + scales: { + x: { + beginAtZero: true, + grid: { color: "rgba(33,38,45,0.5)", drawBorder: false }, + ticks: { color: "#8b949e", font: { size: 11 }, callback: (v) => fmtAxis(v) }, + }, + y: { + grid: { display: false }, + ticks: { color: "#c9d1d9", font: { size: 11 } }, + }, + }, + }, + }); + } catch (e) { + console.error("agent chart:", e); + } + } + + // ── Model usage chart ── + let modelChart = null; + + async function loadModelChart() { + try { + const entries = await fetchJSON(tenantQS("/api/dashboard/costs?group_by=model&hours=168")); + if (!entries || !entries.length) return; + + const top10 = entries.slice(0, 10); + const labels = top10.map((e) => e.Model || "(unknown)"); + const values = top10.map((e) => e.TotalCostUSD); + const ctx = document.getElementById("model-chart").getContext("2d"); + + if (modelChart) modelChart.destroy(); + + modelChart = new Chart(ctx, { + type: "bar", + data: { + labels, + datasets: [{ + label: "Cost (USD)", + data: values, + backgroundColor: DONUT_COLORS.slice(0, labels.length), + borderColor: "#161b22", + borderWidth: 1, + borderRadius: 4, + }], + }, + options: { + responsive: true, + maintainAspectRatio: false, + indexAxis: "y", + plugins: { + legend: { display: false }, + tooltip: { + backgroundColor: "#1c2128", borderColor: "#30363d", borderWidth: 1, + titleColor: "#e1e4e8", bodyColor: "#c9d1d9", + padding: 12, cornerRadius: 8, + callbacks: { label: (ctx) => " " + fmtCost(ctx.parsed.x) }, + }, + }, + scales: { + x: { + beginAtZero: true, + grid: { color: "rgba(33,38,45,0.5)", drawBorder: false }, + ticks: { color: "#8b949e", font: { size: 11 }, callback: (v) => fmtAxis(v) }, + }, + y: { + grid: { display: false }, + ticks: { color: "#c9d1d9", font: { size: 11 } }, + }, + }, + }, + }); + } catch (e) { + console.error("model chart:", e); + } + } + // ── Cost breakdown table ── async function loadCosts() { const groupBy = $("#costs-group").value; @@ -326,43 +568,170 @@ } // ── Sessions table ── + let currentSessionTab = "active"; + + function renderSessionRows(sessions, tbody) { + tbody.innerHTML = ""; + if (!sessions || !sessions.length) { + tbody.innerHTML = 'No sessions'; + return; + } + + // Anomaly detection: compute mean cost and calls + var totalCost = 0, totalCalls = 0; + for (var i = 0; i < sessions.length; i++) { + totalCost += (sessions[i].TotalCostUSD || sessions[i].total_cost_usd || 0); + totalCalls += (sessions[i].CallCount || sessions[i].call_count || 0); + } + var meanCost = totalCost / sessions.length; + var meanCalls = totalCalls / sessions.length; + + for (const s of sessions) { + var id = s.ID || s.id || ""; + var agentID = s.AgentID || s.agent_id || "(none)"; + var userID = s.UserID || s.user_id || "(none)"; + var task = s.Task || s.task || ""; + var callCount = s.CallCount || s.call_count || 0; + var costUSD = s.TotalCostUSD || s.total_cost_usd || 0; + var totalTokens = s.TotalTokens || s.total_tokens || 0; + var status = s.Status || s.status || ""; + var startedAt = s.StartedAt || s.started_at || ""; + var endedAt = s.EndedAt || s.ended_at || null; + + var isAnomaly = (meanCost > 0 && costUSD > meanCost * 3) || + (meanCalls > 0 && callCount > meanCalls * 3); + + const tr = document.createElement("tr"); + if (isAnomaly) tr.classList.add("session-anomaly"); + const statusClass = "status-" + status; + const started = new Date(startedAt).toLocaleString("en-US", { + month: "short", day: "numeric", hour: "numeric", minute: "2-digit", + }); + var duration = fmtDuration(startedAt, endedAt); + var taskDisplay = task ? task.substring(0, 40) : ""; + if (task && task.length > 40) taskDisplay += "..."; + + tr.innerHTML = ` + ${esc(id.slice(0, 12))} + ${esc(agentID)} + ${esc(userID)} + ${esc(taskDisplay)} + ${callCount} + ${fmtCost(costUSD)} + ${totalTokens.toLocaleString()} + ${status} + ${duration} + ${started} + `; + tbody.appendChild(tr); + } + } + async function loadSessions() { try { const sessions = await fetchJSON("/api/dashboard/sessions"); - const tbody = $("#sessions-body"); - tbody.innerHTML = ""; - if (!sessions || !sessions.length) { - tbody.innerHTML = 'No active sessions'; - return; - } - for (const s of sessions) { - const tr = document.createElement("tr"); - const statusClass = "status-" + s.Status; - const started = new Date(s.StartedAt).toLocaleString("en-US", { - month: "short", day: "numeric", hour: "numeric", minute: "2-digit", - }); - tr.innerHTML = ` - ${esc(s.ID.slice(0, 12))} - ${esc(s.AgentID || "(none)")} - ${esc(s.UserID || "(none)")} - ${s.CallCount} - ${fmtCost(s.TotalCostUSD)} - ${s.Status} - ${started} - `; - tbody.appendChild(tr); + if (currentSessionTab === "active") { + renderSessionRows(sessions, $("#sessions-body")); } } catch (e) { console.error("sessions:", e); } } + async function loadSessionHistory(hours) { + try { + const sessions = await fetchJSON(`/api/dashboard/sessions/history?hours=${hours}&limit=50`); + renderSessionRows(sessions, $("#sessions-body")); + } catch (e) { + console.error("session history:", e); + } + } + + // Session tab switching + document.querySelectorAll("[data-session-tab]").forEach((btn) => { + btn.addEventListener("click", () => { + document.querySelectorAll("[data-session-tab]").forEach((b) => b.classList.remove("tab-active")); + btn.classList.add("tab-active"); + currentSessionTab = btn.dataset.sessionTab; + if (currentSessionTab === "active") { + loadSessions(); + } else { + loadSessionHistory(parseInt(currentSessionTab, 10)); + } + }); + }); + function esc(s) { const el = document.createElement("span"); el.textContent = s; return el.innerHTML; } + // ── Latency stats + chart ── + let latencyChart = null; + + async function loadLatency() { + try { + const data = await fetchJSON(tenantQS("/api/dashboard/latency?hours=24")); + if (!data) return; + + // Update stat items + $("#stat-p50").textContent = data.p50_ms.toFixed(0) + "ms"; + $("#stat-p90").textContent = data.p90_ms.toFixed(0) + "ms"; + $("#stat-p99").textContent = data.p99_ms.toFixed(0) + "ms"; + + // Render bucket chart + var buckets = data.buckets || []; + if (!buckets.length) return; + + var labels = buckets.map(function(b) { return b.label; }); + var values = buckets.map(function(b) { return b.count; }); + var ctx = document.getElementById("latency-chart").getContext("2d"); + + if (latencyChart) latencyChart.destroy(); + + latencyChart = new Chart(ctx, { + type: "bar", + data: { + labels: labels, + datasets: [{ + label: "Requests", + data: values, + backgroundColor: "rgba(163, 113, 247, 0.6)", + borderColor: "#a371f7", + borderWidth: 1, + borderRadius: 4, + }], + }, + options: { + responsive: true, + maintainAspectRatio: false, + plugins: { + legend: { display: false }, + tooltip: { + backgroundColor: "#1c2128", borderColor: "#30363d", borderWidth: 1, + titleColor: "#e1e4e8", bodyColor: "#c9d1d9", + padding: 12, cornerRadius: 8, + }, + }, + scales: { + x: { + grid: { display: false }, + ticks: { color: "#8b949e", font: { size: 10 } }, + }, + y: { + beginAtZero: true, + grid: { color: "rgba(33,38,45,0.5)", drawBorder: false }, + ticks: { color: "#8b949e", font: { size: 10 } }, + }, + }, + }, + }); + } catch (e) { + console.error("latency:", e); + } + } + // ── API Keys table ── async function loadAPIKeys() { try { @@ -411,6 +780,88 @@ } catch (e) { /* admin API may not be enabled */ } } + // ── Budget Status gauges ── + async function loadBudgetStatus() { + try { + const statuses = await adminFetch("/api/admin/budgets/status"); + const container = $("#budget-status-body"); + container.innerHTML = ""; + if (!statuses || !statuses.length) { + container.innerHTML = '
No budget data
'; + return; + } + for (const s of statuses) { + var item = document.createElement("div"); + item.className = "budget-item"; + + var dailyPct = s.daily_limit > 0 ? Math.min((s.daily_spent / s.daily_limit) * 100, 100) : 0; + var monthlyPct = s.monthly_limit > 0 ? Math.min((s.monthly_spent / s.monthly_limit) * 100, 100) : 0; + + var dailyClass = "budget-fill"; + if (dailyPct > 90) dailyClass += " budget-fill-danger"; + else if (dailyPct > 70) dailyClass += " budget-fill-warn"; + + var monthlyClass = "budget-fill"; + if (monthlyPct > 90) monthlyClass += " budget-fill-danger"; + else if (monthlyPct > 70) monthlyClass += " budget-fill-warn"; + + var html = '
' + esc(s.pattern) + '' + esc(s.action) + '
'; + + if (s.daily_limit > 0) { + html += '
Daily' + fmtCost(s.daily_spent) + ' / ' + fmtCost(s.daily_limit) + '
'; + html += '
'; + } + if (s.monthly_limit > 0) { + html += '
Monthly' + fmtCost(s.monthly_spent) + ' / ' + fmtCost(s.monthly_limit) + '
'; + html += '
'; + } + if (s.daily_limit <= 0 && s.monthly_limit <= 0) { + html += '
Spend: ' + fmtCost(s.daily_spent) + ' today / ' + fmtCost(s.monthly_spent) + ' month
'; + } + + item.innerHTML = html; + container.appendChild(item); + } + } catch (e) { /* admin API may not be enabled */ } + } + + // ── Blocked Keys table ── + async function loadBlocked() { + try { + const patterns = await adminFetch("/api/admin/api-keys/blocked"); + const tbody = $("#blocked-body"); + tbody.innerHTML = ""; + if (!patterns || !patterns.length) { + tbody.innerHTML = 'No blocked keys'; + return; + } + for (const p of patterns) { + const tr = document.createElement("tr"); + tr.innerHTML = `${esc(p)}`; + tbody.appendChild(tr); + } + for (const btn of tbody.querySelectorAll(".btn-delete")) { + btn.addEventListener("click", async () => { + await adminFetch("/api/admin/api-keys/block?pattern=" + encodeURIComponent(btn.dataset.pattern), { method: "DELETE" }); + loadBlocked(); + }); + } + } catch (e) { /* admin API may not be enabled */ } + } + + // Block key button + $("#block-btn").addEventListener("click", async () => { + var pattern = $("#block-pattern").value.trim(); + if (!pattern) return; + await adminFetch("/api/admin/api-keys/block", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ pattern: pattern }), + }); + $("#block-pattern").value = ""; + loadBlocked(); + }); + // Add Rule button $("#add-rule-btn").addEventListener("click", async () => { const rule = { @@ -428,6 +879,7 @@ $("#rule-daily").value = ""; $("#rule-monthly").value = ""; loadRules(); + loadBudgetStatus(); }); // Tenant filter @@ -439,16 +891,35 @@ function loadAll() { loadSummary(); loadStats(); - loadTimeseries(); + if (currentTimeseriesTab === "cost") { + loadTimeseries(); + } else { + loadTokenTimeseries(); + } loadProviderChart(); + loadAgentChart(); + loadModelChart(); loadCosts(); loadExpensive(); - loadSessions(); + if (currentSessionTab === "active") { + loadSessions(); + } else { + loadSessionHistory(parseInt(currentSessionTab, 10)); + } + loadLatency(); loadAPIKeys(); loadRules(); + loadBudgetStatus(); + loadBlocked(); } - $("#timeseries-hours").addEventListener("change", loadTimeseries); + $("#timeseries-hours").addEventListener("change", () => { + if (currentTimeseriesTab === "cost") { + loadTimeseries(); + } else { + loadTokenTimeseries(); + } + }); $("#costs-group").addEventListener("change", loadCosts); loadAll(); diff --git a/internal/dashboard/static/index.html b/internal/dashboard/static/index.html index fdefbeb..3eaca89 100644 --- a/internal/dashboard/static/index.html +++ b/internal/dashboard/static/index.html @@ -4,14 +4,44 @@ AgentLedger Dashboard + - +
- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Know what your agents cost.
@@ -61,6 +91,10 @@

Cost Over Time

+
+ + +
+ +
+
+
diff --git a/internal/dashboard/static/style.css b/internal/dashboard/static/style.css index 2879521..48055ca 100644 --- a/internal/dashboard/static/style.css +++ b/internal/dashboard/static/style.css @@ -21,8 +21,11 @@ header { } .header-title { display: flex; - align-items: baseline; - gap: 1rem; + align-items: center; + gap: 0.75rem; +} +.header-logo-mark { + flex-shrink: 0; } .logo { font-size: 1.35rem; @@ -31,6 +34,7 @@ header { letter-spacing: -0.02em; font-family: "SF Mono", "Fira Code", "Cascadia Code", Menlo, monospace; } +.logo-accent { color: #388bfd; } .tagline { font-size: 0.8rem; color: #8b949e; } main { padding: 1.5rem 2rem; max-width: 1400px; margin: 0 auto; } @@ -103,7 +107,7 @@ main { padding: 1.5rem 2rem; max-width: 1400px; margin: 0 auto; } gap: 1.5rem; margin-bottom: 1.5rem; } -.grid-2col > .panel { margin-bottom: 0; } +.grid-2col > .panel { margin-bottom: 0; min-height: 200px; } @media (max-width: 1024px) { .grid-2col { grid-template-columns: 1fr; } } @@ -166,6 +170,10 @@ th { text-transform: uppercase; font-size: 0.7rem; letter-spacing: 0.06em; + position: sticky; + top: 0; + background: #161b22; + z-index: 1; } td { color: #c9d1d9; } td:last-child, th:last-child { text-align: right; } @@ -244,3 +252,69 @@ tr:hover td { background: rgba(56, 139, 253, 0.04); } transition: all 0.15s; } .btn-delete:hover { background: #f8514922; border-color: #f85149; } + +/* Tab bar */ +.tab-bar { + display: flex; + gap: 0; + border: 1px solid #30363d; + border-radius: 6px; + overflow: hidden; +} +.tab-btn { + background: #0d1117; + color: #8b949e; + border: none; + border-right: 1px solid #30363d; + padding: 0.3rem 0.7rem; + font-size: 0.75rem; + font-weight: 500; + cursor: pointer; + transition: all 0.15s; +} +.tab-btn:last-child { border-right: none; } +.tab-btn:hover { color: #c9d1d9; background: #161b22; } +.tab-btn.tab-active { + background: #238636; + color: #fff; +} + +/* Budget bars */ +.budget-item { + margin-bottom: 1rem; + padding: 0.75rem; + background: #0d1117; + border: 1px solid #21262d; + border-radius: 8px; +} +.budget-item-label { + font-size: 0.8rem; + color: #c9d1d9; + margin-bottom: 0.5rem; + display: flex; + justify-content: space-between; +} +.budget-bar { + height: 8px; + background: #21262d; + border-radius: 4px; + overflow: hidden; + margin-bottom: 0.25rem; +} +.budget-fill { + height: 100%; + border-radius: 4px; + transition: width 0.3s ease; + background: #3fb950; +} +.budget-fill-warn { background: #d29922; } +.budget-fill-danger { background: #f85149; } +.budget-bar-label { + font-size: 0.65rem; + color: #8b949e; + display: flex; + justify-content: space-between; +} + +/* Session anomaly highlight */ +.session-anomaly td:first-child { border-left: 3px solid #f85149; } diff --git a/internal/ledger/ledger.go b/internal/ledger/ledger.go index 83b924c..4c7dc8a 100644 --- a/internal/ledger/ledger.go +++ b/internal/ledger/ledger.go @@ -31,6 +31,15 @@ type Ledger interface { // QueryErrorStats returns error counts and average metrics for the time window. QueryErrorStats(ctx context.Context, since, until time.Time, tenantID string) (*ErrorStats, error) + // QueryRecentSessions returns sessions within the time window, optionally filtered by status. + QueryRecentSessions(ctx context.Context, since, until time.Time, status string, limit int) ([]SessionRecord, error) + + // QueryLatencyPercentiles returns P50/P90/P99 latency and a histogram distribution. + QueryLatencyPercentiles(ctx context.Context, since, until time.Time, tenantID string) (*LatencyStats, error) + + // QueryTokenTimeseries returns token counts bucketed by time interval. + QueryTokenTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TokenTimeseriesPoint, error) + // Close releases any held resources. Close() error } diff --git a/internal/ledger/models.go b/internal/ledger/models.go index b93039f..44f168d 100644 --- a/internal/ledger/models.go +++ b/internal/ledger/models.go @@ -73,3 +73,38 @@ type ErrorStats struct { AvgDurationMS float64 `json:"avg_duration_ms"` AvgCostPerReq float64 `json:"avg_cost_per_request"` } + +// SessionRecord represents a completed or active agent session for the dashboard API. +type SessionRecord struct { + ID string `json:"id"` + AgentID string `json:"agent_id"` + UserID string `json:"user_id"` + Task string `json:"task"` + StartedAt time.Time `json:"started_at"` + EndedAt *time.Time `json:"ended_at"` + Status string `json:"status"` + CallCount int `json:"call_count"` + TotalCostUSD float64 `json:"total_cost_usd"` + TotalTokens int `json:"total_tokens"` +} + +// LatencyStats holds percentile and distribution data for request latencies. +type LatencyStats struct { + P50 float64 `json:"p50_ms"` + P90 float64 `json:"p90_ms"` + P99 float64 `json:"p99_ms"` + Buckets []LatencyBucket `json:"buckets"` +} + +// LatencyBucket represents a single bucket in a latency distribution histogram. +type LatencyBucket struct { + Label string `json:"label"` + Count int `json:"count"` +} + +// TokenTimeseriesPoint represents a single time-bucketed token usage data point. +type TokenTimeseriesPoint struct { + Timestamp time.Time `json:"Timestamp"` + InputTokens int64 `json:"InputTokens"` + OutputTokens int64 `json:"OutputTokens"` +} diff --git a/internal/ledger/postgres.go b/internal/ledger/postgres.go index 5590611..694c213 100644 --- a/internal/ledger/postgres.go +++ b/internal/ledger/postgres.go @@ -130,9 +130,9 @@ func (p *Postgres) QueryCosts(ctx context.Context, filter CostFilter) ([]CostEnt func (p *Postgres) QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TimeseriesPoint, error) { bucket := "date_trunc('hour', timestamp)" switch interval { - case "minute": + case "minute": //nolint:goconst bucket = "date_trunc('minute', timestamp)" - case "day": + case "day": //nolint:goconst bucket = "date_trunc('day', timestamp)" } @@ -336,6 +336,143 @@ func (p *Postgres) ListActiveSessions(ctx context.Context) ([]agent.Session, err return sessions, rows.Err() } +// QueryRecentSessions returns sessions within the time window, optionally filtered by status. +func (p *Postgres) QueryRecentSessions(ctx context.Context, since, until time.Time, status string, limit int) ([]SessionRecord, error) { + where := "started_at >= $1 AND started_at <= $2" + args := []any{since.UTC(), until.UTC()} + if status != "" { + args = append(args, status) + where += fmt.Sprintf(" AND status = $%d", len(args)) + } + args = append(args, limit) + + q := fmt.Sprintf(`SELECT id, agent_id, user_id, task, started_at, ended_at, status, `+ //nolint:gosec // where clause built from trusted code, not user input + `call_count, total_cost_usd, total_tokens + FROM agent_sessions WHERE %s + ORDER BY started_at DESC LIMIT $%d`, where, len(args)) + + rows, err := p.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("querying recent sessions: %w", err) + } + defer func() { _ = rows.Close() }() + + var records []SessionRecord + for rows.Next() { + var r SessionRecord + var endedAt sql.NullTime + if err := rows.Scan(&r.ID, &r.AgentID, &r.UserID, &r.Task, + &r.StartedAt, &endedAt, &r.Status, + &r.CallCount, &r.TotalCostUSD, &r.TotalTokens); err != nil { + return nil, fmt.Errorf("scanning session record: %w", err) + } + if endedAt.Valid { + r.EndedAt = &endedAt.Time + } + records = append(records, r) + } + return records, rows.Err() +} + +// QueryLatencyPercentiles returns P50/P90/P99 latency and a histogram distribution. +func (p *Postgres) QueryLatencyPercentiles(ctx context.Context, since, until time.Time, tenantID string) (*LatencyStats, error) { + where := "timestamp >= $1 AND timestamp <= $2" + args := []any{since.UTC(), until.UTC()} + if tenantID != "" { + args = append(args, tenantID) + where += fmt.Sprintf(" AND tenant_id = $%d", len(args)) + } + + // Bucket distribution. + bucketQ := fmt.Sprintf(`SELECT + CASE + WHEN duration_ms < 100 THEN '<100ms' + WHEN duration_ms < 500 THEN '100-500ms' + WHEN duration_ms < 1000 THEN '500ms-1s' + WHEN duration_ms < 3000 THEN '1-3s' + WHEN duration_ms < 10000 THEN '3-10s' + ELSE '>10s' + END as bucket, + COUNT(*) as cnt + FROM usage_records WHERE %s + GROUP BY bucket + ORDER BY MIN(duration_ms) ASC`, where) + + rows, err := p.db.QueryContext(ctx, bucketQ, args...) + if err != nil { + return nil, fmt.Errorf("querying latency buckets: %w", err) + } + defer func() { _ = rows.Close() }() + + var buckets []LatencyBucket + for rows.Next() { + var b LatencyBucket + if err := rows.Scan(&b.Label, &b.Count); err != nil { + return nil, fmt.Errorf("scanning latency bucket: %w", err) + } + buckets = append(buckets, b) + } + if err := rows.Err(); err != nil { + return nil, err + } + + // Percentiles using native PostgreSQL PERCENTILE_CONT. + percQ := fmt.Sprintf(`SELECT + COALESCE(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY duration_ms), 0), + COALESCE(PERCENTILE_CONT(0.9) WITHIN GROUP (ORDER BY duration_ms), 0), + COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY duration_ms), 0) + FROM usage_records WHERE %s AND duration_ms > 0`, where) + + stats := &LatencyStats{Buckets: buckets} + if err := p.db.QueryRowContext(ctx, percQ, args...).Scan(&stats.P50, &stats.P90, &stats.P99); err != nil { + return nil, fmt.Errorf("querying latency percentiles: %w", err) + } + return stats, nil +} + +// QueryTokenTimeseries returns token counts bucketed by time interval. +func (p *Postgres) QueryTokenTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TokenTimeseriesPoint, error) { + bucket := "date_trunc('hour', timestamp)" + switch interval { + case "minute": //nolint:goconst + bucket = "date_trunc('minute', timestamp)" + case "day": //nolint:goconst + bucket = "date_trunc('day', timestamp)" + } + + where := "timestamp >= $1 AND timestamp <= $2" //nolint:goconst + args := []any{since.UTC(), until.UTC()} + if tenantID != "" { + args = append(args, tenantID) + where += fmt.Sprintf(" AND tenant_id = $%d", len(args)) + } + + q := fmt.Sprintf(`SELECT + %s as bucket, + COALESCE(SUM(input_tokens), 0), + COALESCE(SUM(output_tokens), 0) + FROM usage_records + WHERE %s + GROUP BY bucket + ORDER BY bucket ASC`, bucket, where) + + rows, err := p.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("querying token timeseries: %w", err) + } + defer func() { _ = rows.Close() }() + + var points []TokenTimeseriesPoint + for rows.Next() { + var pt TokenTimeseriesPoint + if err := rows.Scan(&pt.Timestamp, &pt.InputTokens, &pt.OutputTokens); err != nil { + return nil, fmt.Errorf("scanning token timeseries point: %w", err) + } + points = append(points, pt) + } + return points, rows.Err() +} + // DB returns the underlying database connection for use by other packages // (e.g., admin config store). func (p *Postgres) DB() *sql.DB { diff --git a/internal/ledger/recorder_test.go b/internal/ledger/recorder_test.go index 1ef4493..e7c3716 100644 --- a/internal/ledger/recorder_test.go +++ b/internal/ledger/recorder_test.go @@ -39,6 +39,15 @@ func (c *countingLedger) QueryRecentExpensive(_ context.Context, _, _ time.Time, func (c *countingLedger) QueryErrorStats(_ context.Context, _, _ time.Time, _ string) (*ErrorStats, error) { return &ErrorStats{}, nil } +func (c *countingLedger) QueryRecentSessions(_ context.Context, _, _ time.Time, _ string, _ int) ([]SessionRecord, error) { + return nil, nil +} +func (c *countingLedger) QueryLatencyPercentiles(_ context.Context, _, _ time.Time, _ string) (*LatencyStats, error) { + return &LatencyStats{}, nil +} +func (c *countingLedger) QueryTokenTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]TokenTimeseriesPoint, error) { + return nil, nil +} func (c *countingLedger) Close() error { return nil } @@ -68,6 +77,15 @@ func (f *failingLedger) QueryRecentExpensive(_ context.Context, _, _ time.Time, func (f *failingLedger) QueryErrorStats(_ context.Context, _, _ time.Time, _ string) (*ErrorStats, error) { return &ErrorStats{}, nil } +func (f *failingLedger) QueryRecentSessions(_ context.Context, _, _ time.Time, _ string, _ int) ([]SessionRecord, error) { + return nil, nil +} +func (f *failingLedger) QueryLatencyPercentiles(_ context.Context, _, _ time.Time, _ string) (*LatencyStats, error) { + return &LatencyStats{}, nil +} +func (f *failingLedger) QueryTokenTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]TokenTimeseriesPoint, error) { + return nil, nil +} func (f *failingLedger) Close() error { return nil } diff --git a/internal/ledger/sqlite.go b/internal/ledger/sqlite.go index 2e7e1f0..368a955 100644 --- a/internal/ledger/sqlite.go +++ b/internal/ledger/sqlite.go @@ -134,9 +134,9 @@ func (s *SQLite) QueryCostTimeseries(ctx context.Context, interval string, since // but strftime only parses ISO8601. Use substr to extract the datetime portion. bucket := "strftime('%Y-%m-%d %H:00:00', substr(timestamp, 1, 19))" switch interval { - case "minute": + case "minute": //nolint:goconst bucket = "strftime('%Y-%m-%d %H:%M:00', substr(timestamp, 1, 19))" - case "day": + case "day": //nolint:goconst bucket = "strftime('%Y-%m-%d 00:00:00', substr(timestamp, 1, 19))" } @@ -342,6 +342,179 @@ func (s *SQLite) ListActiveSessions(ctx context.Context) ([]agent.Session, error return sessions, rows.Err() } +// QueryRecentSessions returns sessions within the time window, optionally filtered by status. +func (s *SQLite) QueryRecentSessions(ctx context.Context, since, until time.Time, status string, limit int) ([]SessionRecord, error) { + where := "started_at >= ? AND started_at <= ?" + args := []any{since.UTC(), until.UTC()} + if status != "" { + where += " AND status = ?" + args = append(args, status) + } + args = append(args, limit) + + q := fmt.Sprintf(`SELECT id, agent_id, user_id, task, started_at, ended_at, status, + call_count, total_cost_usd, total_tokens + FROM agent_sessions WHERE %s + ORDER BY started_at DESC LIMIT ?`, where) + + rows, err := s.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("querying recent sessions: %w", err) + } + defer func() { _ = rows.Close() }() + + var records []SessionRecord + for rows.Next() { + var r SessionRecord + var endedAt sql.NullTime + if err := rows.Scan(&r.ID, &r.AgentID, &r.UserID, &r.Task, + &r.StartedAt, &endedAt, &r.Status, + &r.CallCount, &r.TotalCostUSD, &r.TotalTokens); err != nil { + return nil, fmt.Errorf("scanning session record: %w", err) + } + if endedAt.Valid { + r.EndedAt = &endedAt.Time + } + records = append(records, r) + } + return records, rows.Err() +} + +// QueryLatencyPercentiles returns P50/P90/P99 latency and a histogram distribution. +func (s *SQLite) QueryLatencyPercentiles(ctx context.Context, since, until time.Time, tenantID string) (*LatencyStats, error) { + where := "timestamp >= ? AND timestamp <= ?" //nolint:goconst + args := []any{since.UTC(), until.UTC()} + if tenantID != "" { + where += " AND tenant_id = ?" //nolint:goconst + args = append(args, tenantID) + } + + // Bucket distribution. + bucketQ := fmt.Sprintf(`SELECT + CASE + WHEN duration_ms < 100 THEN '<100ms' + WHEN duration_ms < 500 THEN '100-500ms' + WHEN duration_ms < 1000 THEN '500ms-1s' + WHEN duration_ms < 3000 THEN '1-3s' + WHEN duration_ms < 10000 THEN '3-10s' + ELSE '>10s' + END as bucket, + COUNT(*) as cnt + FROM usage_records WHERE %s + GROUP BY bucket + ORDER BY MIN(duration_ms) ASC`, where) + + rows, err := s.db.QueryContext(ctx, bucketQ, args...) + if err != nil { + return nil, fmt.Errorf("querying latency buckets: %w", err) + } + defer func() { _ = rows.Close() }() + + var buckets []LatencyBucket + for rows.Next() { + var b LatencyBucket + if scanErr := rows.Scan(&b.Label, &b.Count); scanErr != nil { + return nil, fmt.Errorf("scanning latency bucket: %w", scanErr) + } + buckets = append(buckets, b) + } + if rowsErr := rows.Err(); rowsErr != nil { + return nil, rowsErr + } + + // Percentiles: fetch sorted durations and compute in Go. + percQ := fmt.Sprintf(`SELECT duration_ms FROM usage_records + WHERE %s AND duration_ms > 0 + ORDER BY duration_ms ASC LIMIT 10000`, where) + + pRows, err := s.db.QueryContext(ctx, percQ, args...) + if err != nil { + return nil, fmt.Errorf("querying latency percentiles: %w", err) + } + defer func() { _ = pRows.Close() }() + + var durations []float64 + for pRows.Next() { + var d float64 + if scanErr := pRows.Scan(&d); scanErr != nil { + return nil, fmt.Errorf("scanning duration: %w", scanErr) + } + durations = append(durations, d) + } + if rowsErr := pRows.Err(); rowsErr != nil { + return nil, err + } + + stats := &LatencyStats{Buckets: buckets} + if len(durations) > 0 { + stats.P50 = percentile(durations, 0.50) + stats.P90 = percentile(durations, 0.90) + stats.P99 = percentile(durations, 0.99) + } + return stats, nil +} + +// QueryTokenTimeseries returns token counts bucketed by time interval. +func (s *SQLite) QueryTokenTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TokenTimeseriesPoint, error) { + bucket := "strftime('%Y-%m-%d %H:00:00', substr(timestamp, 1, 19))" + switch interval { + case "minute": //nolint:goconst + bucket = "strftime('%Y-%m-%d %H:%M:00', substr(timestamp, 1, 19))" + case "day": //nolint:goconst + bucket = "strftime('%Y-%m-%d 00:00:00', substr(timestamp, 1, 19))" + } + + where := "timestamp >= ? AND timestamp <= ?" //nolint:goconst + args := []any{since.UTC(), until.UTC()} + if tenantID != "" { + where += " AND tenant_id = ?" //nolint:goconst + args = append(args, tenantID) + } + + q := fmt.Sprintf(`SELECT + %s as bucket, + COALESCE(SUM(input_tokens), 0), + COALESCE(SUM(output_tokens), 0) + FROM usage_records + WHERE %s + GROUP BY bucket + ORDER BY bucket ASC`, bucket, where) + + rows, err := s.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("querying token timeseries: %w", err) + } + defer func() { _ = rows.Close() }() + + var points []TokenTimeseriesPoint + for rows.Next() { + var p TokenTimeseriesPoint + var ts string + if err := rows.Scan(&ts, &p.InputTokens, &p.OutputTokens); err != nil { + return nil, fmt.Errorf("scanning token timeseries point: %w", err) + } + p.Timestamp, _ = time.Parse("2006-01-02 15:04:05", ts) + points = append(points, p) + } + return points, rows.Err() +} + +// percentile computes the p-th percentile from a sorted slice of float64 values. +func percentile(sorted []float64, p float64) float64 { + n := len(sorted) + if n == 0 { + return 0 + } + idx := p * float64(n-1) + lower := int(idx) + upper := lower + 1 + if upper >= n { + return sorted[n-1] + } + frac := idx - float64(lower) + return sorted[lower]*(1-frac) + sorted[upper]*frac +} + // DB returns the underlying database connection for use by other packages // (e.g., admin config store). func (s *SQLite) DB() *sql.DB { diff --git a/internal/mcp/httpproxy.go b/internal/mcp/httpproxy.go index b70ccb2..d0540e8 100644 --- a/internal/mcp/httpproxy.go +++ b/internal/mcp/httpproxy.go @@ -7,6 +7,8 @@ import ( "io" "log/slog" "net/http" + "net/url" + "path" "strings" "time" @@ -46,6 +48,13 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // Path sanitization: reject path traversal attempts. + if strings.Contains(r.URL.Path, "..") { + http.Error(w, "invalid path", http.StatusBadRequest) + return + } + cleanPath := path.Clean(r.URL.Path) + // Read request body with size limit. r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody) body, err := io.ReadAll(r.Body) @@ -71,13 +80,20 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Intercept outbound request. interceptor.HandleMessage(body, true, agentCtx) - // Build upstream request. - upstreamURL := p.upstream + r.URL.Path + // Build upstream URL safely using url.Parse + path.Join. + upstreamBase, err := url.Parse(p.upstream) + if err != nil { + p.logger.Error("failed to parse upstream URL", "error", err) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + upstreamBase.Path = path.Join(upstreamBase.Path, cleanPath) if r.URL.RawQuery != "" { - upstreamURL += "?" + r.URL.RawQuery + upstreamBase.RawQuery = r.URL.RawQuery } + upstreamURL := upstreamBase.String() - upReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, upstreamURL, bytes.NewReader(body)) //nolint:gosec // upstream URL is from trusted server config + upReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, upstreamURL, bytes.NewReader(body)) if err != nil { p.logger.Error("failed to create upstream request", "error", err) http.Error(w, "internal server error", http.StatusInternalServerError) diff --git a/internal/mcp/httpproxy_test.go b/internal/mcp/httpproxy_test.go index 3f34335..662dd05 100644 --- a/internal/mcp/httpproxy_test.go +++ b/internal/mcp/httpproxy_test.go @@ -141,3 +141,24 @@ func TestHTTPProxy_NonToolCallPassthrough(t *testing.T) { t.Errorf("expected 0 records for tools/list, got %d", len(records)) } } + +func TestHTTPProxy_PathTraversal(t *testing.T) { + store := &recordingLedger{} + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + rec := ledger.NewRecorder(store, 100, 1, logger) + defer rec.Close() + + proxy := NewHTTPProxy("http://localhost:9999", NewPricer(nil), rec, logger) + + // Attempt path traversal. + reqBody := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test"}}` + req := httptest.NewRequest(http.MethodPost, "/mcp/../../../etc/passwd", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("path traversal: status = %d, want 400", rr.Code) + } +} diff --git a/internal/mcp/interceptor_test.go b/internal/mcp/interceptor_test.go index 5fdd721..fd9fdd5 100644 --- a/internal/mcp/interceptor_test.go +++ b/internal/mcp/interceptor_test.go @@ -47,6 +47,15 @@ func (r *recordingLedger) QueryRecentExpensive(_ context.Context, _, _ time.Time func (r *recordingLedger) QueryErrorStats(_ context.Context, _, _ time.Time, _ string) (*ledger.ErrorStats, error) { return &ledger.ErrorStats{}, nil } +func (r *recordingLedger) QueryRecentSessions(_ context.Context, _, _ time.Time, _ string, _ int) ([]ledger.SessionRecord, error) { + return nil, nil +} +func (r *recordingLedger) QueryLatencyPercentiles(_ context.Context, _, _ time.Time, _ string) (*ledger.LatencyStats, error) { + return &ledger.LatencyStats{}, nil +} +func (r *recordingLedger) QueryTokenTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TokenTimeseriesPoint, error) { + return nil, nil +} func (r *recordingLedger) Close() error { return nil } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index af9092e..ef24976 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -26,6 +26,11 @@ import ( "github.com/WDZ-Dev/agent-ledger/internal/tenant" ) +const ( + maxRequestBodyBytes = 10 << 20 // 10 MB + maxResponseBodyBytes = 50 << 20 // 50 MB +) + // context keys for passing data between Rewrite and ModifyResponse. type ctxKey int @@ -113,9 +118,14 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Read request body for metadata extraction. + r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes) body, err := io.ReadAll(r.Body) if err != nil { p.logger.Error("reading request body", "error", err) + if err.Error() == "http: request body too large" { + writeJSONError(w, http.StatusRequestEntityTooLarge, "request body too large") + return + } writeJSONError(w, http.StatusBadRequest, "failed to read request body") return } @@ -295,7 +305,7 @@ func (p *Proxy) modifyResponse(resp *http.Response) error { } // Non-streaming: read, parse, record, replace body. - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes)) _ = resp.Body.Close() if err != nil { p.logger.Error("reading response body", "error", err) @@ -370,7 +380,7 @@ func (p *Proxy) modifyResponse(resp *http.Response) error { func (p *Proxy) errorHandler(w http.ResponseWriter, _ *http.Request, err error) { p.logger.Error("proxy error", "error", err) - writeJSONError(w, http.StatusBadGateway, "upstream request failed: "+err.Error()) + writeJSONError(w, http.StatusBadGateway, "upstream request failed") } func writeJSONError(w http.ResponseWriter, status int, msg string) { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index a38f4ed..e5b0740 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -54,6 +54,15 @@ func (m *mockStore) QueryRecentExpensive(_ context.Context, _, _ time.Time, _ st func (m *mockStore) QueryErrorStats(_ context.Context, _, _ time.Time, _ string) (*ledger.ErrorStats, error) { return &ledger.ErrorStats{}, nil } +func (m *mockStore) QueryRecentSessions(_ context.Context, _, _ time.Time, _ string, _ int) ([]ledger.SessionRecord, error) { + return nil, nil +} +func (m *mockStore) QueryLatencyPercentiles(_ context.Context, _, _ time.Time, _ string) (*ledger.LatencyStats, error) { + return &ledger.LatencyStats{}, nil +} +func (m *mockStore) QueryTokenTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TokenTimeseriesPoint, error) { + return nil, nil +} func (m *mockStore) Close() error { return nil } diff --git a/internal/ratelimit/limiter.go b/internal/ratelimit/limiter.go index 9c54095..c5b1321 100644 --- a/internal/ratelimit/limiter.go +++ b/internal/ratelimit/limiter.go @@ -26,6 +26,9 @@ type Limiter struct { // key -> window start -> count minuteCounters map[string]*slidingWindow hourCounters map[string]*slidingWindow + + done chan struct{} + closed sync.Once } type slidingWindow struct { @@ -35,11 +38,21 @@ type slidingWindow struct { // New creates a rate limiter from configuration. func New(cfg Config) *Limiter { - return &Limiter{ + l := &Limiter{ config: cfg, minuteCounters: make(map[string]*slidingWindow), hourCounters: make(map[string]*slidingWindow), + done: make(chan struct{}), } + go l.cleanupLoop() + return l +} + +// Close stops the background cleanup goroutine. +func (l *Limiter) Close() { + l.closed.Do(func() { + close(l.done) + }) } // Enabled returns true if any rate limits are configured. @@ -121,3 +134,34 @@ func (l *Limiter) mergeWithDefault(r Rule) Rule { } return r } + +func (l *Limiter) cleanupLoop() { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + l.evictExpired() + case <-l.done: + return + } + } +} + +func (l *Limiter) evictExpired() { + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + for key, w := range l.minuteCounters { + if now.After(w.windowEnd) { + delete(l.minuteCounters, key) + } + } + for key, w := range l.hourCounters { + if now.After(w.windowEnd) { + delete(l.hourCounters, key) + } + } +} diff --git a/internal/ratelimit/limiter_test.go b/internal/ratelimit/limiter_test.go index 00b8d37..67def2f 100644 --- a/internal/ratelimit/limiter_test.go +++ b/internal/ratelimit/limiter_test.go @@ -9,6 +9,7 @@ func TestLimiterAllows(t *testing.T) { l := New(Config{ Default: Rule{RequestsPerMinute: 3}, }) + defer l.Close() for i := 0; i < 3; i++ { ok, _ := l.Allow("sk-test", "hash1") @@ -29,12 +30,14 @@ func TestLimiterAllows(t *testing.T) { func TestLimiterEnabled(t *testing.T) { l := New(Config{}) + defer l.Close() if l.Enabled() { t.Error("should not be enabled with zero config") } - l = New(Config{Default: Rule{RequestsPerMinute: 10}}) - if !l.Enabled() { + l2 := New(Config{Default: Rule{RequestsPerMinute: 10}}) + defer l2.Close() + if !l2.Enabled() { t.Error("should be enabled with default rule") } } @@ -46,6 +49,7 @@ func TestLimiterPerKeyRules(t *testing.T) { {APIKeyPattern: "sk-dev-*", RequestsPerMinute: 2}, }, }) + defer l.Close() // Dev key — limited to 2/min. l.Allow("sk-dev-abc", "dev-hash") @@ -66,6 +70,7 @@ func TestLimiterWindowResets(t *testing.T) { l := New(Config{ Default: Rule{RequestsPerMinute: 1}, }) + defer l.Close() ok, _ := l.Allow("sk-test", "hash1") if !ok { @@ -86,3 +91,28 @@ func TestLimiterWindowResets(t *testing.T) { t.Error("should be allowed after window reset") } } + +func TestLimiterEvictsExpired(t *testing.T) { + l := New(Config{ + Default: Rule{RequestsPerMinute: 10}, + }) + defer l.Close() + + // Add a request to create a window. + l.Allow("sk-test", "hash1") + + // Expire the window manually. + l.mu.Lock() + l.minuteCounters["hash1"].windowEnd = time.Now().Add(-time.Second) + l.mu.Unlock() + + l.evictExpired() + + l.mu.Lock() + _, exists := l.minuteCounters["hash1"] + l.mu.Unlock() + + if exists { + t.Error("expected expired window to be evicted") + } +} diff --git a/mkdocs.yml b/mkdocs.yml index 5eabe87..f13d786 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,6 +4,7 @@ site_description: "Know what your agents cost. Real-time cost attribution, budge theme: name: material + favicon: assets/favicon.svg palette: - scheme: slate primary: custom