Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions autorouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ import (

var skipHeaders = []string{"Content-Encoding", "Content-Length"}

func disableUpstreamResponseCompression(req *http.Request) {
// Response extractors parse provider bodies directly, so upstream
// responses must stay uncompressed.
req.Header.Set("Accept-Encoding", "identity")
}

func copyResponseHeaders(w http.ResponseWriter, headers http.Header) {
header := w.Header()

Expand Down Expand Up @@ -189,6 +195,7 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp
if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil {
return nil, ResponseMetadata{}, err
}
disableUpstreamResponseCompression(upstreamReq)

ctxValue := MetaContextValue{Meta: meta, RawBody: body}
upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue))
Expand Down Expand Up @@ -319,12 +326,10 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w
upstreamReq.Header[k] = v
}

// FOR SSE, turn off compression explicitly
upstreamReq.Header["Accept-Encoding"] = []string{"identity"}

if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil {
return ResponseMetadata{}, err
}
disableUpstreamResponseCompression(upstreamReq)

ctxValue := MetaContextValue{Meta: meta, RawBody: body}
upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue))
Expand Down
140 changes: 140 additions & 0 deletions autorouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package llmproxy

import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"io"
Expand Down Expand Up @@ -153,6 +154,78 @@ func TestAutoRouter_Forward(t *testing.T) {
}
}

func TestAutoRouter_ForwardForcesIdentityAcceptEncoding(t *testing.T) {
var upstreamAcceptEncoding string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstreamAcceptEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "application/json")
if upstreamAcceptEncoding != "identity" {
w.Header().Set("Content-Encoding", "gzip")
zw := gzip.NewWriter(w)
_, _ = zw.Write([]byte(`{"id":"test","model":"gpt-4","choices":[]}`))
_ = zw.Close()
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"id":"test","model":"gpt-4","choices":[]}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "test-provider",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "gpt-4"}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error {
req.Header.Set("Accept-Encoding", "gzip")
return nil
},
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return ParseURL(upstream.URL + "/v1/chat/completions")
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return ResponseMetadata{}, nil, err
}
var raw map[string]any
if err := json.Unmarshal(body, &raw); err != nil {
return ResponseMetadata{}, nil, err
}
id, _ := raw["id"].(string)
return ResponseMetadata{ID: id}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string {
return "test-provider"
})),
)
router.RegisterProvider(provider)

req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "gzip, deflate, br")

resp, meta, err := router.Forward(context.Background(), req)
if err != nil {
t.Fatalf("Forward() error = %v", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %d, want 200", resp.StatusCode)
}
if meta.ID != "test" {
t.Errorf("ID = %q, want test", meta.ID)
}
if upstreamAcceptEncoding != "identity" {
t.Errorf("upstream Accept-Encoding = %q, want identity", upstreamAcceptEncoding)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

func TestAutoRouter_NoProvider(t *testing.T) {
detector := &mockDetector{
detectFn: func(hint ProviderHint) string { return "" },
Expand Down Expand Up @@ -618,6 +691,73 @@ func TestAutoRouter_StreamingNoBillingNoStreamOptions(t *testing.T) {
}
}

func TestAutoRouter_ForwardStreamingForcesIdentityAcceptEncoding(t *testing.T) {
eventStream := "data: {\"id\":\"test\"}\n\ndata: [DONE]\n\n"
var upstreamAcceptEncoding string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstreamAcceptEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(eventStream))
}))
defer upstream.Close()

provider := &mockStreamingProvider{
mockProvider: &mockProvider{
name: "test",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error {
req.Header.Set("Accept-Encoding", "gzip")
return nil
},
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
},
streamingExtractor: &mockStreamingExtractor{
isStreaming: true,
extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) {
if _, err := io.Copy(w, resp.Body); err != nil {
return ResponseMetadata{}, err
}
_ = rc.Flush()
return ResponseMetadata{ID: "test"}, nil
},
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
w := httptest.NewRecorder()

meta, err := router.ForwardStreaming(context.Background(), req, w)
if err != nil {
t.Fatalf("ForwardStreaming() error = %v", err)
}

if w.Code != http.StatusOK {
t.Errorf("StatusCode = %d, want 200", w.Code)
}
if meta.ID != "test" {
t.Errorf("ID = %q, want test", meta.ID)
}
if upstreamAcceptEncoding != "identity" {
t.Errorf("upstream Accept-Encoding = %q, want identity", upstreamAcceptEncoding)
}
if w.Body.String() != eventStream {
t.Errorf("Body = %q, want %q", w.Body.String(), eventStream)
}
}

func TestAutoRouter_NonStreamingNoStreamOptions(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading