From 78d7176030063260dc04663b547222c76786df27 Mon Sep 17 00:00:00 2001 From: Piyush Jagadish Bag Date: Tue, 9 Jun 2026 19:34:24 -0700 Subject: [PATCH] mcp: add StreamableHTTPHandler.Close for graceful shutdown Expose public Close() to tear down all sessions and reject new requests with 503, matching the API proposed in #440. --- mcp/streamable.go | 44 ++++++++++++++++++++++++-------------- mcp/streamable_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 5bc31771..a568b748 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -51,6 +51,7 @@ type StreamableHTTPHandler struct { onTransportDeletion func(sessionID string) // for testing mu sync.Mutex + closed bool sessions map[string]*sessionInfo // keyed by session ID } @@ -219,27 +220,30 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea return h } -// closeAll closes all ongoing sessions, for tests. -// -// TODO(rfindley): investigate the best API for callers to configure their -// session lifecycle. (?) +// Close closes the handler, closing and removing all connected sessions, +// and preventing new sessions from being added. // -// Should we allow passing in a session store? That would allow the handler to -// be stateless. -func (h *StreamableHTTPHandler) closeAll() { - // TODO: if we ever expose this outside of tests, we'll need to do better - // than simply collecting sessions while holding the lock: we need to prevent - // new sessions from being added. - // - // Currently, sessions remove themselves from h.sessions when closed, so we - // can't call Close while holding the lock. +// Close is idempotent. +func (h *StreamableHTTPHandler) Close() error { h.mu.Lock() + if h.closed { + h.mu.Unlock() + return nil + } + h.closed = true sessionInfos := slices.Collect(maps.Values(h.sessions)) - h.sessions = nil + h.sessions = make(map[string]*sessionInfo) h.mu.Unlock() - for _, s := range sessionInfos { - s.session.Close() + + for _, info := range sessionInfos { + info.session.Close() } + return nil +} + +// closeAll closes all ongoing sessions, for tests. +func (h *StreamableHTTPHandler) closeAll() { + _ = h.Close() } // disablelocalhostprotection is a compatibility parameter that allows to disable @@ -302,6 +306,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } } + h.mu.Lock() + closed := h.closed + h.mu.Unlock() + if closed { + http.Error(w, "handler closed", http.StatusServiceUnavailable) + return + } + // [ยง2.7] of the spec (2025-06-18): validate the MCP-Protocol-Version // header. If provided, it must be a supported version. If absent, the // version is unknown (the request may be an initialize for any version). diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..628b302c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -627,6 +627,54 @@ func TestStreamableServerDisconnect(t *testing.T) { } } +func TestStreamableHTTPHandlerClose(t *testing.T) { + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client := NewClient(testImpl, nil) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + t.Cleanup(func() { _ = clientSession.Close() }) + + handler.mu.Lock() + if len(handler.sessions) != 1 { + t.Fatalf("want 1 session before Close, got %d", len(handler.sessions)) + } + handler.mu.Unlock() + + if err := handler.Close(); err != nil { + t.Fatalf("Close() failed: %v", err) + } + if err := handler.Close(); err != nil { + t.Fatalf("second Close() failed: %v", err) + } + + handler.mu.Lock() + if len(handler.sessions) != 0 { + t.Fatalf("want 0 sessions after Close, got %d", len(handler.sessions)) + } + if !handler.closed { + t.Fatal("want handler.closed true after Close") + } + handler.mu.Unlock() + + resp, err := http.Get(httpServer.URL) + if err != nil { + t.Fatalf("http.Get after Close failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("got status %d after Close, want %d", resp.StatusCode, http.StatusServiceUnavailable) + } +} + func TestServerTransportCleanup(t *testing.T) { nClient := 3