Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type StreamableHTTPHandler struct {
onTransportDeletion func(sessionID string) // for testing

mu sync.Mutex
closed bool
sessions map[string]*sessionInfo // keyed by session ID
}

Expand Down Expand Up @@ -219,27 +220,30 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
return h
}

// closeAll closes all ongoing sessions, for tests.
//
// TODO(rfindley): investigate the best API for callers to configure their
// session lifecycle. (?)
// Close closes the handler, closing and removing all connected sessions,
// and preventing new sessions from being added.
//
// Should we allow passing in a session store? That would allow the handler to
// be stateless.
func (h *StreamableHTTPHandler) closeAll() {
// TODO: if we ever expose this outside of tests, we'll need to do better
// than simply collecting sessions while holding the lock: we need to prevent
// new sessions from being added.
//
// Currently, sessions remove themselves from h.sessions when closed, so we
// can't call Close while holding the lock.
// Close is idempotent.
func (h *StreamableHTTPHandler) Close() error {
h.mu.Lock()
if h.closed {
h.mu.Unlock()
return nil
}
h.closed = true
sessionInfos := slices.Collect(maps.Values(h.sessions))
h.sessions = nil
h.sessions = make(map[string]*sessionInfo)
h.mu.Unlock()
for _, s := range sessionInfos {
s.session.Close()

for _, info := range sessionInfos {
info.session.Close()
}
return nil
}

// closeAll closes all ongoing sessions, for tests.
func (h *StreamableHTTPHandler) closeAll() {
_ = h.Close()
}

// disablelocalhostprotection is a compatibility parameter that allows to disable
Expand Down Expand Up @@ -302,6 +306,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
}
}

h.mu.Lock()
closed := h.closed
h.mu.Unlock()
if closed {
http.Error(w, "handler closed", http.StatusServiceUnavailable)
return
}

// [§2.7] of the spec (2025-06-18): validate the MCP-Protocol-Version
// header. If provided, it must be a supported version. If absent, the
// version is unknown (the request may be an initialize for any version).
Expand Down
48 changes: 48 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,54 @@ func TestStreamableServerDisconnect(t *testing.T) {
}
}

func TestStreamableHTTPHandlerClose(t *testing.T) {
server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client := NewClient(testImpl, nil)
clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
t.Cleanup(func() { _ = clientSession.Close() })

handler.mu.Lock()
if len(handler.sessions) != 1 {
t.Fatalf("want 1 session before Close, got %d", len(handler.sessions))
}
handler.mu.Unlock()

if err := handler.Close(); err != nil {
t.Fatalf("Close() failed: %v", err)
}
if err := handler.Close(); err != nil {
t.Fatalf("second Close() failed: %v", err)
}

handler.mu.Lock()
if len(handler.sessions) != 0 {
t.Fatalf("want 0 sessions after Close, got %d", len(handler.sessions))
}
if !handler.closed {
t.Fatal("want handler.closed true after Close")
}
handler.mu.Unlock()

resp, err := http.Get(httpServer.URL)
if err != nil {
t.Fatalf("http.Get after Close failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("got status %d after Close, want %d", resp.StatusCode, http.StatusServiceUnavailable)
}
}

func TestServerTransportCleanup(t *testing.T) {
nClient := 3

Expand Down