Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ func startDaemonBackground() error {
fmt.Println("zend is now running (started by another process).")
return nil
}
return fmt.Errorf("daemon failed to start (concurrent startup attempt failed, please retry)")
} else if err != nil {
return fmt.Errorf("cannot acquire daemon lock: %w", err)
}
Expand Down
11 changes: 3 additions & 8 deletions internal/daemon/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,7 @@ func (d *Daemon) handleDaemonShutdown(w http.ResponseWriter, r *http.Request) {
// Trigger shutdown in background so the response is sent first
go func() {
time.Sleep(100 * time.Millisecond)
select {
case <-d.shutdownCh:
// Already closed
default:
close(d.shutdownCh)
}
d.shutdownOnce.Do(func() { close(d.shutdownCh) })
}()
}

Expand Down Expand Up @@ -232,9 +227,9 @@ func (d *Daemon) handleDaemonSessions(w http.ResponseWriter, r *http.Request) {

func (d *Daemon) handleGetSessions(w http.ResponseWriter, r *http.Request) {
d.mu.RLock()
sessions := make([]*SessionInfo, 0, len(d.sessions))
sessions := make([]SessionInfo, 0, len(d.sessions))
for _, s := range d.sessions {
sessions = append(sessions, s)
sessions = append(sessions, *s) // copy struct to avoid data race after unlock
}
d.mu.RUnlock()

Expand Down
9 changes: 5 additions & 4 deletions internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ type Daemon struct {
currentGates *config.FeatureGates

// Shutdown channel - closed when shutdown is requested via API
shutdownCh chan struct{}
runCtx context.Context
runCancel context.CancelFunc
bgWG sync.WaitGroup
shutdownCh chan struct{}
shutdownOnce sync.Once
runCtx context.Context
runCancel context.CancelFunc
bgWG sync.WaitGroup

// Goroutine leak detection
baselineGoroutines int
Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,8 @@ func TestLoadBalancer_RoundRobinPerProfileIsolation(t *testing.T) {
// Profile A and B should have the same rotation sequence (both start from counter=0)
for i := 0; i < 3; i++ {
if profileAResults[i] != profileBResults[i] {
// This is the key assertion: both profiles should independently cycle
// through the same sequence since they start from their own counter=0
t.Errorf("profile isolation broken at index %d: profile-a=%s, profile-b=%s (counters should be independent)",
i, profileAResults[i], profileBResults[i])
}
}

Expand Down
124 changes: 115 additions & 9 deletions internal/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,13 @@ func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
strategy = *decision.StrategyOverride
}

providers = s.LoadBalancer.Select(providers, strategy, model, s.Profile, modelOverrides, weights)
// Use a scenario-specific counter key so scenario route round-robin
// does not advance the default profile's counter.
rrKey := s.Profile
if usingScenarioRoute && decision.Scenario != "" {
rrKey = s.Profile + ":scenario:" + decision.Scenario
}
providers = s.LoadBalancer.Select(providers, strategy, model, rrKey, modelOverrides, weights)
}

// Track provider failure details for error reporting
Expand Down Expand Up @@ -961,8 +967,14 @@ func (s *ProxyServer) tryProviders(w http.ResponseWriter, r *http.Request, provi
s.Logger.Printf("[%s] %s", p.Name, msg)
s.logStructured(p.Name, r.Method, r.URL.Path, resp.StatusCode, LogLevelInfo, msg, sessionID, clientType)

// Update session cache with token usage from response
s.updateSessionCache(sessionID, resp)
// Update session cache with token usage from response.
// For SSE (streaming), wrap the body with an extractor that parses
// usage events in-flight so longContext routing stays accurate.
if sessionID != "" && strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
resp.Body = &sseUsageExtractor{r: resp.Body, sessionID: sessionID}
} else {
s.updateSessionCache(sessionID, resp)
}

// Record usage and metrics
s.recordUsageAndMetrics(p.Name, sessionID, clientType, bodyBytes, resp, requestID, requestStart, requestFormat, failures)
Expand Down Expand Up @@ -1195,16 +1207,26 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp *
return
}

// Transform Responses API → Anthropic
// Check if client expects Anthropic format (anthropic-messages or legacy anthropic)
if transform.NormalizeFormat(requestFormat) == config.ProviderTypeAnthropic {
// Transform Responses API → client format
switch transform.NormalizeFormat(requestFormat) {
case config.ProviderTypeAnthropic:
transformed, err := transform.ResponsesAPIToAnthropic(body)
if err != nil {
s.Logger.Printf("[%s] Responses API response transform error: %v", p.Name, err)
} else {
s.Logger.Printf("[%s] transformed Responses API → Anthropic", p.Name)
body = transformed
}
case config.ProviderTypeOpenAI:
if requestFormat == transform.FormatOpenAIChat {
transformed, err := transform.ResponsesAPIToOpenAIChat(body)
if err != nil {
s.Logger.Printf("[%s] Responses API → Chat Completions transform error: %v", p.Name, err)
} else {
s.Logger.Printf("[%s] transformed Responses API → OpenAI Chat Completions", p.Name)
body = transformed
}
}
}

// Copy headers (except Content-Length which may have changed)
Expand Down Expand Up @@ -1232,14 +1254,23 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp *
flusher, ok := w.(http.Flusher)

var reader io.Reader = resp.Body
// Check if client expects Anthropic format (anthropic-messages or legacy anthropic)
if transform.NormalizeFormat(requestFormat) == config.ProviderTypeAnthropic {
switch transform.NormalizeFormat(requestFormat) {
case config.ProviderTypeAnthropic:
st := &transform.StreamTransformer{
ClientFormat: "anthropic",
ProviderFormat: "openai-responses",
ProviderFormat: transform.FormatOpenAIResponses,
}
reader = st.TransformSSEStream(resp.Body)
s.Logger.Printf("[%s] transforming Responses API SSE stream → Anthropic", p.Name)
case config.ProviderTypeOpenAI:
if requestFormat == transform.FormatOpenAIChat {
st := &transform.StreamTransformer{
ClientFormat: transform.FormatOpenAIChat,
ProviderFormat: transform.FormatOpenAIResponses,
}
reader = st.TransformSSEStream(resp.Body)
s.Logger.Printf("[%s] transforming Responses API SSE stream → OpenAI Chat Completions", p.Name)
}
}

buf := make([]byte, 4096)
Expand Down Expand Up @@ -1850,3 +1881,78 @@ func GetDaemonLogger() daemonLogger {
}
return nil
}

// sseUsageExtractor wraps an SSE response body, parsing Anthropic SSE events
// in-flight to extract token usage. When the stream ends, it updates the
// session cache so longContext routing has accurate usage for streaming turns.
type sseUsageExtractor struct {
r io.ReadCloser
sessionID string
partial []byte // incomplete line buffer
inputTok int
outputTok int
}

func (e *sseUsageExtractor) Read(p []byte) (n int, err error) {
n, err = e.r.Read(p)
if n > 0 {
e.processChunk(p[:n])
}
if err == io.EOF {
if e.sessionID != "" && (e.inputTok > 0 || e.outputTok > 0) {
UpdateSessionUsage(e.sessionID, &SessionUsage{
InputTokens: e.inputTok,
OutputTokens: e.outputTok,
})
}
}
return
}

func (e *sseUsageExtractor) Close() error { return e.r.Close() }

// processChunk scans raw SSE bytes for usage data events.
func (e *sseUsageExtractor) processChunk(data []byte) {
// Append to partial buffer and process line by line
buf := append(e.partial, data...)
for {
idx := bytes.IndexByte(buf, '\n')
if idx < 0 {
break
}
line := string(bytes.TrimRight(buf[:idx], "\r"))
buf = buf[idx+1:]

if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimPrefix(line, "data: ")
if payload == "[DONE]" {
continue
}
var ev map[string]interface{}
if json.Unmarshal([]byte(payload), &ev) != nil {
continue
}
evType, _ := ev["type"].(string)
switch evType {
case "message_start":
// {"type":"message_start","message":{"usage":{"input_tokens":N}}}
if msg, ok := ev["message"].(map[string]interface{}); ok {
if u, ok := msg["usage"].(map[string]interface{}); ok {
if v, ok := u["input_tokens"].(float64); ok {
e.inputTok += int(v)
}
}
}
case "message_delta":
// {"type":"message_delta","usage":{"output_tokens":N}}
if u, ok := ev["usage"].(map[string]interface{}); ok {
if v, ok := u["output_tokens"].(float64); ok {
e.outputTok += int(v)
}
}
}
}
e.partial = buf
}
123 changes: 123 additions & 0 deletions internal/proxy/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2852,6 +2852,69 @@ func TestResponsesAPIRetry(t *testing.T) {
})
}

func TestResponsesAPIRetryOpenAIChat(t *testing.T) {
// Regression test for: OpenAI Chat client receives wrong Responses API payload after retry.
// When the client is openai-chat and the provider returns "input is required", the proxy
// should retry via /responses AND transform the Responses API response back to
// Chat Completions format before returning it to the client.
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/chat/completions") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
w.Write([]byte(`{"error":{"message":"input is required (request id: abc)","type":"new_api_error"}}`))
return
}
if strings.Contains(r.URL.Path, "/responses") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"id":"resp_1","object":"response","status":"completed","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello from Responses!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`))
return
}
w.WriteHeader(404)
}))
defer backend.Close()

u, _ := url.Parse(backend.URL)
providers := []*Provider{{
Name: "openai-provider",
Type: config.ProviderTypeOpenAI,
BaseURL: u,
Token: "test-token",
Model: "gpt-5",
Healthy: true,
}}

srv := NewProxyServer(providers, discardLogger(), config.LoadBalanceFailover, nil)
// Client sends an OpenAI Chat Completions request
req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader(
`{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`))
req.Header.Set("X-Zen-Request-Format", "openai-chat")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)

if w.Code != 200 {
t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String())
}

// Response MUST be Chat Completions format, NOT Responses API format
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("response is not valid JSON: %v", err)
}
if resp["object"] != "chat.completion" {
t.Errorf("response object = %v, want chat.completion (got Responses API payload?)", resp["object"])
}
choices, ok := resp["choices"].([]interface{})
if !ok || len(choices) == 0 {
t.Fatal("response should have choices")
}
choice := choices[0].(map[string]interface{})
msg := choice["message"].(map[string]interface{})
if msg["content"] != "Hello from Responses!" {
t.Errorf("content = %v, want Hello from Responses!", msg["content"])
}
}

func TestResponsesAPIRetryStreaming(t *testing.T) {
// Mock server: 500 "input is required" on /chat/completions,
// SSE Responses API stream on /responses
Expand Down Expand Up @@ -3269,3 +3332,63 @@ func TestTransformError_ProperJSONResponse(t *testing.T) {
t.Errorf("expected message to contain error text, got: %s", message)
}
}

func TestSSEUsageExtractor(t *testing.T) {
t.Run("extracts_usage_from_anthropic_sse", func(t *testing.T) {
sse := strings.Join([]string{
`data: {"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":25,"output_tokens":0}}}`,
``,
`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`,
``,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":42}}`,
``,
`data: {"type":"message_stop"}`,
``,
}, "\n")

var got *SessionUsage
// Temporarily wire a capture via a fake session
sessionID := "test-sse-usage-session"
ClearSessionUsage(sessionID)

extractor := &sseUsageExtractor{
r: io.NopCloser(strings.NewReader(sse)),
sessionID: sessionID,
}
_, err := io.ReadAll(extractor)
if err != nil {
t.Fatalf("unexpected read error: %v", err)
}

got = GetSessionUsage(sessionID)
if got == nil {
t.Fatal("expected session usage to be updated, got nil")
}
if got.InputTokens != 25 {
t.Errorf("InputTokens = %d, want 25", got.InputTokens)
}
if got.OutputTokens != 42 {
t.Errorf("OutputTokens = %d, want 42", got.OutputTokens)
}
})

t.Run("no_update_on_empty_session", func(t *testing.T) {
sse := "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10}}}\n\n"
extractor := &sseUsageExtractor{
r: io.NopCloser(strings.NewReader(sse)),
sessionID: "", // empty → should not call UpdateSessionUsage
}
// Should complete without panic
io.ReadAll(extractor)
})

t.Run("close_delegates_to_inner", func(t *testing.T) {
extractor := &sseUsageExtractor{
r: io.NopCloser(strings.NewReader("")),
sessionID: "",
}
if err := extractor.Close(); err != nil {
t.Errorf("Close() error: %v", err)
}
})
}
Loading
Loading