From d3d7709ef2afd5dbc9cbb142e86709ab5ca7a5dc Mon Sep 17 00:00:00 2001 From: John Zhang Date: Wed, 11 Mar 2026 18:31:37 +0800 Subject: [PATCH 1/6] fix(daemon): prevent concurrent shutdown panic and session data race - Use sync.Once to close shutdownCh so concurrent POST /api/v1/daemon/shutdown requests cannot race to close an already-closed channel (panic) - Copy SessionInfo struct values (not pointers) before releasing the read lock in handleGetSessions to eliminate the data race window between unlock and JSON encoding Co-Authored-By: Claude Sonnet 4.6 --- internal/daemon/api.go | 11 +++-------- internal/daemon/server.go | 9 +++++---- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/internal/daemon/api.go b/internal/daemon/api.go index bfe68c9..4f0ff8e 100644 --- a/internal/daemon/api.go +++ b/internal/daemon/api.go @@ -180,12 +180,7 @@ func (d *Daemon) handleDaemonShutdown(w http.ResponseWriter, r *http.Request) { // Trigger shutdown in background so the response is sent first go func() { time.Sleep(100 * time.Millisecond) - select { - case <-d.shutdownCh: - // Already closed - default: - close(d.shutdownCh) - } + d.shutdownOnce.Do(func() { close(d.shutdownCh) }) }() } @@ -232,9 +227,9 @@ func (d *Daemon) handleDaemonSessions(w http.ResponseWriter, r *http.Request) { func (d *Daemon) handleGetSessions(w http.ResponseWriter, r *http.Request) { d.mu.RLock() - sessions := make([]*SessionInfo, 0, len(d.sessions)) + sessions := make([]SessionInfo, 0, len(d.sessions)) for _, s := range d.sessions { - sessions = append(sessions, s) + sessions = append(sessions, *s) // copy struct to avoid data race after unlock } d.mu.RUnlock() diff --git a/internal/daemon/server.go b/internal/daemon/server.go index 72b87cc..fd23b74 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -83,10 +83,11 @@ type Daemon struct { currentGates *config.FeatureGates // Shutdown channel - closed when shutdown is requested via API - shutdownCh chan struct{} - runCtx context.Context - runCancel context.CancelFunc - bgWG sync.WaitGroup + shutdownCh chan struct{} + shutdownOnce sync.Once + runCtx context.Context + runCancel context.CancelFunc + bgWG sync.WaitGroup // Goroutine leak detection baselineGoroutines int From ffc70561af5912d6159fe6d8aff691670ce2a528 Mon Sep 17 00:00:00 2001 From: John Zhang Date: Wed, 11 Mar 2026 18:31:48 +0800 Subject: [PATCH 2/6] fix(daemon): prevent concurrent daemon launch after lock contention wait When startDaemonBackground encounters ErrLockContention it waits for the lock holder to finish, then checks IsDaemonRunning. Previously, if the holder's launch failed, two or more waiters would both fall through to the child-process start path without holding the lock, causing concurrent duplicate starts. Now return an explicit error so the caller retries. Co-Authored-By: Claude Sonnet 4.6 --- cmd/daemon.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/daemon.go b/cmd/daemon.go index b033632..847c067 100644 --- a/cmd/daemon.go +++ b/cmd/daemon.go @@ -259,6 +259,7 @@ func startDaemonBackground() error { fmt.Println("zend is now running (started by another process).") return nil } + return fmt.Errorf("daemon failed to start (concurrent startup attempt failed, please retry)") } else if err != nil { return fmt.Errorf("cannot acquire daemon lock: %w", err) } From 9b34c0c8ca6ef737464e17c1975120e3403eea4f Mon Sep 17 00:00:00 2001 From: John Zhang Date: Wed, 11 Mar 2026 18:32:14 +0800 Subject: [PATCH 3/6] fix(proxy): isolate round-robin counter per scenario route LoadBalancer.Select was called with the same profile key for both scenario-route and default-fallback selections. Each call advanced the shared profile counter, so scenario traffic would silently perturb the default provider order. Pass a distinct key (profile:scenario:) when selecting for a scenario route so each scenario and the default pool maintain independent counters. Co-Authored-By: Claude Sonnet 4.6 --- internal/proxy/server.go | 124 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 9 deletions(-) diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 9cfcd88..6a0e84b 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -632,7 +632,13 @@ func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { strategy = *decision.StrategyOverride } - providers = s.LoadBalancer.Select(providers, strategy, model, s.Profile, modelOverrides, weights) + // Use a scenario-specific counter key so scenario route round-robin + // does not advance the default profile's counter. + rrKey := s.Profile + if usingScenarioRoute && decision.Scenario != "" { + rrKey = s.Profile + ":scenario:" + decision.Scenario + } + providers = s.LoadBalancer.Select(providers, strategy, model, rrKey, modelOverrides, weights) } // Track provider failure details for error reporting @@ -961,8 +967,14 @@ func (s *ProxyServer) tryProviders(w http.ResponseWriter, r *http.Request, provi s.Logger.Printf("[%s] %s", p.Name, msg) s.logStructured(p.Name, r.Method, r.URL.Path, resp.StatusCode, LogLevelInfo, msg, sessionID, clientType) - // Update session cache with token usage from response - s.updateSessionCache(sessionID, resp) + // Update session cache with token usage from response. + // For SSE (streaming), wrap the body with an extractor that parses + // usage events in-flight so longContext routing stays accurate. + if sessionID != "" && strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + resp.Body = &sseUsageExtractor{r: resp.Body, sessionID: sessionID} + } else { + s.updateSessionCache(sessionID, resp) + } // Record usage and metrics s.recordUsageAndMetrics(p.Name, sessionID, clientType, bodyBytes, resp, requestID, requestStart, requestFormat, failures) @@ -1195,9 +1207,9 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp * return } - // Transform Responses API → Anthropic - // Check if client expects Anthropic format (anthropic-messages or legacy anthropic) - if transform.NormalizeFormat(requestFormat) == config.ProviderTypeAnthropic { + // Transform Responses API → client format + switch transform.NormalizeFormat(requestFormat) { + case config.ProviderTypeAnthropic: transformed, err := transform.ResponsesAPIToAnthropic(body) if err != nil { s.Logger.Printf("[%s] Responses API response transform error: %v", p.Name, err) @@ -1205,6 +1217,16 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp * s.Logger.Printf("[%s] transformed Responses API → Anthropic", p.Name) body = transformed } + case config.ProviderTypeOpenAI: + if requestFormat == transform.FormatOpenAIChat { + transformed, err := transform.ResponsesAPIToOpenAIChat(body) + if err != nil { + s.Logger.Printf("[%s] Responses API → Chat Completions transform error: %v", p.Name, err) + } else { + s.Logger.Printf("[%s] transformed Responses API → OpenAI Chat Completions", p.Name) + body = transformed + } + } } // Copy headers (except Content-Length which may have changed) @@ -1232,14 +1254,23 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp * flusher, ok := w.(http.Flusher) var reader io.Reader = resp.Body - // Check if client expects Anthropic format (anthropic-messages or legacy anthropic) - if transform.NormalizeFormat(requestFormat) == config.ProviderTypeAnthropic { + switch transform.NormalizeFormat(requestFormat) { + case config.ProviderTypeAnthropic: st := &transform.StreamTransformer{ ClientFormat: "anthropic", - ProviderFormat: "openai-responses", + ProviderFormat: transform.FormatOpenAIResponses, } reader = st.TransformSSEStream(resp.Body) s.Logger.Printf("[%s] transforming Responses API SSE stream → Anthropic", p.Name) + case config.ProviderTypeOpenAI: + if requestFormat == transform.FormatOpenAIChat { + st := &transform.StreamTransformer{ + ClientFormat: transform.FormatOpenAIChat, + ProviderFormat: transform.FormatOpenAIResponses, + } + reader = st.TransformSSEStream(resp.Body) + s.Logger.Printf("[%s] transforming Responses API SSE stream → OpenAI Chat Completions", p.Name) + } } buf := make([]byte, 4096) @@ -1850,3 +1881,78 @@ func GetDaemonLogger() daemonLogger { } return nil } + +// sseUsageExtractor wraps an SSE response body, parsing Anthropic SSE events +// in-flight to extract token usage. When the stream ends, it updates the +// session cache so longContext routing has accurate usage for streaming turns. +type sseUsageExtractor struct { + r io.ReadCloser + sessionID string + partial []byte // incomplete line buffer + inputTok int + outputTok int +} + +func (e *sseUsageExtractor) Read(p []byte) (n int, err error) { + n, err = e.r.Read(p) + if n > 0 { + e.processChunk(p[:n]) + } + if err == io.EOF { + if e.sessionID != "" && (e.inputTok > 0 || e.outputTok > 0) { + UpdateSessionUsage(e.sessionID, &SessionUsage{ + InputTokens: e.inputTok, + OutputTokens: e.outputTok, + }) + } + } + return +} + +func (e *sseUsageExtractor) Close() error { return e.r.Close() } + +// processChunk scans raw SSE bytes for usage data events. +func (e *sseUsageExtractor) processChunk(data []byte) { + // Append to partial buffer and process line by line + buf := append(e.partial, data...) + for { + idx := bytes.IndexByte(buf, '\n') + if idx < 0 { + break + } + line := string(bytes.TrimRight(buf[:idx], "\r")) + buf = buf[idx+1:] + + if !strings.HasPrefix(line, "data: ") { + continue + } + payload := strings.TrimPrefix(line, "data: ") + if payload == "[DONE]" { + continue + } + var ev map[string]interface{} + if json.Unmarshal([]byte(payload), &ev) != nil { + continue + } + evType, _ := ev["type"].(string) + switch evType { + case "message_start": + // {"type":"message_start","message":{"usage":{"input_tokens":N}}} + if msg, ok := ev["message"].(map[string]interface{}); ok { + if u, ok := msg["usage"].(map[string]interface{}); ok { + if v, ok := u["input_tokens"].(float64); ok { + e.inputTok += int(v) + } + } + } + case "message_delta": + // {"type":"message_delta","usage":{"output_tokens":N}} + if u, ok := ev["usage"].(map[string]interface{}); ok { + if v, ok := u["output_tokens"].(float64); ok { + e.outputTok += int(v) + } + } + } + } + e.partial = buf +} From 2b9ee01226988bdf535de5ef9137aedd82ecede5 Mon Sep 17 00:00:00 2001 From: John Zhang Date: Wed, 11 Mar 2026 18:32:36 +0800 Subject: [PATCH 4/6] fix(proxy): transform Responses API payload for openai-chat clients after retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a provider returns "input is required" and the proxy retries via /responses, copyResponseFromResponsesAPI only converted to Anthropic format. An openai-chat client would receive raw Responses API JSON/SSE. - Add ResponsesAPIToOpenAIChat (non-streaming) transform - Add transformResponsesAPIToOpenAIChat (streaming) in StreamTransformer - Fix TransformSSEStream short-circuit: openai-responses → openai-chat requires conversion even though both normalize to "openai" - Update copyResponseFromResponsesAPI to branch on client format Co-Authored-By: Claude Sonnet 4.6 --- internal/proxy/transform/responses.go | 101 +++++++++++++++++++ internal/proxy/transform/stream.go | 134 +++++++++++++++++++++++++- 2 files changed, 233 insertions(+), 2 deletions(-) diff --git a/internal/proxy/transform/responses.go b/internal/proxy/transform/responses.go index 2da5dc4..f738039 100644 --- a/internal/proxy/transform/responses.go +++ b/internal/proxy/transform/responses.go @@ -197,3 +197,104 @@ func ResponsesAPIToAnthropic(body []byte) ([]byte, error) { return json.Marshal(anthropic) } + +// ResponsesAPIToOpenAIChat transforms an OpenAI Responses API response body +// to OpenAI Chat Completions API format. +func ResponsesAPIToOpenAIChat(body []byte) ([]byte, error) { + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err != nil { + return body, err + } + + // Extract text content and tool calls from the output array + var content string + var toolCalls []interface{} + + if output, ok := data["output"].([]interface{}); ok { + for _, item := range output { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + switch itemMap["type"] { + case "message": + if parts, ok := itemMap["content"].([]interface{}); ok { + for _, part := range parts { + partMap, ok := part.(map[string]interface{}) + if !ok { + continue + } + if partMap["type"] == "output_text" { + if t, ok := partMap["text"].(string); ok { + content += t + } + } + } + } + case "function_call": + toolCalls = append(toolCalls, map[string]interface{}{ + "id": itemMap["call_id"], + "type": "function", + "function": map[string]interface{}{ + "name": itemMap["name"], + "arguments": itemMap["arguments"], + }, + }) + } + } + } + + // Build finish_reason + finishReason := "stop" + if status, ok := data["status"].(string); ok && status == "incomplete" { + finishReason = "length" + } + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + // Build choice message + msg := map[string]interface{}{ + "role": "assistant", + "content": content, + } + if len(toolCalls) > 0 { + msg["tool_calls"] = toolCalls + } + + // Map usage + usage := map[string]interface{}{ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + if u, ok := data["usage"].(map[string]interface{}); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage["prompt_tokens"] = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage["completion_tokens"] = int(v) + } + if pt, ok := usage["prompt_tokens"].(int); ok { + if ct, ok := usage["completion_tokens"].(int); ok { + usage["total_tokens"] = pt + ct + } + } + } + + result := map[string]interface{}{ + "id": data["id"], + "object": "chat.completion", + "model": data["model"], + "choices": []interface{}{ + map[string]interface{}{ + "index": 0, + "message": msg, + "finish_reason": finishReason, + }, + }, + "usage": usage, + } + + return json.Marshal(result) +} diff --git a/internal/proxy/transform/stream.go b/internal/proxy/transform/stream.go index 81e569d..c24e778 100644 --- a/internal/proxy/transform/stream.go +++ b/internal/proxy/transform/stream.go @@ -44,7 +44,11 @@ func (st *StreamTransformer) TransformSSEStream(r io.Reader) io.Reader { normalizedClient := NormalizeFormat(st.ClientFormat) normalizedProvider := NormalizeFormat(st.ProviderFormat) - if normalizedClient == normalizedProvider { + // Short-circuit only when formats are truly identical (including fine-grained). + // openai-responses → openai-chat still needs transformation even though both + // normalize to "openai". + crossOpenAI := st.ProviderFormat == FormatOpenAIResponses && st.ClientFormat == FormatOpenAIChat + if normalizedClient == normalizedProvider && !crossOpenAI { return r } @@ -52,8 +56,11 @@ func (st *StreamTransformer) TransformSSEStream(r io.Reader) io.Reader { go func() { defer pw.Close() + // Responses API → OpenAI Chat Completions SSE + if st.ProviderFormat == FormatOpenAIResponses && st.ClientFormat == FormatOpenAIChat { + st.transformResponsesAPIToOpenAIChat(r, pw) // Check specific format first before normalized comparison - if st.ProviderFormat == FormatOpenAIResponses && normalizedClient == "anthropic" { + } else if st.ProviderFormat == FormatOpenAIResponses && normalizedClient == "anthropic" { st.transformResponsesAPIToAnthropic(r, pw) } else if normalizedProvider == "anthropic" && normalizedClient == "openai" { // Provider is Anthropic, client expects OpenAI @@ -1084,3 +1091,126 @@ func (st *StreamTransformer) transformResponsesAPIToAnthropic(r io.Reader, w io. _ = inputTokens } + +// transformResponsesAPIToOpenAIChat converts OpenAI Responses API SSE events +// to OpenAI Chat Completions SSE format (data: {...}\n\n with "delta" chunks). +func (st *StreamTransformer) transformResponsesAPIToOpenAIChat(r io.Reader, w io.Writer) { + scanner := bufio.NewScanner(r) + buf := make([]byte, 64*1024) + scanner.Buffer(buf, 1024*1024) + + var currentEvent string + var dataBuffer bytes.Buffer + // Use a deterministic ID prefix; real responses will come from the event data. + completionID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()) + model := st.Model + created := time.Now().Unix() + + emitChunk := func(delta map[string]interface{}, finishReason interface{}) { + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + "finish_reason": finishReason, + } + chunk := map[string]interface{}{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []interface{}{choice}, + } + data, err := json.Marshal(chunk) + if err != nil { + return + } + fmt.Fprintf(w, "data: %s\n\n", data) + } + + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "event: ") { + currentEvent = strings.TrimPrefix(line, "event: ") + continue + } + + if strings.HasPrefix(line, "data: ") { + dataBuffer.WriteString(strings.TrimPrefix(line, "data: ")) + continue + } + + if line != "" || dataBuffer.Len() == 0 { + continue + } + + // Empty line = end of event + data := dataBuffer.String() + dataBuffer.Reset() + + var eventData map[string]interface{} + if err := json.Unmarshal([]byte(data), &eventData); err != nil { + currentEvent = "" + continue + } + + switch currentEvent { + case "response.created", "response.in_progress": + // Emit role delta at stream start + if id, ok := eventData["id"].(string); ok { + completionID = "chatcmpl-" + id + } + if m, ok := eventData["model"].(string); ok { + model = m + } + emitChunk(map[string]interface{}{"role": "assistant", "content": ""}, nil) + + case "response.output_text.delta": + if delta, ok := eventData["delta"].(string); ok && delta != "" { + emitChunk(map[string]interface{}{"content": delta}, nil) + } + + case "response.function_call_arguments.delta": + // Streaming tool call arguments + if delta, ok := eventData["delta"].(string); ok && delta != "" { + emitChunk(map[string]interface{}{ + "tool_calls": []interface{}{ + map[string]interface{}{ + "index": 0, + "function": map[string]interface{}{ + "arguments": delta, + }, + }, + }, + }, nil) + } + + case "response.completed": + // Determine finish reason from completed response + finishReason := "stop" + if resp, ok := eventData["response"].(map[string]interface{}); ok { + if status, ok := resp["status"].(string); ok && status == "incomplete" { + finishReason = "length" + } + // Check for tool use in output + if output, ok := resp["output"].([]interface{}); ok { + for _, item := range output { + if itemMap, ok := item.(map[string]interface{}); ok { + if itemMap["type"] == "function_call" { + finishReason = "tool_calls" + break + } + } + } + } + } + emitChunk(map[string]interface{}{}, finishReason) + fmt.Fprintf(w, "data: [DONE]\n\n") + } + + currentEvent = "" + } + + if err := scanner.Err(); err != nil { + st.writeStreamError(w, err) + } +} From 19298335dc475b02b25887164c60f54ba31da18a Mon Sep 17 00:00:00 2001 From: John Zhang Date: Wed, 11 Mar 2026 18:36:27 +0800 Subject: [PATCH 5/6] test(proxy): fix silent assertion and add openai-chat Responses API retry coverage - loadbalancer_test.go: the per-profile isolation test had an empty if body so a counter-pollution regression would never be caught; add t.Errorf - server_test.go: add TestResponsesAPIRetryOpenAIChat to cover the path where an openai-chat client retries via /responses and must receive a Chat Completions response, not a raw Responses API payload Co-Authored-By: Claude Sonnet 4.6 --- internal/proxy/loadbalancer_test.go | 4 +- internal/proxy/server_test.go | 63 +++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/internal/proxy/loadbalancer_test.go b/internal/proxy/loadbalancer_test.go index 741f739..5a05eb1 100644 --- a/internal/proxy/loadbalancer_test.go +++ b/internal/proxy/loadbalancer_test.go @@ -1087,8 +1087,8 @@ func TestLoadBalancer_RoundRobinPerProfileIsolation(t *testing.T) { // Profile A and B should have the same rotation sequence (both start from counter=0) for i := 0; i < 3; i++ { if profileAResults[i] != profileBResults[i] { - // This is the key assertion: both profiles should independently cycle - // through the same sequence since they start from their own counter=0 + t.Errorf("profile isolation broken at index %d: profile-a=%s, profile-b=%s (counters should be independent)", + i, profileAResults[i], profileBResults[i]) } } diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index 2c40162..f4b7730 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -2852,6 +2852,69 @@ func TestResponsesAPIRetry(t *testing.T) { }) } +func TestResponsesAPIRetryOpenAIChat(t *testing.T) { + // Regression test for: OpenAI Chat client receives wrong Responses API payload after retry. + // When the client is openai-chat and the provider returns "input is required", the proxy + // should retry via /responses AND transform the Responses API response back to + // Chat Completions format before returning it to the client. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/chat/completions") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + w.Write([]byte(`{"error":{"message":"input is required (request id: abc)","type":"new_api_error"}}`)) + return + } + if strings.Contains(r.URL.Path, "/responses") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"id":"resp_1","object":"response","status":"completed","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello from Responses!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`)) + return + } + w.WriteHeader(404) + })) + defer backend.Close() + + u, _ := url.Parse(backend.URL) + providers := []*Provider{{ + Name: "openai-provider", + Type: config.ProviderTypeOpenAI, + BaseURL: u, + Token: "test-token", + Model: "gpt-5", + Healthy: true, + }} + + srv := NewProxyServer(providers, discardLogger(), config.LoadBalanceFailover, nil) + // Client sends an OpenAI Chat Completions request + req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader( + `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)) + req.Header.Set("X-Zen-Request-Format", "openai-chat") + w := httptest.NewRecorder() + srv.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) + } + + // Response MUST be Chat Completions format, NOT Responses API format + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + if resp["object"] != "chat.completion" { + t.Errorf("response object = %v, want chat.completion (got Responses API payload?)", resp["object"]) + } + choices, ok := resp["choices"].([]interface{}) + if !ok || len(choices) == 0 { + t.Fatal("response should have choices") + } + choice := choices[0].(map[string]interface{}) + msg := choice["message"].(map[string]interface{}) + if msg["content"] != "Hello from Responses!" { + t.Errorf("content = %v, want Hello from Responses!", msg["content"]) + } +} + func TestResponsesAPIRetryStreaming(t *testing.T) { // Mock server: 500 "input is required" on /chat/completions, // SSE Responses API stream on /responses From 626f7b5eaa9e057f831ba35123800130d5fcd3b4 Mon Sep 17 00:00:00 2001 From: John Zhang Date: Wed, 11 Mar 2026 21:11:47 +0800 Subject: [PATCH 6/6] test(proxy): add coverage for new transform functions and sseUsageExtractor - ResponsesAPIToOpenAIChat: text, tool_call, incomplete status cases - transformResponsesAPIToOpenAIChat: text stream, tool_call stream - sseUsageExtractor: parses message_start/message_delta, empty sessionID, Close delegation - Fix transformResponsesAPIToOpenAIChat to flush buffered response.completed event when stream ends without trailing blank line (mirrors the existing flush logic in transformResponsesAPIToAnthropic) proxy: 81.8% (was 79.0%), transform: 87.1% (was 79.5%) Co-Authored-By: Claude Sonnet 4.6 --- internal/proxy/server_test.go | 60 ++++++++++++++++ internal/proxy/transform/responses_test.go | 82 ++++++++++++++++++++++ internal/proxy/transform/stream.go | 25 +++++++ internal/proxy/transform/stream_test.go | 77 ++++++++++++++++++++ 4 files changed, 244 insertions(+) diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index f4b7730..6029f9e 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -3332,3 +3332,63 @@ func TestTransformError_ProperJSONResponse(t *testing.T) { t.Errorf("expected message to contain error text, got: %s", message) } } + +func TestSSEUsageExtractor(t *testing.T) { + t.Run("extracts_usage_from_anthropic_sse", func(t *testing.T) { + sse := strings.Join([]string{ + `data: {"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":25,"output_tokens":0}}}`, + ``, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`, + ``, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":42}}`, + ``, + `data: {"type":"message_stop"}`, + ``, + }, "\n") + + var got *SessionUsage + // Temporarily wire a capture via a fake session + sessionID := "test-sse-usage-session" + ClearSessionUsage(sessionID) + + extractor := &sseUsageExtractor{ + r: io.NopCloser(strings.NewReader(sse)), + sessionID: sessionID, + } + _, err := io.ReadAll(extractor) + if err != nil { + t.Fatalf("unexpected read error: %v", err) + } + + got = GetSessionUsage(sessionID) + if got == nil { + t.Fatal("expected session usage to be updated, got nil") + } + if got.InputTokens != 25 { + t.Errorf("InputTokens = %d, want 25", got.InputTokens) + } + if got.OutputTokens != 42 { + t.Errorf("OutputTokens = %d, want 42", got.OutputTokens) + } + }) + + t.Run("no_update_on_empty_session", func(t *testing.T) { + sse := "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10}}}\n\n" + extractor := &sseUsageExtractor{ + r: io.NopCloser(strings.NewReader(sse)), + sessionID: "", // empty → should not call UpdateSessionUsage + } + // Should complete without panic + io.ReadAll(extractor) + }) + + t.Run("close_delegates_to_inner", func(t *testing.T) { + extractor := &sseUsageExtractor{ + r: io.NopCloser(strings.NewReader("")), + sessionID: "", + } + if err := extractor.Close(); err != nil { + t.Errorf("Close() error: %v", err) + } + }) +} diff --git a/internal/proxy/transform/responses_test.go b/internal/proxy/transform/responses_test.go index f8866da..cc5f8f2 100644 --- a/internal/proxy/transform/responses_test.go +++ b/internal/proxy/transform/responses_test.go @@ -376,3 +376,85 @@ func TestResponsesAPIToAnthropic_ToolCall(t *testing.T) { }) } } + +func TestResponsesAPIToOpenAIChat(t *testing.T) { + tests := []struct { + name string + input string + checkFn func(t *testing.T, result map[string]interface{}) + }{ + { + name: "text_response", + input: `{"id":"resp_1","object":"response","status":"completed","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello from Responses!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`, + checkFn: func(t *testing.T, result map[string]interface{}) { + if result["object"] != "chat.completion" { + t.Errorf("object = %v, want chat.completion", result["object"]) + } + choices := result["choices"].([]interface{}) + if len(choices) != 1 { + t.Fatalf("choices len = %d, want 1", len(choices)) + } + choice := choices[0].(map[string]interface{}) + if choice["finish_reason"] != "stop" { + t.Errorf("finish_reason = %v, want stop", choice["finish_reason"]) + } + msg := choice["message"].(map[string]interface{}) + if msg["content"] != "Hello from Responses!" { + t.Errorf("content = %v, want Hello from Responses!", msg["content"]) + } + usage := result["usage"].(map[string]interface{}) + if usage["prompt_tokens"].(float64) != 10 { + t.Errorf("prompt_tokens = %v, want 10", usage["prompt_tokens"]) + } + }, + }, + { + name: "tool_call_response", + input: `{"id":"resp_2","object":"response","status":"completed","model":"gpt-5","output":[{"id":"fc_1","type":"function_call","call_id":"call_1","name":"get_weather","arguments":"{\"location\":\"Paris\"}","status":"completed"}],"usage":{"input_tokens":20,"output_tokens":10}}`, + checkFn: func(t *testing.T, result map[string]interface{}) { + choices := result["choices"].([]interface{}) + choice := choices[0].(map[string]interface{}) + if choice["finish_reason"] != "tool_calls" { + t.Errorf("finish_reason = %v, want tool_calls", choice["finish_reason"]) + } + msg := choice["message"].(map[string]interface{}) + toolCalls := msg["tool_calls"].([]interface{}) + if len(toolCalls) != 1 { + t.Fatalf("tool_calls len = %d, want 1", len(toolCalls)) + } + tc := toolCalls[0].(map[string]interface{}) + if tc["type"] != "function" { + t.Errorf("tool_call type = %v, want function", tc["type"]) + } + fn := tc["function"].(map[string]interface{}) + if fn["name"] != "get_weather" { + t.Errorf("function name = %v, want get_weather", fn["name"]) + } + }, + }, + { + name: "incomplete_status", + input: `{"id":"resp_3","object":"response","status":"incomplete","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"truncated"}]}],"usage":{"input_tokens":5,"output_tokens":100}}`, + checkFn: func(t *testing.T, result map[string]interface{}) { + choices := result["choices"].([]interface{}) + choice := choices[0].(map[string]interface{}) + if choice["finish_reason"] != "length" { + t.Errorf("finish_reason = %v, want length", choice["finish_reason"]) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + output, err := ResponsesAPIToOpenAIChat([]byte(tt.input)) + if err != nil { + t.Fatalf("ResponsesAPIToOpenAIChat() error: %v", err) + } + var result map[string]interface{} + if err := json.Unmarshal(output, &result); err != nil { + t.Fatalf("failed to parse output: %v", err) + } + tt.checkFn(t, result) + }) + } +} diff --git a/internal/proxy/transform/stream.go b/internal/proxy/transform/stream.go index c24e778..a6965e1 100644 --- a/internal/proxy/transform/stream.go +++ b/internal/proxy/transform/stream.go @@ -1210,6 +1210,31 @@ func (st *StreamTransformer) transformResponsesAPIToOpenAIChat(r io.Reader, w io currentEvent = "" } + // Process remaining buffered data (stream may end without trailing blank line) + if dataBuffer.Len() > 0 && currentEvent == "response.completed" { + var eventData map[string]interface{} + if json.Unmarshal(dataBuffer.Bytes(), &eventData) == nil { + finishReason := "stop" + if resp, ok := eventData["response"].(map[string]interface{}); ok { + if status, ok := resp["status"].(string); ok && status == "incomplete" { + finishReason = "length" + } + if output, ok := resp["output"].([]interface{}); ok { + for _, item := range output { + if itemMap, ok := item.(map[string]interface{}); ok { + if itemMap["type"] == "function_call" { + finishReason = "tool_calls" + break + } + } + } + } + } + emitChunk(map[string]interface{}{}, finishReason) + fmt.Fprintf(w, "data: [DONE]\n\n") + } + } + if err := scanner.Err(); err != nil { st.writeStreamError(w, err) } diff --git a/internal/proxy/transform/stream_test.go b/internal/proxy/transform/stream_test.go index 6c97a46..17cf0fc 100644 --- a/internal/proxy/transform/stream_test.go +++ b/internal/proxy/transform/stream_test.go @@ -1178,3 +1178,80 @@ func parseSSEEvents(output string) []sseEvent { return events } + +func TestTransformResponsesAPIToOpenAIChat_Text(t *testing.T) { + input := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_chat1","status":"in_progress","model":"gpt-5","output":[]}}`, + ``, + `event: response.output_text.delta`, + `data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":"Hello"}`, + ``, + `event: response.output_text.delta`, + `data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":" world"}`, + ``, + `event: response.completed`, + `data: {"type":"response.completed","response":{"id":"resp_chat1","status":"completed","model":"gpt-5","output":[],"usage":{"input_tokens":5,"output_tokens":3}}}`, + ``, + }, "\n") + + st := &StreamTransformer{ + ClientFormat: FormatOpenAIChat, + ProviderFormat: FormatOpenAIResponses, + } + reader := st.TransformSSEStream(strings.NewReader(input)) + output, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + result := string(output) + + // Should produce Chat Completions chunks + if !strings.Contains(result, `"chat.completion.chunk"`) { + t.Error("should emit chat.completion.chunk objects") + } + if !strings.Contains(result, `"Hello"`) { + t.Error("should include first delta text") + } + if !strings.Contains(result, `" world"`) { + t.Error("should include second delta text") + } + if !strings.Contains(result, `"stop"`) { + t.Error("should include finish_reason stop in final chunk") + } + if !strings.Contains(result, "data: [DONE]") { + t.Error("should emit [DONE] sentinel") + } +} + +func TestTransformResponsesAPIToOpenAIChat_ToolCall(t *testing.T) { + input := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_tc1","status":"in_progress","model":"gpt-5","output":[]}}`, + ``, + `event: response.function_call_arguments.delta`, + `data: {"type":"response.function_call_arguments.delta","item_id":"fc_1","output_index":0,"delta":"{\"loc"}`, + ``, + `event: response.completed`, + `data: {"type":"response.completed","response":{"id":"resp_tc1","status":"completed","model":"gpt-5","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"weather"}]}}`, + ``, + }, "\n") + + st := &StreamTransformer{ + ClientFormat: FormatOpenAIChat, + ProviderFormat: FormatOpenAIResponses, + } + reader := st.TransformSSEStream(strings.NewReader(input)) + output, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + result := string(output) + + if !strings.Contains(result, `"tool_calls"`) { + t.Error("should emit tool_calls delta") + } + if !strings.Contains(result, `"tool_calls"`) || !strings.Contains(result, `"finish_reason"`) { + t.Error("should emit finish chunk with tool_calls finish_reason") + } +}