diff --git a/autorouter.go b/autorouter.go index f3e3d15..6ea5932 100644 --- a/autorouter.go +++ b/autorouter.go @@ -15,6 +15,12 @@ import ( var skipHeaders = []string{"Content-Encoding", "Content-Length"} +func disableUpstreamResponseCompression(req *http.Request) { + // Response extractors parse provider bodies directly, so upstream + // responses must stay uncompressed. + req.Header.Set("Accept-Encoding", "identity") +} + func copyResponseHeaders(w http.ResponseWriter, headers http.Header) { header := w.Header() @@ -189,6 +195,7 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil { return nil, ResponseMetadata{}, err } + disableUpstreamResponseCompression(upstreamReq) ctxValue := MetaContextValue{Meta: meta, RawBody: body} upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue)) @@ -319,12 +326,10 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w upstreamReq.Header[k] = v } - // FOR SSE, turn off compression explicitly - upstreamReq.Header["Accept-Encoding"] = []string{"identity"} - if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil { return ResponseMetadata{}, err } + disableUpstreamResponseCompression(upstreamReq) ctxValue := MetaContextValue{Meta: meta, RawBody: body} upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue)) diff --git a/autorouter_test.go b/autorouter_test.go index 5ba7b4e..1a4ca0f 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -2,6 +2,7 @@ package llmproxy import ( "bytes" + "compress/gzip" "context" "encoding/json" "io" @@ -153,6 +154,78 @@ func TestAutoRouter_Forward(t *testing.T) { } } +func TestAutoRouter_ForwardForcesIdentityAcceptEncoding(t *testing.T) { + var upstreamAcceptEncoding string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamAcceptEncoding = r.Header.Get("Accept-Encoding") + w.Header().Set("Content-Type", "application/json") + if upstreamAcceptEncoding != "identity" { + w.Header().Set("Content-Encoding", "gzip") + zw := gzip.NewWriter(w) + _, _ = zw.Write([]byte(`{"id":"test","model":"gpt-4","choices":[]}`)) + _ = zw.Close() + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"test","model":"gpt-4","choices":[]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "test-provider", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4"}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + req.Header.Set("Accept-Encoding", "gzip") + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return ResponseMetadata{}, nil, err + } + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return ResponseMetadata{}, nil, err + } + id, _ := raw["id"].(string) + return ResponseMetadata{ID: id}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "test-provider" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "gzip, deflate, br") + + resp, meta, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } + if meta.ID != "test" { + t.Errorf("ID = %q, want test", meta.ID) + } + if upstreamAcceptEncoding != "identity" { + t.Errorf("upstream Accept-Encoding = %q, want identity", upstreamAcceptEncoding) + } +} + func TestAutoRouter_NoProvider(t *testing.T) { detector := &mockDetector{ detectFn: func(hint ProviderHint) string { return "" }, @@ -618,6 +691,73 @@ func TestAutoRouter_StreamingNoBillingNoStreamOptions(t *testing.T) { } } +func TestAutoRouter_ForwardStreamingForcesIdentityAcceptEncoding(t *testing.T) { + eventStream := "data: {\"id\":\"test\"}\n\ndata: [DONE]\n\n" + var upstreamAcceptEncoding string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamAcceptEncoding = r.Header.Get("Accept-Encoding") + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(eventStream)) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + req.Header.Set("Accept-Encoding", "gzip") + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + if _, err := io.Copy(w, resp.Body); err != nil { + return ResponseMetadata{}, err + } + _ = rc.Flush() + return ResponseMetadata{ID: "test"}, nil + }, + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "gzip, deflate, br") + w := httptest.NewRecorder() + + meta, err := router.ForwardStreaming(context.Background(), req, w) + if err != nil { + t.Fatalf("ForwardStreaming() error = %v", err) + } + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + if meta.ID != "test" { + t.Errorf("ID = %q, want test", meta.ID) + } + if upstreamAcceptEncoding != "identity" { + t.Errorf("upstream Accept-Encoding = %q, want identity", upstreamAcceptEncoding) + } + if w.Body.String() != eventStream { + t.Errorf("Body = %q, want %q", w.Body.String(), eventStream) + } +} + func TestAutoRouter_NonStreamingNoStreamOptions(t *testing.T) { var receivedBody map[string]any upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {