diff --git a/cmd/daemon.go b/cmd/daemon.go index b033632..847c067 100644 --- a/cmd/daemon.go +++ b/cmd/daemon.go @@ -259,6 +259,7 @@ func startDaemonBackground() error { fmt.Println("zend is now running (started by another process).") return nil } + return fmt.Errorf("daemon failed to start (concurrent startup attempt failed, please retry)") } else if err != nil { return fmt.Errorf("cannot acquire daemon lock: %w", err) } diff --git a/internal/daemon/api.go b/internal/daemon/api.go index bfe68c9..4f0ff8e 100644 --- a/internal/daemon/api.go +++ b/internal/daemon/api.go @@ -180,12 +180,7 @@ func (d *Daemon) handleDaemonShutdown(w http.ResponseWriter, r *http.Request) { // Trigger shutdown in background so the response is sent first go func() { time.Sleep(100 * time.Millisecond) - select { - case <-d.shutdownCh: - // Already closed - default: - close(d.shutdownCh) - } + d.shutdownOnce.Do(func() { close(d.shutdownCh) }) }() } @@ -232,9 +227,9 @@ func (d *Daemon) handleDaemonSessions(w http.ResponseWriter, r *http.Request) { func (d *Daemon) handleGetSessions(w http.ResponseWriter, r *http.Request) { d.mu.RLock() - sessions := make([]*SessionInfo, 0, len(d.sessions)) + sessions := make([]SessionInfo, 0, len(d.sessions)) for _, s := range d.sessions { - sessions = append(sessions, s) + sessions = append(sessions, *s) // copy struct to avoid data race after unlock } d.mu.RUnlock() diff --git a/internal/daemon/server.go b/internal/daemon/server.go index 72b87cc..fd23b74 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -83,10 +83,11 @@ type Daemon struct { currentGates *config.FeatureGates // Shutdown channel - closed when shutdown is requested via API - shutdownCh chan struct{} - runCtx context.Context - runCancel context.CancelFunc - bgWG sync.WaitGroup + shutdownCh chan struct{} + shutdownOnce sync.Once + runCtx context.Context + runCancel context.CancelFunc + bgWG sync.WaitGroup // Goroutine leak detection baselineGoroutines int diff --git a/internal/proxy/loadbalancer_test.go b/internal/proxy/loadbalancer_test.go index 741f739..5a05eb1 100644 --- a/internal/proxy/loadbalancer_test.go +++ b/internal/proxy/loadbalancer_test.go @@ -1087,8 +1087,8 @@ func TestLoadBalancer_RoundRobinPerProfileIsolation(t *testing.T) { // Profile A and B should have the same rotation sequence (both start from counter=0) for i := 0; i < 3; i++ { if profileAResults[i] != profileBResults[i] { - // This is the key assertion: both profiles should independently cycle - // through the same sequence since they start from their own counter=0 + t.Errorf("profile isolation broken at index %d: profile-a=%s, profile-b=%s (counters should be independent)", + i, profileAResults[i], profileBResults[i]) } } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 9cfcd88..6a0e84b 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -632,7 +632,13 @@ func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { strategy = *decision.StrategyOverride } - providers = s.LoadBalancer.Select(providers, strategy, model, s.Profile, modelOverrides, weights) + // Use a scenario-specific counter key so scenario route round-robin + // does not advance the default profile's counter. + rrKey := s.Profile + if usingScenarioRoute && decision.Scenario != "" { + rrKey = s.Profile + ":scenario:" + decision.Scenario + } + providers = s.LoadBalancer.Select(providers, strategy, model, rrKey, modelOverrides, weights) } // Track provider failure details for error reporting @@ -961,8 +967,14 @@ func (s *ProxyServer) tryProviders(w http.ResponseWriter, r *http.Request, provi s.Logger.Printf("[%s] %s", p.Name, msg) s.logStructured(p.Name, r.Method, r.URL.Path, resp.StatusCode, LogLevelInfo, msg, sessionID, clientType) - // Update session cache with token usage from response - s.updateSessionCache(sessionID, resp) + // Update session cache with token usage from response. + // For SSE (streaming), wrap the body with an extractor that parses + // usage events in-flight so longContext routing stays accurate. + if sessionID != "" && strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + resp.Body = &sseUsageExtractor{r: resp.Body, sessionID: sessionID} + } else { + s.updateSessionCache(sessionID, resp) + } // Record usage and metrics s.recordUsageAndMetrics(p.Name, sessionID, clientType, bodyBytes, resp, requestID, requestStart, requestFormat, failures) @@ -1195,9 +1207,9 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp * return } - // Transform Responses API → Anthropic - // Check if client expects Anthropic format (anthropic-messages or legacy anthropic) - if transform.NormalizeFormat(requestFormat) == config.ProviderTypeAnthropic { + // Transform Responses API → client format + switch transform.NormalizeFormat(requestFormat) { + case config.ProviderTypeAnthropic: transformed, err := transform.ResponsesAPIToAnthropic(body) if err != nil { s.Logger.Printf("[%s] Responses API response transform error: %v", p.Name, err) @@ -1205,6 +1217,16 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp * s.Logger.Printf("[%s] transformed Responses API → Anthropic", p.Name) body = transformed } + case config.ProviderTypeOpenAI: + if requestFormat == transform.FormatOpenAIChat { + transformed, err := transform.ResponsesAPIToOpenAIChat(body) + if err != nil { + s.Logger.Printf("[%s] Responses API → Chat Completions transform error: %v", p.Name, err) + } else { + s.Logger.Printf("[%s] transformed Responses API → OpenAI Chat Completions", p.Name) + body = transformed + } + } } // Copy headers (except Content-Length which may have changed) @@ -1232,14 +1254,23 @@ func (s *ProxyServer) copyResponseFromResponsesAPI(w http.ResponseWriter, resp * flusher, ok := w.(http.Flusher) var reader io.Reader = resp.Body - // Check if client expects Anthropic format (anthropic-messages or legacy anthropic) - if transform.NormalizeFormat(requestFormat) == config.ProviderTypeAnthropic { + switch transform.NormalizeFormat(requestFormat) { + case config.ProviderTypeAnthropic: st := &transform.StreamTransformer{ ClientFormat: "anthropic", - ProviderFormat: "openai-responses", + ProviderFormat: transform.FormatOpenAIResponses, } reader = st.TransformSSEStream(resp.Body) s.Logger.Printf("[%s] transforming Responses API SSE stream → Anthropic", p.Name) + case config.ProviderTypeOpenAI: + if requestFormat == transform.FormatOpenAIChat { + st := &transform.StreamTransformer{ + ClientFormat: transform.FormatOpenAIChat, + ProviderFormat: transform.FormatOpenAIResponses, + } + reader = st.TransformSSEStream(resp.Body) + s.Logger.Printf("[%s] transforming Responses API SSE stream → OpenAI Chat Completions", p.Name) + } } buf := make([]byte, 4096) @@ -1850,3 +1881,78 @@ func GetDaemonLogger() daemonLogger { } return nil } + +// sseUsageExtractor wraps an SSE response body, parsing Anthropic SSE events +// in-flight to extract token usage. When the stream ends, it updates the +// session cache so longContext routing has accurate usage for streaming turns. +type sseUsageExtractor struct { + r io.ReadCloser + sessionID string + partial []byte // incomplete line buffer + inputTok int + outputTok int +} + +func (e *sseUsageExtractor) Read(p []byte) (n int, err error) { + n, err = e.r.Read(p) + if n > 0 { + e.processChunk(p[:n]) + } + if err == io.EOF { + if e.sessionID != "" && (e.inputTok > 0 || e.outputTok > 0) { + UpdateSessionUsage(e.sessionID, &SessionUsage{ + InputTokens: e.inputTok, + OutputTokens: e.outputTok, + }) + } + } + return +} + +func (e *sseUsageExtractor) Close() error { return e.r.Close() } + +// processChunk scans raw SSE bytes for usage data events. +func (e *sseUsageExtractor) processChunk(data []byte) { + // Append to partial buffer and process line by line + buf := append(e.partial, data...) + for { + idx := bytes.IndexByte(buf, '\n') + if idx < 0 { + break + } + line := string(bytes.TrimRight(buf[:idx], "\r")) + buf = buf[idx+1:] + + if !strings.HasPrefix(line, "data: ") { + continue + } + payload := strings.TrimPrefix(line, "data: ") + if payload == "[DONE]" { + continue + } + var ev map[string]interface{} + if json.Unmarshal([]byte(payload), &ev) != nil { + continue + } + evType, _ := ev["type"].(string) + switch evType { + case "message_start": + // {"type":"message_start","message":{"usage":{"input_tokens":N}}} + if msg, ok := ev["message"].(map[string]interface{}); ok { + if u, ok := msg["usage"].(map[string]interface{}); ok { + if v, ok := u["input_tokens"].(float64); ok { + e.inputTok += int(v) + } + } + } + case "message_delta": + // {"type":"message_delta","usage":{"output_tokens":N}} + if u, ok := ev["usage"].(map[string]interface{}); ok { + if v, ok := u["output_tokens"].(float64); ok { + e.outputTok += int(v) + } + } + } + } + e.partial = buf +} diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index 2c40162..6029f9e 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -2852,6 +2852,69 @@ func TestResponsesAPIRetry(t *testing.T) { }) } +func TestResponsesAPIRetryOpenAIChat(t *testing.T) { + // Regression test for: OpenAI Chat client receives wrong Responses API payload after retry. + // When the client is openai-chat and the provider returns "input is required", the proxy + // should retry via /responses AND transform the Responses API response back to + // Chat Completions format before returning it to the client. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/chat/completions") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + w.Write([]byte(`{"error":{"message":"input is required (request id: abc)","type":"new_api_error"}}`)) + return + } + if strings.Contains(r.URL.Path, "/responses") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"id":"resp_1","object":"response","status":"completed","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello from Responses!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`)) + return + } + w.WriteHeader(404) + })) + defer backend.Close() + + u, _ := url.Parse(backend.URL) + providers := []*Provider{{ + Name: "openai-provider", + Type: config.ProviderTypeOpenAI, + BaseURL: u, + Token: "test-token", + Model: "gpt-5", + Healthy: true, + }} + + srv := NewProxyServer(providers, discardLogger(), config.LoadBalanceFailover, nil) + // Client sends an OpenAI Chat Completions request + req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader( + `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)) + req.Header.Set("X-Zen-Request-Format", "openai-chat") + w := httptest.NewRecorder() + srv.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) + } + + // Response MUST be Chat Completions format, NOT Responses API format + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + if resp["object"] != "chat.completion" { + t.Errorf("response object = %v, want chat.completion (got Responses API payload?)", resp["object"]) + } + choices, ok := resp["choices"].([]interface{}) + if !ok || len(choices) == 0 { + t.Fatal("response should have choices") + } + choice := choices[0].(map[string]interface{}) + msg := choice["message"].(map[string]interface{}) + if msg["content"] != "Hello from Responses!" { + t.Errorf("content = %v, want Hello from Responses!", msg["content"]) + } +} + func TestResponsesAPIRetryStreaming(t *testing.T) { // Mock server: 500 "input is required" on /chat/completions, // SSE Responses API stream on /responses @@ -3269,3 +3332,63 @@ func TestTransformError_ProperJSONResponse(t *testing.T) { t.Errorf("expected message to contain error text, got: %s", message) } } + +func TestSSEUsageExtractor(t *testing.T) { + t.Run("extracts_usage_from_anthropic_sse", func(t *testing.T) { + sse := strings.Join([]string{ + `data: {"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":25,"output_tokens":0}}}`, + ``, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`, + ``, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":42}}`, + ``, + `data: {"type":"message_stop"}`, + ``, + }, "\n") + + var got *SessionUsage + // Temporarily wire a capture via a fake session + sessionID := "test-sse-usage-session" + ClearSessionUsage(sessionID) + + extractor := &sseUsageExtractor{ + r: io.NopCloser(strings.NewReader(sse)), + sessionID: sessionID, + } + _, err := io.ReadAll(extractor) + if err != nil { + t.Fatalf("unexpected read error: %v", err) + } + + got = GetSessionUsage(sessionID) + if got == nil { + t.Fatal("expected session usage to be updated, got nil") + } + if got.InputTokens != 25 { + t.Errorf("InputTokens = %d, want 25", got.InputTokens) + } + if got.OutputTokens != 42 { + t.Errorf("OutputTokens = %d, want 42", got.OutputTokens) + } + }) + + t.Run("no_update_on_empty_session", func(t *testing.T) { + sse := "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10}}}\n\n" + extractor := &sseUsageExtractor{ + r: io.NopCloser(strings.NewReader(sse)), + sessionID: "", // empty → should not call UpdateSessionUsage + } + // Should complete without panic + io.ReadAll(extractor) + }) + + t.Run("close_delegates_to_inner", func(t *testing.T) { + extractor := &sseUsageExtractor{ + r: io.NopCloser(strings.NewReader("")), + sessionID: "", + } + if err := extractor.Close(); err != nil { + t.Errorf("Close() error: %v", err) + } + }) +} diff --git a/internal/proxy/transform/responses.go b/internal/proxy/transform/responses.go index 2da5dc4..f738039 100644 --- a/internal/proxy/transform/responses.go +++ b/internal/proxy/transform/responses.go @@ -197,3 +197,104 @@ func ResponsesAPIToAnthropic(body []byte) ([]byte, error) { return json.Marshal(anthropic) } + +// ResponsesAPIToOpenAIChat transforms an OpenAI Responses API response body +// to OpenAI Chat Completions API format. +func ResponsesAPIToOpenAIChat(body []byte) ([]byte, error) { + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err != nil { + return body, err + } + + // Extract text content and tool calls from the output array + var content string + var toolCalls []interface{} + + if output, ok := data["output"].([]interface{}); ok { + for _, item := range output { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + switch itemMap["type"] { + case "message": + if parts, ok := itemMap["content"].([]interface{}); ok { + for _, part := range parts { + partMap, ok := part.(map[string]interface{}) + if !ok { + continue + } + if partMap["type"] == "output_text" { + if t, ok := partMap["text"].(string); ok { + content += t + } + } + } + } + case "function_call": + toolCalls = append(toolCalls, map[string]interface{}{ + "id": itemMap["call_id"], + "type": "function", + "function": map[string]interface{}{ + "name": itemMap["name"], + "arguments": itemMap["arguments"], + }, + }) + } + } + } + + // Build finish_reason + finishReason := "stop" + if status, ok := data["status"].(string); ok && status == "incomplete" { + finishReason = "length" + } + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + // Build choice message + msg := map[string]interface{}{ + "role": "assistant", + "content": content, + } + if len(toolCalls) > 0 { + msg["tool_calls"] = toolCalls + } + + // Map usage + usage := map[string]interface{}{ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + if u, ok := data["usage"].(map[string]interface{}); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage["prompt_tokens"] = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage["completion_tokens"] = int(v) + } + if pt, ok := usage["prompt_tokens"].(int); ok { + if ct, ok := usage["completion_tokens"].(int); ok { + usage["total_tokens"] = pt + ct + } + } + } + + result := map[string]interface{}{ + "id": data["id"], + "object": "chat.completion", + "model": data["model"], + "choices": []interface{}{ + map[string]interface{}{ + "index": 0, + "message": msg, + "finish_reason": finishReason, + }, + }, + "usage": usage, + } + + return json.Marshal(result) +} diff --git a/internal/proxy/transform/responses_test.go b/internal/proxy/transform/responses_test.go index f8866da..cc5f8f2 100644 --- a/internal/proxy/transform/responses_test.go +++ b/internal/proxy/transform/responses_test.go @@ -376,3 +376,85 @@ func TestResponsesAPIToAnthropic_ToolCall(t *testing.T) { }) } } + +func TestResponsesAPIToOpenAIChat(t *testing.T) { + tests := []struct { + name string + input string + checkFn func(t *testing.T, result map[string]interface{}) + }{ + { + name: "text_response", + input: `{"id":"resp_1","object":"response","status":"completed","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello from Responses!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`, + checkFn: func(t *testing.T, result map[string]interface{}) { + if result["object"] != "chat.completion" { + t.Errorf("object = %v, want chat.completion", result["object"]) + } + choices := result["choices"].([]interface{}) + if len(choices) != 1 { + t.Fatalf("choices len = %d, want 1", len(choices)) + } + choice := choices[0].(map[string]interface{}) + if choice["finish_reason"] != "stop" { + t.Errorf("finish_reason = %v, want stop", choice["finish_reason"]) + } + msg := choice["message"].(map[string]interface{}) + if msg["content"] != "Hello from Responses!" { + t.Errorf("content = %v, want Hello from Responses!", msg["content"]) + } + usage := result["usage"].(map[string]interface{}) + if usage["prompt_tokens"].(float64) != 10 { + t.Errorf("prompt_tokens = %v, want 10", usage["prompt_tokens"]) + } + }, + }, + { + name: "tool_call_response", + input: `{"id":"resp_2","object":"response","status":"completed","model":"gpt-5","output":[{"id":"fc_1","type":"function_call","call_id":"call_1","name":"get_weather","arguments":"{\"location\":\"Paris\"}","status":"completed"}],"usage":{"input_tokens":20,"output_tokens":10}}`, + checkFn: func(t *testing.T, result map[string]interface{}) { + choices := result["choices"].([]interface{}) + choice := choices[0].(map[string]interface{}) + if choice["finish_reason"] != "tool_calls" { + t.Errorf("finish_reason = %v, want tool_calls", choice["finish_reason"]) + } + msg := choice["message"].(map[string]interface{}) + toolCalls := msg["tool_calls"].([]interface{}) + if len(toolCalls) != 1 { + t.Fatalf("tool_calls len = %d, want 1", len(toolCalls)) + } + tc := toolCalls[0].(map[string]interface{}) + if tc["type"] != "function" { + t.Errorf("tool_call type = %v, want function", tc["type"]) + } + fn := tc["function"].(map[string]interface{}) + if fn["name"] != "get_weather" { + t.Errorf("function name = %v, want get_weather", fn["name"]) + } + }, + }, + { + name: "incomplete_status", + input: `{"id":"resp_3","object":"response","status":"incomplete","model":"gpt-5","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"truncated"}]}],"usage":{"input_tokens":5,"output_tokens":100}}`, + checkFn: func(t *testing.T, result map[string]interface{}) { + choices := result["choices"].([]interface{}) + choice := choices[0].(map[string]interface{}) + if choice["finish_reason"] != "length" { + t.Errorf("finish_reason = %v, want length", choice["finish_reason"]) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + output, err := ResponsesAPIToOpenAIChat([]byte(tt.input)) + if err != nil { + t.Fatalf("ResponsesAPIToOpenAIChat() error: %v", err) + } + var result map[string]interface{} + if err := json.Unmarshal(output, &result); err != nil { + t.Fatalf("failed to parse output: %v", err) + } + tt.checkFn(t, result) + }) + } +} diff --git a/internal/proxy/transform/stream.go b/internal/proxy/transform/stream.go index 81e569d..a6965e1 100644 --- a/internal/proxy/transform/stream.go +++ b/internal/proxy/transform/stream.go @@ -44,7 +44,11 @@ func (st *StreamTransformer) TransformSSEStream(r io.Reader) io.Reader { normalizedClient := NormalizeFormat(st.ClientFormat) normalizedProvider := NormalizeFormat(st.ProviderFormat) - if normalizedClient == normalizedProvider { + // Short-circuit only when formats are truly identical (including fine-grained). + // openai-responses → openai-chat still needs transformation even though both + // normalize to "openai". + crossOpenAI := st.ProviderFormat == FormatOpenAIResponses && st.ClientFormat == FormatOpenAIChat + if normalizedClient == normalizedProvider && !crossOpenAI { return r } @@ -52,8 +56,11 @@ func (st *StreamTransformer) TransformSSEStream(r io.Reader) io.Reader { go func() { defer pw.Close() + // Responses API → OpenAI Chat Completions SSE + if st.ProviderFormat == FormatOpenAIResponses && st.ClientFormat == FormatOpenAIChat { + st.transformResponsesAPIToOpenAIChat(r, pw) // Check specific format first before normalized comparison - if st.ProviderFormat == FormatOpenAIResponses && normalizedClient == "anthropic" { + } else if st.ProviderFormat == FormatOpenAIResponses && normalizedClient == "anthropic" { st.transformResponsesAPIToAnthropic(r, pw) } else if normalizedProvider == "anthropic" && normalizedClient == "openai" { // Provider is Anthropic, client expects OpenAI @@ -1084,3 +1091,151 @@ func (st *StreamTransformer) transformResponsesAPIToAnthropic(r io.Reader, w io. _ = inputTokens } + +// transformResponsesAPIToOpenAIChat converts OpenAI Responses API SSE events +// to OpenAI Chat Completions SSE format (data: {...}\n\n with "delta" chunks). +func (st *StreamTransformer) transformResponsesAPIToOpenAIChat(r io.Reader, w io.Writer) { + scanner := bufio.NewScanner(r) + buf := make([]byte, 64*1024) + scanner.Buffer(buf, 1024*1024) + + var currentEvent string + var dataBuffer bytes.Buffer + // Use a deterministic ID prefix; real responses will come from the event data. + completionID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()) + model := st.Model + created := time.Now().Unix() + + emitChunk := func(delta map[string]interface{}, finishReason interface{}) { + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + "finish_reason": finishReason, + } + chunk := map[string]interface{}{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []interface{}{choice}, + } + data, err := json.Marshal(chunk) + if err != nil { + return + } + fmt.Fprintf(w, "data: %s\n\n", data) + } + + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "event: ") { + currentEvent = strings.TrimPrefix(line, "event: ") + continue + } + + if strings.HasPrefix(line, "data: ") { + dataBuffer.WriteString(strings.TrimPrefix(line, "data: ")) + continue + } + + if line != "" || dataBuffer.Len() == 0 { + continue + } + + // Empty line = end of event + data := dataBuffer.String() + dataBuffer.Reset() + + var eventData map[string]interface{} + if err := json.Unmarshal([]byte(data), &eventData); err != nil { + currentEvent = "" + continue + } + + switch currentEvent { + case "response.created", "response.in_progress": + // Emit role delta at stream start + if id, ok := eventData["id"].(string); ok { + completionID = "chatcmpl-" + id + } + if m, ok := eventData["model"].(string); ok { + model = m + } + emitChunk(map[string]interface{}{"role": "assistant", "content": ""}, nil) + + case "response.output_text.delta": + if delta, ok := eventData["delta"].(string); ok && delta != "" { + emitChunk(map[string]interface{}{"content": delta}, nil) + } + + case "response.function_call_arguments.delta": + // Streaming tool call arguments + if delta, ok := eventData["delta"].(string); ok && delta != "" { + emitChunk(map[string]interface{}{ + "tool_calls": []interface{}{ + map[string]interface{}{ + "index": 0, + "function": map[string]interface{}{ + "arguments": delta, + }, + }, + }, + }, nil) + } + + case "response.completed": + // Determine finish reason from completed response + finishReason := "stop" + if resp, ok := eventData["response"].(map[string]interface{}); ok { + if status, ok := resp["status"].(string); ok && status == "incomplete" { + finishReason = "length" + } + // Check for tool use in output + if output, ok := resp["output"].([]interface{}); ok { + for _, item := range output { + if itemMap, ok := item.(map[string]interface{}); ok { + if itemMap["type"] == "function_call" { + finishReason = "tool_calls" + break + } + } + } + } + } + emitChunk(map[string]interface{}{}, finishReason) + fmt.Fprintf(w, "data: [DONE]\n\n") + } + + currentEvent = "" + } + + // Process remaining buffered data (stream may end without trailing blank line) + if dataBuffer.Len() > 0 && currentEvent == "response.completed" { + var eventData map[string]interface{} + if json.Unmarshal(dataBuffer.Bytes(), &eventData) == nil { + finishReason := "stop" + if resp, ok := eventData["response"].(map[string]interface{}); ok { + if status, ok := resp["status"].(string); ok && status == "incomplete" { + finishReason = "length" + } + if output, ok := resp["output"].([]interface{}); ok { + for _, item := range output { + if itemMap, ok := item.(map[string]interface{}); ok { + if itemMap["type"] == "function_call" { + finishReason = "tool_calls" + break + } + } + } + } + } + emitChunk(map[string]interface{}{}, finishReason) + fmt.Fprintf(w, "data: [DONE]\n\n") + } + } + + if err := scanner.Err(); err != nil { + st.writeStreamError(w, err) + } +} diff --git a/internal/proxy/transform/stream_test.go b/internal/proxy/transform/stream_test.go index 6c97a46..17cf0fc 100644 --- a/internal/proxy/transform/stream_test.go +++ b/internal/proxy/transform/stream_test.go @@ -1178,3 +1178,80 @@ func parseSSEEvents(output string) []sseEvent { return events } + +func TestTransformResponsesAPIToOpenAIChat_Text(t *testing.T) { + input := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_chat1","status":"in_progress","model":"gpt-5","output":[]}}`, + ``, + `event: response.output_text.delta`, + `data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":"Hello"}`, + ``, + `event: response.output_text.delta`, + `data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":" world"}`, + ``, + `event: response.completed`, + `data: {"type":"response.completed","response":{"id":"resp_chat1","status":"completed","model":"gpt-5","output":[],"usage":{"input_tokens":5,"output_tokens":3}}}`, + ``, + }, "\n") + + st := &StreamTransformer{ + ClientFormat: FormatOpenAIChat, + ProviderFormat: FormatOpenAIResponses, + } + reader := st.TransformSSEStream(strings.NewReader(input)) + output, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + result := string(output) + + // Should produce Chat Completions chunks + if !strings.Contains(result, `"chat.completion.chunk"`) { + t.Error("should emit chat.completion.chunk objects") + } + if !strings.Contains(result, `"Hello"`) { + t.Error("should include first delta text") + } + if !strings.Contains(result, `" world"`) { + t.Error("should include second delta text") + } + if !strings.Contains(result, `"stop"`) { + t.Error("should include finish_reason stop in final chunk") + } + if !strings.Contains(result, "data: [DONE]") { + t.Error("should emit [DONE] sentinel") + } +} + +func TestTransformResponsesAPIToOpenAIChat_ToolCall(t *testing.T) { + input := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_tc1","status":"in_progress","model":"gpt-5","output":[]}}`, + ``, + `event: response.function_call_arguments.delta`, + `data: {"type":"response.function_call_arguments.delta","item_id":"fc_1","output_index":0,"delta":"{\"loc"}`, + ``, + `event: response.completed`, + `data: {"type":"response.completed","response":{"id":"resp_tc1","status":"completed","model":"gpt-5","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"weather"}]}}`, + ``, + }, "\n") + + st := &StreamTransformer{ + ClientFormat: FormatOpenAIChat, + ProviderFormat: FormatOpenAIResponses, + } + reader := st.TransformSSEStream(strings.NewReader(input)) + output, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + result := string(output) + + if !strings.Contains(result, `"tool_calls"`) { + t.Error("should emit tool_calls delta") + } + if !strings.Contains(result, `"tool_calls"`) || !strings.Contains(result, `"finish_reason"`) { + t.Error("should emit finish chunk with tool_calls finish_reason") + } +}