From 5361988829ec99186992e95a1a0d1ae0af76ce5f Mon Sep 17 00:00:00 2001 From: Jarvis Date: Sat, 27 Jun 2026 23:19:35 +1000 Subject: [PATCH] fix: serve HTTP server has no read/write/idle deadlines, so slow clients can pin handlers and block shutdown --- cmd/serve.go | 22 ++++++++-- cmd/serve_handlers_anthropic.go | 3 ++ cmd/serve_handlers_chat.go | 3 ++ cmd/serve_protocol.go | 49 +++++++++++++++++++++ cmd/serve_protocol_test.go | 77 +++++++++++++++++++++++++++++++++ cmd/serve_response_runs.go | 3 ++ cmd/serve_test.go | 18 ++++++++ 7 files changed, 171 insertions(+), 4 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 0d5a5300..76ad6be0 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -1105,6 +1105,12 @@ type serveServer struct { fileTrackStoreFn func() *filetrack.Store // test seam; nil → process-wide store from config } +const ( + serveReadHeaderTimeout = 5 * time.Second + serveIdleTimeout = 2 * time.Minute + serveStreamingWriteTimeout = 30 * time.Second +) + // fileTrackStore returns the file-change history store, or nil when file // tracking is disabled. func (s *serveServer) fileTrackStore() *filetrack.Store { @@ -1114,13 +1120,21 @@ func (s *serveServer) fileTrackStore() *filetrack.Store { return fileTrackingStore(s.cfgRef) } +func (s *serveServer) newHTTPServer() *http.Server { + return &http.Server{ + Addr: fmt.Sprintf("%s:%d", s.cfg.host, s.cfg.port), + Handler: s.httpHandler(), + ReadHeaderTimeout: serveReadHeaderTimeout, + IdleTimeout: serveIdleTimeout, + // Leave WriteTimeout at zero: several endpoints stream for minutes or + // longer, so they use per-write deadlines on the ResponseWriter instead. + } +} + func (s *serveServer) Start() error { s.shutdownCh = make(chan struct{}) s.shutdownOnce = sync.Once{} - s.server = &http.Server{ - Addr: fmt.Sprintf("%s:%d", s.cfg.host, s.cfg.port), - Handler: s.httpHandler(), - } + s.server = s.newHTTPServer() if s.cfg.ui { s.prewarmUIAssetCache() diff --git a/cmd/serve_handlers_anthropic.go b/cmd/serve_handlers_anthropic.go index 1a10f663..31e42156 100644 --- a/cmd/serve_handlers_anthropic.go +++ b/cmd/serve_handlers_anthropic.go @@ -185,6 +185,9 @@ func (s *serveServer) streamAnthropicMessages(ctx context.Context, w http.Respon writeAnthropicError(w, http.StatusInternalServerError, "api_error", "streaming not supported") return } + streamWriter := newStreamingResponseWriter(w, serveStreamingWriteTimeout) + w = streamWriter + flusher = streamWriter setSSEHeaders(w) flusher.Flush() diff --git a/cmd/serve_handlers_chat.go b/cmd/serve_handlers_chat.go index f45c02e9..116e8825 100644 --- a/cmd/serve_handlers_chat.go +++ b/cmd/serve_handlers_chat.go @@ -164,6 +164,9 @@ func (s *serveServer) streamChatCompletions(ctx context.Context, w http.Response writeOpenAIError(w, http.StatusInternalServerError, "server_error", "streaming not supported") return } + streamWriter := newStreamingResponseWriter(w, serveStreamingWriteTimeout) + w = streamWriter + flusher = streamWriter setSSEHeaders(w) flusher.Flush() diff --git a/cmd/serve_protocol.go b/cmd/serve_protocol.go index fb60d63c..c4fd9346 100644 --- a/cmd/serve_protocol.go +++ b/cmd/serve_protocol.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "mime" @@ -50,6 +51,54 @@ func setSSEHeaders(w http.ResponseWriter) { w.Header().Set("X-Accel-Buffering", "no") } +// streamingResponseWriter bounds each blocking Write/Flush call without putting +// a lifetime cap on the entire streaming response. +type streamingResponseWriter struct { + http.ResponseWriter + writeTimeout time.Duration +} + +func newStreamingResponseWriter(w http.ResponseWriter, writeTimeout time.Duration) *streamingResponseWriter { + return &streamingResponseWriter{ResponseWriter: w, writeTimeout: writeTimeout} +} + +func (w *streamingResponseWriter) Write(b []byte) (int, error) { + if err := w.setWriteDeadline(); err != nil { + return 0, err + } + defer w.clearWriteDeadline() + return w.ResponseWriter.Write(b) +} + +func (w *streamingResponseWriter) Flush() { + if err := w.setWriteDeadline(); err != nil { + return + } + defer w.clearWriteDeadline() + w.ResponseWriter.(http.Flusher).Flush() +} + +func (w *streamingResponseWriter) setWriteDeadline() error { + if w.writeTimeout <= 0 { + return nil + } + err := http.NewResponseController(w.ResponseWriter).SetWriteDeadline(time.Now().Add(w.writeTimeout)) + if err == nil || errors.Is(err, http.ErrNotSupported) { + return nil + } + return fmt.Errorf("set streaming write deadline: %w", err) +} + +func (w *streamingResponseWriter) clearWriteDeadline() { + if w.writeTimeout <= 0 { + return + } + err := http.NewResponseController(w.ResponseWriter).SetWriteDeadline(time.Time{}) + if err != nil && !errors.Is(err, http.ErrNotSupported) { + return + } +} + func writeSSEEvent(w io.Writer, event string, payload any) error { b, err := json.Marshal(payload) if err != nil { diff --git a/cmd/serve_protocol_test.go b/cmd/serve_protocol_test.go index fa0db386..ecd21a84 100644 --- a/cmd/serve_protocol_test.go +++ b/cmd/serve_protocol_test.go @@ -8,11 +8,13 @@ import ( "image" "image/color" "image/jpeg" + "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" + "time" "github.com/samsaffron/term-llm/internal/session" ) @@ -48,6 +50,81 @@ func TestSetSessionNumberHeader(t *testing.T) { } } +type streamingDeadlineRecorder struct { + header http.Header + body bytes.Buffer + deadlines []time.Time + flushes int +} + +func (r *streamingDeadlineRecorder) Header() http.Header { + if r.header == nil { + r.header = make(http.Header) + } + return r.header +} + +func (r *streamingDeadlineRecorder) Write(b []byte) (int, error) { + return r.body.Write(b) +} + +func (r *streamingDeadlineRecorder) WriteHeader(statusCode int) {} + +func (r *streamingDeadlineRecorder) Flush() { + r.flushes++ +} + +func (r *streamingDeadlineRecorder) SetWriteDeadline(deadline time.Time) error { + r.deadlines = append(r.deadlines, deadline) + return nil +} + +func TestStreamingResponseWriterSetsAndClearsWriteDeadline(t *testing.T) { + recorder := &streamingDeadlineRecorder{} + w := newStreamingResponseWriter(recorder, 3*time.Second) + + if _, err := w.Write([]byte("hello")); err != nil { + t.Fatalf("Write() error = %v", err) + } + w.Flush() + + if got := recorder.body.String(); got != "hello" { + t.Fatalf("body = %q, want hello", got) + } + if recorder.flushes != 1 { + t.Fatalf("flushes = %d, want 1", recorder.flushes) + } + if len(recorder.deadlines) != 4 { + t.Fatalf("deadline calls = %d, want 4", len(recorder.deadlines)) + } + if recorder.deadlines[0].IsZero() { + t.Fatal("first deadline was not set") + } + if !recorder.deadlines[1].IsZero() { + t.Fatal("second deadline was not cleared") + } + if recorder.deadlines[2].IsZero() { + t.Fatal("third deadline was not set") + } + if !recorder.deadlines[3].IsZero() { + t.Fatal("fourth deadline was not cleared") + } +} + +func TestStreamingResponseWriterIgnoresUnsupportedDeadline(t *testing.T) { + recorder := httptest.NewRecorder() + w := newStreamingResponseWriter(recorder, 3*time.Second) + + if _, err := w.Write([]byte("ok")); err != nil { + t.Fatalf("Write() error = %v", err) + } + w.Flush() + + if got := recorder.Body.String(); got != "ok" { + t.Fatalf("body = %q, want ok", got) + } +} + func TestParseUserMessageContent_AllowsUpToMaxInlineImages(t *testing.T) { t.Setenv("XDG_DATA_HOME", t.TempDir()) diff --git a/cmd/serve_response_runs.go b/cmd/serve_response_runs.go index b6be12c4..0c24128a 100644 --- a/cmd/serve_response_runs.go +++ b/cmd/serve_response_runs.go @@ -1648,6 +1648,9 @@ func (s *serveServer) streamResponseRunEvents(ctx context.Context, w http.Respon writeOpenAIError(w, http.StatusInternalServerError, "server_error", "streaming not supported") return } + streamWriter := newStreamingResponseWriter(w, serveStreamingWriteTimeout) + w = streamWriter + flusher = streamWriter setSSEHeaders(w) flusher.Flush() diff --git a/cmd/serve_test.go b/cmd/serve_test.go index 582278f6..f7aa4cc1 100644 --- a/cmd/serve_test.go +++ b/cmd/serve_test.go @@ -245,6 +245,24 @@ func TestServeServerStopIsIdempotent(t *testing.T) { } } +func TestServeServerNewHTTPServerSetsTimeouts(t *testing.T) { + s := &serveServer{cfg: serveServerConfig{host: "127.0.0.1", port: 8080}} + server := s.newHTTPServer() + + if server == nil { + t.Fatal("newHTTPServer() = nil") + } + if got := server.ReadHeaderTimeout; got != serveReadHeaderTimeout { + t.Fatalf("ReadHeaderTimeout = %v, want %v", got, serveReadHeaderTimeout) + } + if got := server.IdleTimeout; got != serveIdleTimeout { + t.Fatalf("IdleTimeout = %v, want %v", got, serveIdleTimeout) + } + if got := server.WriteTimeout; got != 0 { + t.Fatalf("WriteTimeout = %v, want 0 for streaming endpoints", got) + } +} + func readSSEEvent(t *testing.T, scanner *bufio.Scanner) (string, string, bool) { t.Helper()