Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,12 @@ type serveServer struct {
fileTrackStoreFn func() *filetrack.Store // test seam; nil → process-wide store from config
}

const (
serveReadHeaderTimeout = 5 * time.Second
serveIdleTimeout = 2 * time.Minute
serveStreamingWriteTimeout = 30 * time.Second
)

// fileTrackStore returns the file-change history store, or nil when file
// tracking is disabled.
func (s *serveServer) fileTrackStore() *filetrack.Store {
Expand All @@ -1114,13 +1120,21 @@ func (s *serveServer) fileTrackStore() *filetrack.Store {
return fileTrackingStore(s.cfgRef)
}

func (s *serveServer) newHTTPServer() *http.Server {
return &http.Server{
Addr: fmt.Sprintf("%s:%d", s.cfg.host, s.cfg.port),
Handler: s.httpHandler(),
ReadHeaderTimeout: serveReadHeaderTimeout,
IdleTimeout: serveIdleTimeout,
// Leave WriteTimeout at zero: several endpoints stream for minutes or
// longer, so they use per-write deadlines on the ResponseWriter instead.
}
}

func (s *serveServer) Start() error {
s.shutdownCh = make(chan struct{})
s.shutdownOnce = sync.Once{}
s.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", s.cfg.host, s.cfg.port),
Handler: s.httpHandler(),
}
s.server = s.newHTTPServer()

if s.cfg.ui {
s.prewarmUIAssetCache()
Expand Down
3 changes: 3 additions & 0 deletions cmd/serve_handlers_anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ func (s *serveServer) streamAnthropicMessages(ctx context.Context, w http.Respon
writeAnthropicError(w, http.StatusInternalServerError, "api_error", "streaming not supported")
return
}
streamWriter := newStreamingResponseWriter(w, serveStreamingWriteTimeout)
w = streamWriter
flusher = streamWriter

setSSEHeaders(w)
flusher.Flush()
Expand Down
3 changes: 3 additions & 0 deletions cmd/serve_handlers_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ func (s *serveServer) streamChatCompletions(ctx context.Context, w http.Response
writeOpenAIError(w, http.StatusInternalServerError, "server_error", "streaming not supported")
return
}
streamWriter := newStreamingResponseWriter(w, serveStreamingWriteTimeout)
w = streamWriter
flusher = streamWriter

setSSEHeaders(w)
flusher.Flush()
Expand Down
49 changes: 49 additions & 0 deletions cmd/serve_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
Expand Down Expand Up @@ -50,6 +51,54 @@ func setSSEHeaders(w http.ResponseWriter) {
w.Header().Set("X-Accel-Buffering", "no")
}

// streamingResponseWriter bounds each blocking Write/Flush call without putting
// a lifetime cap on the entire streaming response.
type streamingResponseWriter struct {
http.ResponseWriter
writeTimeout time.Duration
}

func newStreamingResponseWriter(w http.ResponseWriter, writeTimeout time.Duration) *streamingResponseWriter {
return &streamingResponseWriter{ResponseWriter: w, writeTimeout: writeTimeout}
}

func (w *streamingResponseWriter) Write(b []byte) (int, error) {
if err := w.setWriteDeadline(); err != nil {
return 0, err
}
defer w.clearWriteDeadline()
return w.ResponseWriter.Write(b)
}

func (w *streamingResponseWriter) Flush() {
if err := w.setWriteDeadline(); err != nil {
return
}
defer w.clearWriteDeadline()
w.ResponseWriter.(http.Flusher).Flush()
}

func (w *streamingResponseWriter) setWriteDeadline() error {
if w.writeTimeout <= 0 {
return nil
}
err := http.NewResponseController(w.ResponseWriter).SetWriteDeadline(time.Now().Add(w.writeTimeout))
if err == nil || errors.Is(err, http.ErrNotSupported) {
return nil
}
return fmt.Errorf("set streaming write deadline: %w", err)
}

func (w *streamingResponseWriter) clearWriteDeadline() {
if w.writeTimeout <= 0 {
return
}
err := http.NewResponseController(w.ResponseWriter).SetWriteDeadline(time.Time{})
if err != nil && !errors.Is(err, http.ErrNotSupported) {
return
}
}

func writeSSEEvent(w io.Writer, event string, payload any) error {
b, err := json.Marshal(payload)
if err != nil {
Expand Down
77 changes: 77 additions & 0 deletions cmd/serve_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"image"
"image/color"
"image/jpeg"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/samsaffron/term-llm/internal/session"
)
Expand Down Expand Up @@ -48,6 +50,81 @@ func TestSetSessionNumberHeader(t *testing.T) {
}
}

type streamingDeadlineRecorder struct {
header http.Header
body bytes.Buffer
deadlines []time.Time
flushes int
}

func (r *streamingDeadlineRecorder) Header() http.Header {
if r.header == nil {
r.header = make(http.Header)
}
return r.header
}

func (r *streamingDeadlineRecorder) Write(b []byte) (int, error) {
return r.body.Write(b)
}

func (r *streamingDeadlineRecorder) WriteHeader(statusCode int) {}

func (r *streamingDeadlineRecorder) Flush() {
r.flushes++
}

func (r *streamingDeadlineRecorder) SetWriteDeadline(deadline time.Time) error {
r.deadlines = append(r.deadlines, deadline)
return nil
}

func TestStreamingResponseWriterSetsAndClearsWriteDeadline(t *testing.T) {
recorder := &streamingDeadlineRecorder{}
w := newStreamingResponseWriter(recorder, 3*time.Second)

if _, err := w.Write([]byte("hello")); err != nil {
t.Fatalf("Write() error = %v", err)
}
w.Flush()

if got := recorder.body.String(); got != "hello" {
t.Fatalf("body = %q, want hello", got)
}
if recorder.flushes != 1 {
t.Fatalf("flushes = %d, want 1", recorder.flushes)
}
if len(recorder.deadlines) != 4 {
t.Fatalf("deadline calls = %d, want 4", len(recorder.deadlines))
}
if recorder.deadlines[0].IsZero() {
t.Fatal("first deadline was not set")
}
if !recorder.deadlines[1].IsZero() {
t.Fatal("second deadline was not cleared")
}
if recorder.deadlines[2].IsZero() {
t.Fatal("third deadline was not set")
}
if !recorder.deadlines[3].IsZero() {
t.Fatal("fourth deadline was not cleared")
}
}

func TestStreamingResponseWriterIgnoresUnsupportedDeadline(t *testing.T) {
recorder := httptest.NewRecorder()
w := newStreamingResponseWriter(recorder, 3*time.Second)

if _, err := w.Write([]byte("ok")); err != nil {
t.Fatalf("Write() error = %v", err)
}
w.Flush()

if got := recorder.Body.String(); got != "ok" {
t.Fatalf("body = %q, want ok", got)
}
}

func TestParseUserMessageContent_AllowsUpToMaxInlineImages(t *testing.T) {
t.Setenv("XDG_DATA_HOME", t.TempDir())

Expand Down
3 changes: 3 additions & 0 deletions cmd/serve_response_runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,9 @@ func (s *serveServer) streamResponseRunEvents(ctx context.Context, w http.Respon
writeOpenAIError(w, http.StatusInternalServerError, "server_error", "streaming not supported")
return
}
streamWriter := newStreamingResponseWriter(w, serveStreamingWriteTimeout)
w = streamWriter
flusher = streamWriter

setSSEHeaders(w)
flusher.Flush()
Expand Down
18 changes: 18 additions & 0 deletions cmd/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ func TestServeServerStopIsIdempotent(t *testing.T) {
}
}

func TestServeServerNewHTTPServerSetsTimeouts(t *testing.T) {
s := &serveServer{cfg: serveServerConfig{host: "127.0.0.1", port: 8080}}
server := s.newHTTPServer()

if server == nil {
t.Fatal("newHTTPServer() = nil")
}
if got := server.ReadHeaderTimeout; got != serveReadHeaderTimeout {
t.Fatalf("ReadHeaderTimeout = %v, want %v", got, serveReadHeaderTimeout)
}
if got := server.IdleTimeout; got != serveIdleTimeout {
t.Fatalf("IdleTimeout = %v, want %v", got, serveIdleTimeout)
}
if got := server.WriteTimeout; got != 0 {
t.Fatalf("WriteTimeout = %v, want 0 for streaming endpoints", got)
}
}

func readSSEEvent(t *testing.T, scanner *bufio.Scanner) (string, string, bool) {
t.Helper()

Expand Down
Loading