From c99ca347f6d70988c155d6e75550f182938cf8de Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 11 Jun 2026 15:32:52 +0000 Subject: [PATCH 1/7] feat: implement subscriptions/listen with SSE support for long-lived streams and end-to-end testing --- mcp/client.go | 36 +++++++++ mcp/protocol.go | 48 ++++++++++++ mcp/requests.go | 1 + mcp/server.go | 162 ++++++++++++++++++++++++++++++++++++++--- mcp/streamable.go | 33 +++++++-- mcp/streamable_test.go | 134 ++++++++++++++++++++++++++++++++++ mcp/transport_test.go | 151 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 550 insertions(+), 15 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 1e0e18a4..e40f148a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -300,6 +300,26 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } + subscribeParams := &SubscriptionsListenParams{} + if c.opts.ToolListChangedHandler != nil { + subscribeParams.Notifications.ToolsListChanged = true + } + if c.opts.PromptListChangedHandler != nil { + subscribeParams.Notifications.PromptsListChanged = true + } + if c.opts.ResourceListChangedHandler != nil { + subscribeParams.Notifications.ResourcesListChanged = true + } + if subscribeParams.Notifications.ToolsListChanged || + subscribeParams.Notifications.PromptsListChanged || + subscribeParams.Notifications.ResourcesListChanged { + // The listen blocks until the server cancels it (or the + // underlying connection ends). Run it in a goroutine so + // Connect can return; the SDK retires it on cs.Close(). + go func() { + _ = cs.subscriptionsListen(context.Background(), subscribeParams) + }() + } return cs, nil } @@ -1079,6 +1099,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), + notificationSubscriptionsAck: newClientMethodInfo(clientMethod((*Client).callSubscriptionsAckHandler), notification|missingParamsOK), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -1234,6 +1255,21 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar return err } +// SubscriptionsListen opens a SEP-2575 "subscriptions/listen" stream and +// blocks for the lifetime of the subscription. The server's first message on +// the stream is "notifications/subscriptions/acknowledged"; subsequent +// opted-in notifications (e.g. tools/list_changed) are delivered through the +// usual handlers registered in [ClientOptions]. +func (cs *ClientSession) subscriptionsListen(ctx context.Context, params *SubscriptionsListenParams) error { + params = injectRequestMeta(cs, params) + _, err := handleSend[*emptyResult](ctx, methodSubscriptionsListen, newClientRequest(cs, orZero[Params](params))) + return err +} + +func (c *Client) callSubscriptionsAckHandler(context.Context, *ClientRequest[*SubscriptionsAcknowledgedParams]) (Result, error) { + return nil, nil +} + func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { if h := c.opts.ToolListChangedHandler; h != nil { h(ctx, req) diff --git a/mcp/protocol.go b/mcp/protocol.go index 4ba8600e..6eca3698 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -1883,6 +1883,49 @@ type ResourceUpdatedNotificationParams struct { func (x *ResourceUpdatedNotificationParams) isParams() {} func (x *ResourceUpdatedNotificationParams) isNil() bool { return x == nil } +// NotificationSubscriptions describes the set of server-to-client +// notifications a client wishes to receive on a [SubscriptionsListenParams] +// stream. Each field is an explicit opt-in: a server MUST NOT push +// notifications of a type the client did not request. +type NotificationSubscriptions struct { + // ToolsListChanged opts in to "notifications/tools/list_changed". + ToolsListChanged bool `json:"toolsListChanged,omitempty"` + // PromptsListChanged opts in to "notifications/prompts/list_changed". + PromptsListChanged bool `json:"promptsListChanged,omitempty"` + // ResourcesListChanged opts in to "notifications/resources/list_changed". + ResourcesListChanged bool `json:"resourcesListChanged,omitempty"` + // ResourceSubscriptions enumerates the resource URIs for which the client + // wants "notifications/resources/updated". Replaces the legacy + // resources/subscribe RPC. + ResourceSubscriptions []string `json:"resourceSubscriptions,omitempty"` +} + +// SubscriptionsListenParams are the parameters for the +// "subscriptions/listen" RPC. +type SubscriptionsListenParams struct { + // Meta carries the per-request `_meta` triple. + Meta `json:"_meta,omitempty"` + // Notifications declares which notification types the client wants to + // receive on this stream. + Notifications NotificationSubscriptions `json:"notifications"` +} + +func (x *SubscriptionsListenParams) isParams() {} +func (x *SubscriptionsListenParams) isNil() bool { return x == nil } + +// SubscriptionsAcknowledgedParams are the parameters for the +// "notifications/subscriptions/acknowledged" notification, which the server +// MUST send as the first message on a subscriptions/listen stream. It carries +// the subset of the requested [NotificationSubscriptions] that the server has +// agreed to honor. +type SubscriptionsAcknowledgedParams struct { + Meta `json:"_meta,omitempty"` + Notifications NotificationSubscriptions `json:"notifications"` +} + +func (x *SubscriptionsAcknowledgedParams) isParams() {} +func (x *SubscriptionsAcknowledgedParams) isNil() bool { return x == nil } + // TODO(jba): add CompleteRequest and related types. // A request from the server to elicit additional information from the user via the client. @@ -2086,8 +2129,10 @@ const ( notificationRootsListChanged = "notifications/roots/list_changed" methodSetLevel = "logging/setLevel" methodSubscribe = "resources/subscribe" + methodSubscriptionsListen = "subscriptions/listen" notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" + notificationSubscriptionsAck = "notifications/subscriptions/acknowledged" ) // Per-request _meta field names for the >= 2026-06-30 protocol version. @@ -2104,6 +2149,9 @@ const ( MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" // MetaKeyLogLevel identifies the desired log level for the request. MetaKeyLogLevel = "io.modelcontextprotocol/logLevel" + // MetaKeySubscriptionID identifies the subscriptions/listen request that an + // out-of-band notification belongs to. + MetaKeySubscriptionID = "io.modelcontextprotocol/subscriptionId" ) // UnsupportedProtocolVersionData is the SEP-2575 payload carried in the diff --git a/mcp/requests.go b/mcp/requests.go index 36368c99..64414caa 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -19,6 +19,7 @@ type ( ReadResourceRequest = ServerRequest[*ReadResourceParams] RootsListChangedRequest = ServerRequest[*RootsListChangedParams] SubscribeRequest = ServerRequest[*SubscribeParams] + SubscriptionsListenRequest = ServerRequest[*SubscriptionsListenParams] UnsubscribeRequest = ServerRequest[*UnsubscribeParams] ) diff --git a/mcp/server.go b/mcp/server.go index adcf698f..d4af3761 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -53,7 +53,10 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + toolsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + promptsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + resourcesSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -204,6 +207,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), + toolsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), + promptsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), + resourcesSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) @@ -643,12 +649,11 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR return s.opts.CompletionHandler(ctx, req) } -// Map from notification name to its corresponding params. The params have no fields, -// so a single struct can be reused. -var changeNotificationParams = map[string]Params{ - notificationToolListChanged: &ToolListChangedParams{}, - notificationPromptListChanged: &PromptListChangedParams{}, - notificationResourceListChanged: &ResourceListChangedParams{}, +// Map from notification name to its corresponding params. +var changeNotificationParams = map[string]func() Params{ + notificationToolListChanged: func() Params { return &ToolListChangedParams{} }, + notificationPromptListChanged: func() Params { return &PromptListChangedParams{} }, + notificationResourceListChanged: func() Params { return &ResourceListChangedParams{} }, } // How long to wait before sending a change notification. @@ -687,7 +692,47 @@ func (s *Server) notifySessions(n string) { sessions := slices.Clone(s.sessions) s.pendingNotifications[n] = nil s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. - notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) + + // Only add legacy sessions for the notification, new ones use the new notification mechanism. + var legacySessions []*ServerSession + for _, s := range sessions { + if s.InitializeParams().isNil() || s.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, s) + } + } + notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) + + var activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + switch n { + case notificationToolListChanged: + activeSubscriptions = s.toolsSubscriptions + case notificationPromptListChanged: + activeSubscriptions = s.promptsSubscriptions + case notificationResourceListChanged: + activeSubscriptions = s.resourcesSubscriptions + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for session := range activeSubscriptions { + for reqID := range activeSubscriptions[session] { + params := changeNotificationParams[n]() + setSubscriptionID(params, reqID) + req := newRequest(session, params) + ctx = context.WithValue(ctx, idContextKey{}, reqID) + if err := handleNotify(ctx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + } + } + } +} + +func setSubscriptionID(params Params, reqID jsonrpc.ID) { + m := params.GetMeta() + if m == nil { + m = map[string]any{} + } + m[MetaKeySubscriptionID] = reqID.Raw() + params.SetMeta(m) } // shouldSendListChangedNotification checks if the server's capabilities allow @@ -1027,6 +1072,80 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } +func (s *Server) addSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() + ids := index[ss] + if ids == nil { + ids = make(map[jsonrpc.ID]bool) + index[ss] = ids + } + ids[id] = true +} + +func (s *Server) removeSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() + ids, ok := index[ss] + if !ok { + return + } + delete(ids, id) + if len(ids) == 0 { + delete(index, ss) + } +} + +func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) + } + + allowed := s.allowedSubscriptions(req.Params.Notifications) + if allowed.ToolsListChanged { + s.addSubscription(s.toolsSubscriptions, req.Session, requestID) + } + if allowed.PromptsListChanged { + s.addSubscription(s.promptsSubscriptions, req.Session, requestID) + } + if allowed.ResourcesListChanged { + s.addSubscription(s.resourcesSubscriptions, req.Session, requestID) + } + + ackParams := &SubscriptionsAcknowledgedParams{ + Notifications: allowed, + Meta: Meta{MetaKeySubscriptionID: requestID.Raw()}, + } + if err := req.Session.NotifySubscriptionAcked(ctx, ackParams); err != nil { + return nil, fmt.Errorf("sending subscriptions/acknowledged: %w", err) + } + + defer func() { + s.removeSubscription(s.toolsSubscriptions, req.Session, requestID) + s.removeSubscription(s.promptsSubscriptions, req.Session, requestID) + s.removeSubscription(s.resourcesSubscriptions, req.Session, requestID) + }() + + <-ctx.Done() + return &emptyResult{}, nil +} + +func (s *Server) allowedSubscriptions(want NotificationSubscriptions) NotificationSubscriptions { + caps := s.capabilities() + agreed := NotificationSubscriptions{} + if want.ToolsListChanged && caps.Tools != nil && caps.Tools.ListChanged { + agreed.ToolsListChanged = true + } + if want.PromptsListChanged && caps.Prompts != nil && caps.Prompts.ListChanged { + agreed.PromptsListChanged = true + } + if want.ResourcesListChanged && caps.Resources != nil && caps.Resources.ListChanged { + agreed.ResourcesListChanged = true + } + return agreed +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -1096,6 +1215,10 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } + delete(s.toolsSubscriptions, cc) + delete(s.promptsSubscriptions, cc) + delete(s.resourcesSubscriptions, cc) + s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) } @@ -1196,6 +1319,12 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) } +// NotifySubscriptionAcked sends a subscription acknowledged notification from the server to the client +// associated with this session. +func (ss *ServerSession) NotifySubscriptionAcked(ctx context.Context, params *SubscriptionsAcknowledgedParams) error { + return handleNotify(ctx, notificationSubscriptionsAck, newServerRequest(ss, orZero[Params](params))) +} + func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { return &ServerRequest[P]{Session: ss, Params: params} } @@ -1497,6 +1626,7 @@ var serverMethodInfos = map[string]methodInfo{ methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodSubscriptionsListen: newServerMethodInfo(serverMethod((*Server).subscriptionsListen), 0), methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), @@ -1577,7 +1707,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, } switch req.Method { - case methodInitialize, methodPing, notificationInitialized, methodSetLevel: + case methodInitialize, methodPing, notificationInitialized, methodSetLevel, methodSubscribe, methodUnsubscribe: if validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method removed in the new protocol", "method", req.Method) return nil, &jsonrpc.Error{ @@ -1663,7 +1793,11 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { +func (ss *ServerSession) cancel(ctx context.Context, param *CancelledParams) (Result, error) { + server := ss.server + server.removeSubscription(server.toolsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) + server.removeSubscription(server.promptsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) + server.removeSubscription(server.resourcesSubscriptions, ss, param.RequestID.(jsonrpc.ID)) return nil, nil } @@ -1696,6 +1830,14 @@ func (ss *ServerSession) Close() error { ss.onClose() } + // Clean up session subscriptions + server := ss.server + server.mu.Lock() + delete(server.toolsSubscriptions, ss) + delete(server.promptsSubscriptions, ss) + delete(server.resourcesSubscriptions, ss) + server.mu.Unlock() + return err } diff --git a/mcp/streamable.go b/mcp/streamable.go index 5bc31771..2c851dd3 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1365,6 +1365,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false + isSubscriptionsListen := false var initializeProtocolVersion string headerVersion := protocolVersionFromContext(req.Context()) for _, msg := range incoming { @@ -1391,6 +1392,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques initializeProtocolVersion = params.ProtocolVersion } } + if jreq.Method == methodSubscriptionsListen { + isSubscriptionsListen = true + } // SEP-2575: requests carrying `_meta.protocolVersion` require the // Mcp-Protocol-Version HTTP header to be present and to match the // per-request `_meta.protocolVersion` value. @@ -1528,13 +1532,20 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } + // subscriptions/listen is inherently a long-lived SSE endpoint (SEP-2575): + // it has no synchronous result, the response stream stays open until the + // client cancels, and the server pushes notifications on it as they occur. + // Force SSE mode (bypassing JSONResponse) so the buffered application/json + // path doesn't deadlock waiting for a completion that won't come. + useSSE := !c.jsonResponse || isSubscriptionsListen + // Set response headers. Accept was checked in [StreamableHTTPHandler]. w.Header().Set("Cache-Control", "no-cache, no-transform") - if c.jsonResponse { - w.Header().Set("Content-Type", "application/json") - } else { + if useSSE { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") + } else { + w.Header().Set("Content-Type", "application/json") } if c.sessionID != "" && isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) @@ -1545,7 +1556,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques done := make(chan struct{}) stream.done = done stream.protocolVersion = effectiveVersion - if c.jsonResponse { + if !useSSE { // JSON mode: collect messages in pendingJSONMessages until done. // Set pendingJSONMessages to a non-nil value to signal that this is an // application/json stream. @@ -1571,6 +1582,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) } } + // For subscriptions/listen, flush headers eagerly so the client sees + // the open SSE stream even when an HTTP/2-aware reverse proxy is + // buffering the HEADERS frame waiting for a DATA frame to coalesce + // with. The acknowledgment notification will follow shortly, but the + // client can begin reading event-stream framing immediately. See the + // equivalent comment in [streamableServerConn.acquireStream]. + if isSubscriptionsListen { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, ": ok\n\n") + _ = http.NewResponseController(w).Flush() + } } // TODO(rfindley): if we have no event store, we should really cancel all @@ -2460,7 +2482,8 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. - if jsonResp.ID == forCall.ID { + // The subscriptions/listen is now returning a response in the SSE stream. + if jsonResp.ID == forCall.ID && forCall.Method != methodSubscriptionsListen { return "", 0, true } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..6e19336f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3716,3 +3716,137 @@ func TestStreamableHTTP_E2E_DiscoverSuccess(t *testing.T) { t.Errorf("CallTool result[0] = %+v, want TextContent{Text:\"hello\"}", res.Content[0]) } } + +// TestSubscriptionsListen_Streamable opens two concurrent subscriptions/listen +// streams on the same client session against a stateless HTTP server and +// verifies the full SEP-2575 contract end-to-end as observed by the client: +// +// - both listens are acknowledged, each tagged with its own request ID; +// - opt-in filtering: tool-list-changed reaches only the tools listen and +// prompt-list-changed reaches only the prompts listen; +// - the underlying SSE streams persist across multiple unrelated changes so +// notifications keep arriving until the session is closed; +// - subscription-ID tagging: every notification carries the originating +// listen's request ID in `_meta`, matching the ack's tag for that listen; +// - closing the client session ends all deliveries. +func TestSubscriptionsListen_Streamable(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "t1"}, sayHi) + server.AddPrompt(&Prompt{Name: "p1"}, nil) + + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + type event struct { + kind string + id string + } + events := make(chan event, 64) + asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } + + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + + cs, err := client.Connect(context.Background(), &StreamableClientTransport{Endpoint: httpServer.URL}, + &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + startListen := func(notifs NotificationSubscriptions) { + go func() { + _ = cs.SubscriptionsListen(context.Background(), &SubscriptionsListenParams{Notifications: notifs}) + }() + } + waitFor := func(kind string) event { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return event{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + startListen(NotificationSubscriptions{ToolsListChanged: true}) + startListen(NotificationSubscriptions{PromptsListChanged: true}) + + ack1 := waitFor("ack") + ack2 := waitFor("ack") + if ack1.id == ack2.id { + t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) + } + + // Trigger a tool change. The notification's subscription ID identifies + // which of the two acks belongs to the tools listen; the other ack + // therefore belongs to the prompts listen. + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + toolEv := waitFor("tool") + if toolEv.id != ack1.id && toolEv.id != ack2.id { + t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) + } + toolSubID := toolEv.id + promptSubID := ack2.id + if ack1.id != toolSubID { + promptSubID = ack1.id + } + expectNoEvent(notificationDelay * 5) + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) + } + expectNoEvent(notificationDelay * 5) + + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != toolSubID { + t.Errorf("second tool notif id = %s, want %s", e.id, toolSubID) + } + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("second prompt notif id = %s, want %s", e.id, promptSubID) + } + expectNoEvent(notificationDelay * 5) + + cs.Close() + time.Sleep(50 * time.Millisecond) + server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 515b8c19..d882998c 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -6,10 +6,14 @@ package mcp import ( "context" + "fmt" "io" + "slices" "strings" "testing" + "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -124,3 +128,150 @@ func TestIOConnRead(t *testing.T) { }) } } + +// TestSubscriptionsListen_InMemory verifies SEP-2575 subscriptions/listen +// over a single shared session (in-memory transport, semantically equivalent +// to STDIO). It exercises behavior that is harder to observe over streamable +// HTTP, where each listen lives in its own ephemeral session: +// +// - two concurrent listens on the SAME session both deliver notifications; +// - opt-in filtering: each listen receives only its opted-in notification +// types, tagged with its own subscription ID; +// - per-listen cancellation propagates over notifications/cancelled: when +// the client cancels one listen's context, the server stops fanning out +// notifications for that listen but keeps delivering to the other; +// - the remaining listen continues to work after the first cancellation. +func TestSubscriptionsListen_InMemory(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer topCancel() + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "t1"}, sayHi) + server.AddPrompt(&Prompt{Name: "p1"}, nil) + + ct, st := NewInMemoryTransports() + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + type event struct { + kind string + id string + } + events := make(chan event, 64) + asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } + + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + + cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + startListen := func(notifs NotificationSubscriptions) context.CancelFunc { + lctx, cancel := context.WithCancel(ctx) + go func() { + _ = cs.SubscriptionsListen(lctx, &SubscriptionsListenParams{Notifications: notifs}) + }() + return cancel + } + waitFor := func(kind string) event { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return event{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + cancelTools := startListen(NotificationSubscriptions{ToolsListChanged: true}) + cancelPrompts := startListen(NotificationSubscriptions{PromptsListChanged: true}) + + ack1 := waitFor("ack") + ack2 := waitFor("ack") + if ack1.id == ack2.id { + t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) + } + + // Identify which ack belongs to which listen by triggering a tool change + // and observing the tagged subscription ID on the notification. + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + toolEv := waitFor("tool") + if toolEv.id != ack1.id && toolEv.id != ack2.id { + t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) + } + toolSubID := toolEv.id + promptSubID := ack2.id + if ack1.id != toolSubID { + promptSubID = ack1.id + } + expectNoEvent(notificationDelay * 5) + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) + } + expectNoEvent(notificationDelay * 5) + + // Cancel the tools listen. The SDK sends a notifications/cancelled to the + // server, which flips the listen handler's ctx, which unblocks the + // goroutine that removes the subscription from the server's index. + cancelTools() + + // Give the cancellation a moment to propagate (notifications/cancelled + // → server-side cancel → cleanup goroutine). + time.Sleep(50 * time.Millisecond) + + // A new tool change must NOT reach the (cancelled) tools listen, while + // the prompts listen continues to receive its notifications. + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + expectNoEvent(notificationDelay * 20) + + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("prompt notif after tools-cancel id = %s, want %s", e.id, promptSubID) + } + + cancelPrompts() + time.Sleep(50 * time.Millisecond) + + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} From 0066ac920db4f16ef3a1b2501f02a499641f08e5 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 11 Jun 2026 15:55:27 +0000 Subject: [PATCH 2/7] refactor: remove subscription-based notification tests and cleanup unused imports --- mcp/client.go | 14 ++- mcp/mcp_test.go | 201 +++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 12 ++- mcp/streamable_test.go | 134 --------------------------- mcp/transport_test.go | 151 ------------------------------- 5 files changed, 219 insertions(+), 293 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index e40f148a..fff96859 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -313,11 +313,13 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if subscribeParams.Notifications.ToolsListChanged || subscribeParams.Notifications.PromptsListChanged || subscribeParams.Notifications.ResourcesListChanged { - // The listen blocks until the server cancels it (or the - // underlying connection ends). Run it in a goroutine so - // Connect can return; the SDK retires it on cs.Close(). + // The listen blocks until the server cancels it. Run it in + // a goroutine so Connect can return; ClientSession.Close + // cancels its context to send notifications/cancelled. + listenCtx, cancelListen := context.WithCancel(context.Background()) + cs.listenCancel = cancelListen go func() { - _ = cs.subscriptionsListen(context.Background(), subscribeParams) + _ = cs.subscriptionsListen(listenCtx, subscribeParams) }() } return cs, nil @@ -436,6 +438,7 @@ type ClientSession struct { conn *jsonrpc2.Connection client *Client keepaliveCancel context.CancelFunc + listenCancel context.CancelFunc mcpConn Connection // No mutex is (currently) required to guard the session state, because it is @@ -518,6 +521,9 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } + if cs.listenCancel != nil { + cs.listenCancel() + } err := cs.conn.Close() if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5c8e7d12..02f05618 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -12,6 +12,8 @@ import ( "fmt" "io" "log/slog" + "net/http" + "net/http/httptest" "net/url" "path/filepath" "runtime" @@ -2449,3 +2451,202 @@ func TestSetErrorPreservesContent(t *testing.T) { } var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} + +// runSubscriptionsListenTest exercises the SEP-2575 auto-listen flow end-to-end +// against the supplied transport pair. It captures every notification and the +// acknowledgment the client sees, then asserts: +// +// - the auto-listen issued by Client.Connect is acknowledged with a tagged +// subscription ID; +// - tool and prompt list-changed notifications are delivered to the matching +// handlers, each carrying the same subscription ID as the ack; +// - the subscription persists across multiple unrelated changes; +// - cs.Close() ends the subscription and further changes don't deliver. +func runSubscriptionsListenTest(t *testing.T, client *Client, server *Server, ct Transport, events chan subListenEvent) { + t.Helper() + + ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer topCancel() + + cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + waitFor := func(kind string) subListenEvent { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return subListenEvent{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + ack := waitFor("ack") + if ack.id == "" { + t.Fatalf("acknowledgment missing subscription ID") + } + + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != ack.id { + t.Errorf("first tool notif id = %s, want %s", e.id, ack.id) + } + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != ack.id { + t.Errorf("first prompt notif id = %s, want %s", e.id, ack.id) + } + + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != ack.id { + t.Errorf("second tool notif id = %s, want %s", e.id, ack.id) + } + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != ack.id { + t.Errorf("second prompt notif id = %s, want %s", e.id, ack.id) + } + expectNoEvent(notificationDelay * 5) + + cs.Close() + time.Sleep(50 * time.Millisecond) + + server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} + +type subListenEvent struct { + kind string // "ack", "tool", "prompt" + id string // subscription ID from _meta, stringified for cross-encoding equality +} + +// newSubListenClient returns a client wired to push every ack and every +// list-changed notification it receives into events, tagged with the kind +// and the subscription ID extracted from _meta. +func newSubListenClient(events chan subListenEvent) *Client { + asEvent := func(kind string, raw any) subListenEvent { + return subListenEvent{kind, fmt.Sprint(raw)} + } + c := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + return c +} + +func newSubListenServer() *Server { + s := NewServer(testImpl, nil) + AddTool(s, &Tool{Name: "t1"}, sayHi) + s.AddPrompt(&Prompt{Name: "p1"}, nil) + return s +} + +func enableNewProtocol(t *testing.T) { + t.Helper() + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) +} + +// TestSubscriptionsListen_InMemory exercises the listen flow over the +// session-shared in-memory transport (semantically equivalent to STDIO). +// Cancellation here propagates via notifications/cancelled. +func TestSubscriptionsListen_InMemory(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 64) + server := newSubListenServer() + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + runSubscriptionsListenTest(t, newSubListenClient(events), server, ct, events) +} + +// TestSubscriptionsListen_Streamable exercises the listen flow over a +// stateless HTTP server (SEP-2575). Each listen rides its own SSE response +// stream; cs.Close() tears it down. +func TestSubscriptionsListen_Streamable(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 64) + server := newSubListenServer() + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + runSubscriptionsListenTest(t, newSubListenClient(events), server, + &StreamableClientTransport{Endpoint: httpServer.URL}, events) +} + +// TestSubscriptionsListen_NoHandlersNoListen verifies that a new-protocol +// client without any list-changed handlers registered does not open an +// auto-listen on connect, and therefore does not receive any acknowledgment +// or downstream notifications. +func TestSubscriptionsListen_NoHandlersNoListen(t *testing.T) { + enableNewProtocol(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := make(chan subListenEvent, 8) + server := newSubListenServer() + ct, st := NewInMemoryTransports() + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + events <- subListenEvent{"ack", ""} + } + return next(ctx, method, req) + } + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + + select { + case e := <-events: + t.Fatalf("unexpected event %q on no-handler client", e.kind) + case <-time.After(notificationDelay * 10): + } +} diff --git a/mcp/server.go b/mcp/server.go index d4af3761..7cf2a9c5 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1793,11 +1793,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(ctx context.Context, param *CancelledParams) (Result, error) { +func (ss *ServerSession) cancel(_ context.Context, param *CancelledParams) (Result, error) { + id, err := jsonrpc.MakeID(param.RequestID) + if err != nil { + return nil, nil + } server := ss.server - server.removeSubscription(server.toolsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) - server.removeSubscription(server.promptsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) - server.removeSubscription(server.resourcesSubscriptions, ss, param.RequestID.(jsonrpc.ID)) + server.removeSubscription(server.toolsSubscriptions, ss, id) + server.removeSubscription(server.promptsSubscriptions, ss, id) + server.removeSubscription(server.resourcesSubscriptions, ss, id) return nil, nil } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6e19336f..2efee20e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3716,137 +3716,3 @@ func TestStreamableHTTP_E2E_DiscoverSuccess(t *testing.T) { t.Errorf("CallTool result[0] = %+v, want TextContent{Text:\"hello\"}", res.Content[0]) } } - -// TestSubscriptionsListen_Streamable opens two concurrent subscriptions/listen -// streams on the same client session against a stateless HTTP server and -// verifies the full SEP-2575 contract end-to-end as observed by the client: -// -// - both listens are acknowledged, each tagged with its own request ID; -// - opt-in filtering: tool-list-changed reaches only the tools listen and -// prompt-list-changed reaches only the prompts listen; -// - the underlying SSE streams persist across multiple unrelated changes so -// notifications keep arriving until the session is closed; -// - subscription-ID tagging: every notification carries the originating -// listen's request ID in `_meta`, matching the ack's tag for that listen; -// - closing the client session ends all deliveries. -func TestSubscriptionsListen_Streamable(t *testing.T) { - orig := supportedProtocolVersions - supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) - t.Cleanup(func() { supportedProtocolVersions = orig }) - - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "t1"}, sayHi) - server.AddPrompt(&Prompt{Name: "p1"}, nil) - - handler := NewStreamableHTTPHandler( - func(*http.Request) *Server { return server }, - &StreamableHTTPOptions{Stateless: true}, - ) - httpServer := httptest.NewServer(mustNotPanic(t, handler)) - defer httpServer.Close() - - type event struct { - kind string - id string - } - events := make(chan event, 64) - asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } - - client := NewClient(testImpl, &ClientOptions{ - ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { - events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) - }, - PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { - events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) - }, - }) - client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == notificationSubscriptionsAck { - if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { - events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) - } - } - return next(ctx, method, req) - } - }) - - cs, err := client.Connect(context.Background(), &StreamableClientTransport{Endpoint: httpServer.URL}, - &ClientSessionOptions{protocolVersion: protocolVersion20260630}) - if err != nil { - t.Fatalf("client connect: %v", err) - } - - startListen := func(notifs NotificationSubscriptions) { - go func() { - _ = cs.SubscriptionsListen(context.Background(), &SubscriptionsListenParams{Notifications: notifs}) - }() - } - waitFor := func(kind string) event { - t.Helper() - select { - case e := <-events: - if e.kind != kind { - t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) - } - return e - case <-time.After(5 * time.Second): - t.Fatalf("timed out waiting for %q event", kind) - return event{} - } - } - expectNoEvent := func(d time.Duration) { - t.Helper() - select { - case e := <-events: - t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) - case <-time.After(d): - } - } - - startListen(NotificationSubscriptions{ToolsListChanged: true}) - startListen(NotificationSubscriptions{PromptsListChanged: true}) - - ack1 := waitFor("ack") - ack2 := waitFor("ack") - if ack1.id == ack2.id { - t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) - } - - // Trigger a tool change. The notification's subscription ID identifies - // which of the two acks belongs to the tools listen; the other ack - // therefore belongs to the prompts listen. - server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - toolEv := waitFor("tool") - if toolEv.id != ack1.id && toolEv.id != ack2.id { - t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) - } - toolSubID := toolEv.id - promptSubID := ack2.id - if ack1.id != toolSubID { - promptSubID = ack1.id - } - expectNoEvent(notificationDelay * 5) - - server.AddPrompt(&Prompt{Name: "p2"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) - } - expectNoEvent(notificationDelay * 5) - - server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - if e := waitFor("tool"); e.id != toolSubID { - t.Errorf("second tool notif id = %s, want %s", e.id, toolSubID) - } - server.AddPrompt(&Prompt{Name: "p3"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("second prompt notif id = %s, want %s", e.id, promptSubID) - } - expectNoEvent(notificationDelay * 5) - - cs.Close() - time.Sleep(50 * time.Millisecond) - server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - server.AddPrompt(&Prompt{Name: "p4"}, nil) - expectNoEvent(notificationDelay * 20) -} diff --git a/mcp/transport_test.go b/mcp/transport_test.go index d882998c..515b8c19 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -6,14 +6,10 @@ package mcp import ( "context" - "fmt" "io" - "slices" "strings" "testing" - "time" - "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -128,150 +124,3 @@ func TestIOConnRead(t *testing.T) { }) } } - -// TestSubscriptionsListen_InMemory verifies SEP-2575 subscriptions/listen -// over a single shared session (in-memory transport, semantically equivalent -// to STDIO). It exercises behavior that is harder to observe over streamable -// HTTP, where each listen lives in its own ephemeral session: -// -// - two concurrent listens on the SAME session both deliver notifications; -// - opt-in filtering: each listen receives only its opted-in notification -// types, tagged with its own subscription ID; -// - per-listen cancellation propagates over notifications/cancelled: when -// the client cancels one listen's context, the server stops fanning out -// notifications for that listen but keeps delivering to the other; -// - the remaining listen continues to work after the first cancellation. -func TestSubscriptionsListen_InMemory(t *testing.T) { - orig := supportedProtocolVersions - supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) - t.Cleanup(func() { supportedProtocolVersions = orig }) - - ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer topCancel() - - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "t1"}, sayHi) - server.AddPrompt(&Prompt{Name: "p1"}, nil) - - ct, st := NewInMemoryTransports() - ss, err := server.Connect(ctx, st, nil) - if err != nil { - t.Fatalf("server connect: %v", err) - } - defer ss.Close() - - type event struct { - kind string - id string - } - events := make(chan event, 64) - asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } - - client := NewClient(testImpl, &ClientOptions{ - ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { - events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) - }, - PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { - events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) - }, - }) - client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == notificationSubscriptionsAck { - if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { - events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) - } - } - return next(ctx, method, req) - } - }) - - cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) - if err != nil { - t.Fatalf("client connect: %v", err) - } - defer cs.Close() - - startListen := func(notifs NotificationSubscriptions) context.CancelFunc { - lctx, cancel := context.WithCancel(ctx) - go func() { - _ = cs.SubscriptionsListen(lctx, &SubscriptionsListenParams{Notifications: notifs}) - }() - return cancel - } - waitFor := func(kind string) event { - t.Helper() - select { - case e := <-events: - if e.kind != kind { - t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) - } - return e - case <-time.After(5 * time.Second): - t.Fatalf("timed out waiting for %q event", kind) - return event{} - } - } - expectNoEvent := func(d time.Duration) { - t.Helper() - select { - case e := <-events: - t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) - case <-time.After(d): - } - } - - cancelTools := startListen(NotificationSubscriptions{ToolsListChanged: true}) - cancelPrompts := startListen(NotificationSubscriptions{PromptsListChanged: true}) - - ack1 := waitFor("ack") - ack2 := waitFor("ack") - if ack1.id == ack2.id { - t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) - } - - // Identify which ack belongs to which listen by triggering a tool change - // and observing the tagged subscription ID on the notification. - server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - toolEv := waitFor("tool") - if toolEv.id != ack1.id && toolEv.id != ack2.id { - t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) - } - toolSubID := toolEv.id - promptSubID := ack2.id - if ack1.id != toolSubID { - promptSubID = ack1.id - } - expectNoEvent(notificationDelay * 5) - - server.AddPrompt(&Prompt{Name: "p2"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) - } - expectNoEvent(notificationDelay * 5) - - // Cancel the tools listen. The SDK sends a notifications/cancelled to the - // server, which flips the listen handler's ctx, which unblocks the - // goroutine that removes the subscription from the server's index. - cancelTools() - - // Give the cancellation a moment to propagate (notifications/cancelled - // → server-side cancel → cleanup goroutine). - time.Sleep(50 * time.Millisecond) - - // A new tool change must NOT reach the (cancelled) tools listen, while - // the prompts listen continues to receive its notifications. - server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - expectNoEvent(notificationDelay * 20) - - server.AddPrompt(&Prompt{Name: "p3"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("prompt notif after tools-cancel id = %s, want %s", e.id, promptSubID) - } - - cancelPrompts() - time.Sleep(50 * time.Millisecond) - - server.AddPrompt(&Prompt{Name: "p4"}, nil) - expectNoEvent(notificationDelay * 20) -} From 8974819e587c0535cb49fdb7172917fd2d7885cb Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 12 Jun 2026 07:21:07 +0000 Subject: [PATCH 3/7] wip --- mcp/server.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 7cf2a9c5..7cb383c2 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1793,15 +1793,7 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(_ context.Context, param *CancelledParams) (Result, error) { - id, err := jsonrpc.MakeID(param.RequestID) - if err != nil { - return nil, nil - } - server := ss.server - server.removeSubscription(server.toolsSubscriptions, ss, id) - server.removeSubscription(server.promptsSubscriptions, ss, id) - server.removeSubscription(server.resourcesSubscriptions, ss, id) +func (ss *ServerSession) cancel(_ context.Context, _ *CancelledParams) (Result, error) { return nil, nil } From d2803b70e8c7ffbd998346e330125c7a672dc8a4 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 12 Jun 2026 14:03:49 +0000 Subject: [PATCH 4/7] refactor: encapsulate subscription management with activeSubscriptions type and dedicated methods --- mcp/server.go | 106 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 38 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 7cb383c2..ab7b6878 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -35,6 +35,9 @@ import ( // DefaultPageSize is the default for [ServerOptions.PageSize]. const DefaultPageSize = 1000 +// A map from sessions to the set of subscription IDs they have active for a given feature. +type activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP @@ -53,9 +56,9 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - toolsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool - promptsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool - resourcesSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + toolsSubscriptions activeSubscriptions + promptsSubscriptions activeSubscriptions + resourcesSubscriptions activeSubscriptions pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } @@ -207,9 +210,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), - toolsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), - promptsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), - resourcesSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), + toolsSubscriptions: make(activeSubscriptions), + promptsSubscriptions: make(activeSubscriptions), + resourcesSubscriptions: make(activeSubscriptions), pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) @@ -702,7 +705,15 @@ func (s *Server) notifySessions(n string) { } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) - var activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + // Notify sessions that subscribed to the notification with 'subscriptions/listen' method. + type subscription struct { + session *ServerSession + reqID jsonrpc.ID + } + var subs []subscription + + s.mu.Lock() + var activeSubscriptions activeSubscriptions switch n { case notificationToolListChanged: activeSubscriptions = s.toolsSubscriptions @@ -711,22 +722,27 @@ func (s *Server) notifySessions(n string) { case notificationResourceListChanged: activeSubscriptions = s.resourcesSubscriptions } + for session, reqIDs := range activeSubscriptions { + for reqID := range reqIDs { + subs = append(subs, subscription{session: session, reqID: reqID}) + } + } + s.mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for session := range activeSubscriptions { - for reqID := range activeSubscriptions[session] { - params := changeNotificationParams[n]() - setSubscriptionID(params, reqID) - req := newRequest(session, params) - ctx = context.WithValue(ctx, idContextKey{}, reqID) - if err := handleNotify(ctx, n, req); err != nil { - s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) - } + for _, sub := range subs { + params := changeNotificationParams[n]() + injectMetaSubscriptionID(params, sub.reqID) + req := newRequest(sub.session, params) + reqCtx := context.WithValue(ctx, idContextKey{}, sub.reqID) + if err := handleNotify(reqCtx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) } } } -func setSubscriptionID(params Params, reqID jsonrpc.ID) { +func injectMetaSubscriptionID(params Params, reqID jsonrpc.ID) { m := params.GetMeta() if m == nil { m = map[string]any{} @@ -1072,46 +1088,62 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } -func (s *Server) addSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - ids := index[ss] +func (m activeSubscriptions) add(ss *ServerSession, id jsonrpc.ID) { + ids := m[ss] if ids == nil { ids = make(map[jsonrpc.ID]bool) - index[ss] = ids + m[ss] = ids } ids[id] = true } -func (s *Server) removeSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - ids, ok := index[ss] +func (m activeSubscriptions) remove(ss *ServerSession, id jsonrpc.ID) { + ids, ok := m[ss] if !ok { return } delete(ids, id) if len(ids) == 0 { - delete(index, ss) + delete(m, ss) } } -func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { - requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) - if !ok || !requestID.IsValid() { - return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) +func (s *Server) addSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() + if allowed.ToolsListChanged { + s.toolsSubscriptions.add(ss, id) + } + if allowed.PromptsListChanged { + s.promptsSubscriptions.add(ss, id) + } + if allowed.ResourcesListChanged { + s.resourcesSubscriptions.add(ss, id) } +} - allowed := s.allowedSubscriptions(req.Params.Notifications) +func (s *Server) removeSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() if allowed.ToolsListChanged { - s.addSubscription(s.toolsSubscriptions, req.Session, requestID) + s.toolsSubscriptions.remove(ss, id) } if allowed.PromptsListChanged { - s.addSubscription(s.promptsSubscriptions, req.Session, requestID) + s.promptsSubscriptions.remove(ss, id) } if allowed.ResourcesListChanged { - s.addSubscription(s.resourcesSubscriptions, req.Session, requestID) + s.resourcesSubscriptions.remove(ss, id) } +} + +func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) + } + + allowed := s.allowedSubscriptions(req.Params.Notifications) + s.addSubscriptions(allowed, req.Session, requestID) ackParams := &SubscriptionsAcknowledgedParams{ Notifications: allowed, @@ -1122,9 +1154,7 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList } defer func() { - s.removeSubscription(s.toolsSubscriptions, req.Session, requestID) - s.removeSubscription(s.promptsSubscriptions, req.Session, requestID) - s.removeSubscription(s.resourcesSubscriptions, req.Session, requestID) + s.removeSubscriptions(allowed, req.Session, requestID) }() <-ctx.Done() From 3ca0f0c75698645722f734f6fe09230c1dd3ce11 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 15 Jun 2026 09:19:06 +0000 Subject: [PATCH 5/7] refactor: move subscription management from Server to ServerSession to simplify notification handling --- mcp/server.go | 188 +++++++++++++++++++++----------------------------- 1 file changed, 80 insertions(+), 108 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index ab7b6878..fc19d5fb 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -35,9 +35,6 @@ import ( // DefaultPageSize is the default for [ServerOptions.PageSize]. const DefaultPageSize = 1000 -// A map from sessions to the set of subscription IDs they have active for a given feature. -type activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool - // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP @@ -56,10 +53,7 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - toolsSubscriptions activeSubscriptions - promptsSubscriptions activeSubscriptions - resourcesSubscriptions activeSubscriptions - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -210,9 +204,6 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), - toolsSubscriptions: make(activeSubscriptions), - promptsSubscriptions: make(activeSubscriptions), - resourcesSubscriptions: make(activeSubscriptions), pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) @@ -690,6 +681,11 @@ func (s *Server) changeAndNotify(notification string, change func() bool) { // notifySessions sends the notification n to all existing sessions. // It is called asynchronously by changeAndNotify. +// +// Legacy (pre-SEP-2575) sessions receive the notification on the shared +// session channel. Sessions speaking the new protocol receive it only if they +// have an active subscriptions/listen stream that opted in to this +// notification type. func (s *Server) notifySessions(n string) { s.mu.Lock() sessions := slices.Clone(s.sessions) @@ -698,46 +694,24 @@ func (s *Server) notifySessions(n string) { // Only add legacy sessions for the notification, new ones use the new notification mechanism. var legacySessions []*ServerSession - for _, s := range sessions { - if s.InitializeParams().isNil() || s.InitializeParams().ProtocolVersion < protocolVersion20260630 { - legacySessions = append(legacySessions, s) + for _, session := range sessions { + if session.InitializeParams().isNil() || session.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, session) } } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) - // Notify sessions that subscribed to the notification with 'subscriptions/listen' method. - type subscription struct { - session *ServerSession - reqID jsonrpc.ID - } - var subs []subscription - - s.mu.Lock() - var activeSubscriptions activeSubscriptions - switch n { - case notificationToolListChanged: - activeSubscriptions = s.toolsSubscriptions - case notificationPromptListChanged: - activeSubscriptions = s.promptsSubscriptions - case notificationResourceListChanged: - activeSubscriptions = s.resourcesSubscriptions - } - for session, reqIDs := range activeSubscriptions { - for reqID := range reqIDs { - subs = append(subs, subscription{session: session, reqID: reqID}) - } - } - s.mu.Unlock() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for _, sub := range subs { - params := changeNotificationParams[n]() - injectMetaSubscriptionID(params, sub.reqID) - req := newRequest(sub.session, params) - reqCtx := context.WithValue(ctx, idContextKey{}, sub.reqID) - if err := handleNotify(reqCtx, n, req); err != nil { - s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + for _, session := range sessions { + for _, reqID := range session.subscribedListenIDs(n) { + params := changeNotificationParams[n]() + injectMetaSubscriptionID(params, reqID) + req := newRequest(session, params) + reqCtx := context.WithValue(ctx, idContextKey{}, reqID) + if err := handleNotify(reqCtx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + } } } } @@ -1088,54 +1062,6 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } -func (m activeSubscriptions) add(ss *ServerSession, id jsonrpc.ID) { - ids := m[ss] - if ids == nil { - ids = make(map[jsonrpc.ID]bool) - m[ss] = ids - } - ids[id] = true -} - -func (m activeSubscriptions) remove(ss *ServerSession, id jsonrpc.ID) { - ids, ok := m[ss] - if !ok { - return - } - delete(ids, id) - if len(ids) == 0 { - delete(m, ss) - } -} - -func (s *Server) addSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - if allowed.ToolsListChanged { - s.toolsSubscriptions.add(ss, id) - } - if allowed.PromptsListChanged { - s.promptsSubscriptions.add(ss, id) - } - if allowed.ResourcesListChanged { - s.resourcesSubscriptions.add(ss, id) - } -} - -func (s *Server) removeSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - if allowed.ToolsListChanged { - s.toolsSubscriptions.remove(ss, id) - } - if allowed.PromptsListChanged { - s.promptsSubscriptions.remove(ss, id) - } - if allowed.ResourcesListChanged { - s.resourcesSubscriptions.remove(ss, id) - } -} - func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) if !ok || !requestID.IsValid() { @@ -1143,7 +1069,8 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList } allowed := s.allowedSubscriptions(req.Params.Notifications) - s.addSubscriptions(allowed, req.Session, requestID) + req.Session.addSubscription(requestID, allowed) + defer req.Session.removeSubscription(requestID) ackParams := &SubscriptionsAcknowledgedParams{ Notifications: allowed, @@ -1153,10 +1080,6 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList return nil, fmt.Errorf("sending subscriptions/acknowledged: %w", err) } - defer func() { - s.removeSubscriptions(allowed, req.Session, requestID) - }() - <-ctx.Done() return &emptyResult{}, nil } @@ -1176,6 +1099,55 @@ func (s *Server) allowedSubscriptions(want NotificationSubscriptions) Notificati return agreed } +// addSubscription registers a subscriptions/listen stream on this session. +func (ss *ServerSession) addSubscription(id jsonrpc.ID, allowed NotificationSubscriptions) { + if !allowed.ToolsListChanged && !allowed.PromptsListChanged && !allowed.ResourcesListChanged { + return + } + ss.mu.Lock() + defer ss.mu.Unlock() + if ss.subscriptions == nil { + ss.subscriptions = make(map[jsonrpc.ID]*listenSubscription) + } + ss.subscriptions[id] = &listenSubscription{ + toolsListChanged: allowed.ToolsListChanged, + promptsListChanged: allowed.PromptsListChanged, + resourcesListChanged: allowed.ResourcesListChanged, + } +} + +func (ss *ServerSession) removeSubscription(id jsonrpc.ID) { + ss.mu.Lock() + defer ss.mu.Unlock() + delete(ss.subscriptions, id) +} + +// subscribedListenIDs returns the listen-request IDs that opted in to the +// given list-changed notification. +func (ss *ServerSession) subscribedListenIDs(notification string) []jsonrpc.ID { + ss.mu.Lock() + defer ss.mu.Unlock() + if len(ss.subscriptions) == 0 { + return nil + } + var ids []jsonrpc.ID + for id, sub := range ss.subscriptions { + var match bool + switch notification { + case notificationToolListChanged: + match = sub.toolsListChanged + case notificationPromptListChanged: + match = sub.promptsListChanged + case notificationResourceListChanged: + match = sub.resourcesListChanged + } + if match { + ids = append(ids, id) + } + } + return ids +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -1245,9 +1217,6 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } - delete(s.toolsSubscriptions, cc) - delete(s.promptsSubscriptions, cc) - delete(s.resourcesSubscriptions, cc) s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) } @@ -1385,6 +1354,17 @@ type ServerSession struct { mu sync.Mutex state ServerSessionState + // subscriptions tracks the SEP-2575 subscriptions/listen streams opened by + // this session, keyed by the originating listen request ID. + subscriptions map[jsonrpc.ID]*listenSubscription +} + +// listenSubscription records the notification types a single +// subscriptions/listen stream has opted in to. +type listenSubscription struct { + toolsListChanged bool + promptsListChanged bool + resourcesListChanged bool } func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { @@ -1856,14 +1836,6 @@ func (ss *ServerSession) Close() error { ss.onClose() } - // Clean up session subscriptions - server := ss.server - server.mu.Lock() - delete(server.toolsSubscriptions, ss) - delete(server.promptsSubscriptions, ss) - delete(server.resourcesSubscriptions, ss) - server.mu.Unlock() - return err } From c729ae9118480831e01e337f58e9cfffd98810a1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 15 Jun 2026 09:41:54 +0000 Subject: [PATCH 6/7] refactor: remove redundant SSE priming and simplify server session cancellation signature and request matching logic --- mcp/server.go | 2 +- mcp/streamable.go | 13 +------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index fc19d5fb..535e0a6e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1803,7 +1803,7 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(_ context.Context, _ *CancelledParams) (Result, error) { +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { return nil, nil } diff --git a/mcp/streamable.go b/mcp/streamable.go index 2c851dd3..c01e1eb2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1582,17 +1582,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) } } - // For subscriptions/listen, flush headers eagerly so the client sees - // the open SSE stream even when an HTTP/2-aware reverse proxy is - // buffering the HEADERS frame waiting for a DATA frame to coalesce - // with. The acknowledgment notification will follow shortly, but the - // client can begin reading event-stream framing immediately. See the - // equivalent comment in [streamableServerConn.acquireStream]. - if isSubscriptionsListen { - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, ": ok\n\n") - _ = http.NewResponseController(w).Flush() - } } // TODO(rfindley): if we have no event store, we should really cancel all @@ -2483,7 +2472,7 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. // The subscriptions/listen is now returning a response in the SSE stream. - if jsonResp.ID == forCall.ID && forCall.Method != methodSubscriptionsListen { + if jsonResp.ID == forCall.ID { return "", 0, true } } From c9557ad958c901aa22afe8eccaacbad05f5fd132 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 15 Jun 2026 09:42:57 +0000 Subject: [PATCH 7/7] refactor: remove stale comment regarding SSE stream response handling --- mcp/streamable.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index c01e1eb2..f4d81c15 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -2471,7 +2471,6 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. - // The subscriptions/listen is now returning a response in the SSE stream. if jsonResp.ID == forCall.ID { return "", 0, true }