Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions mcp/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by the license
// that can be found in the LICENSE file.

package mcp

import (
"sync"
"time"
)

// methodCache is a per-method TTL cache for list and read results, as
// described in SEP-2549. Each entry is keyed by cursor (for paginated list
// methods) or URI (for resources/read).
type methodCache[R CacheableResult] struct {
mu sync.Mutex
cachedValues map[string]*cacheEntry[R]
}

type cacheEntry[R CacheableResult] struct {
result R
receivedAt time.Time
}

func (e *cacheEntry[R]) isValid() bool {
return time.Since(e.receivedAt) < time.Duration(e.result.GetTTLMs())*time.Millisecond
}

func (mc *methodCache[R]) get(key string) (R, bool) {
mc.mu.Lock()
defer mc.mu.Unlock()
entry, ok := mc.cachedValues[key]
if !ok {
var zero R
return zero, false
}
if entry.result.GetTTLMs() <= 0 || !entry.isValid() {
delete(mc.cachedValues, key)
var zero R
return zero, false
}
return entry.result, true
}

func (mc *methodCache[R]) put(key string, result R) {
mc.mu.Lock()
defer mc.mu.Unlock()
if mc.cachedValues == nil {
mc.cachedValues = make(map[string]*cacheEntry[R])
}
mc.cachedValues[key] = &cacheEntry[R]{
result: result,
receivedAt: time.Now(),
}
}

func (mc *methodCache[R]) forEach(f func(R)) {
mc.mu.Lock()
defer mc.mu.Unlock()
for _, entry := range mc.cachedValues {
f(entry.result)
}
}

func (mc *methodCache[R]) invalidate() {
mc.mu.Lock()
defer mc.mu.Unlock()
clear(mc.cachedValues)
}

func (mc *methodCache[R]) invalidateKey(key string) {
mc.mu.Lock()
defer mc.mu.Unlock()
delete(mc.cachedValues, key)
}

// cursorParams is the constraint for list-method params that carry a pagination
// cursor and can be checked for nil. Both methods are already implemented by
// every concrete list-params type.
type cursorParams interface {
Params
cursorPtr() *string
}

// cachedListResult returns a cached list result keyed by the request cursor
// (SEP-2549). It returns the zero value and false on miss or when params is nil.
func cachedListResult[P cursorParams, R CacheableResult](cache *methodCache[R], params P) (R, bool) {
key := ""
if !params.isNil() {
if cp := params.cursorPtr(); cp != nil {
key = *cp
}
}
return cache.get(key)
}
123 changes: 97 additions & 26 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,16 +422,16 @@ type ClientSession struct {
// only set synchronously during Client.Connect.
state clientSessionState

// Per-method TTL caches for list results (SEP-2549).
toolsCache methodCache[*ListToolsResult]
promptsCache methodCache[*ListPromptsResult]
resourcesCache methodCache[*ListResourcesResult]
resourceTemplatesCache methodCache[*ListResourceTemplatesResult]
readResourceCache methodCache[*ReadResourceResult]

// Pending URL elicitations waiting for completion notifications.
pendingElicitationsMu sync.Mutex
pendingElicitations map[string]chan struct{}

// toolCacheMu guards toolCache.
toolCacheMu sync.RWMutex
// toolCache stores tool definitions keyed by name.
// It is used to look up x-mcp-header annotations when
// constructing Mcp-Param-* headers for tools/call requests.
toolCache map[string]*Tool
}

type clientSessionState struct {
Expand Down Expand Up @@ -513,19 +513,25 @@ func (cs *ClientSession) Wait() error {
return cs.conn.Wait()
}

func (cs *ClientSession) cacheTools(tools []*Tool) {
cs.toolCacheMu.Lock()
defer cs.toolCacheMu.Unlock()
cs.toolCache = make(map[string]*Tool, len(tools))
for _, tool := range tools {
cs.toolCache[tool.Name] = tool
}
}

func (cs *ClientSession) getCachedTool(name string) *Tool {
cs.toolCacheMu.RLock()
defer cs.toolCacheMu.RUnlock()
return cs.toolCache[name]
// lookupTool returns the most recently seen definition of the tool with the
// given name across all cached ListTools results, or nil if no such tool has
// been seen. It is used by CallTool to inject the tool definition into the
// outgoing request context for transport-layer features (e.g. x-mcp-header
// param annotations).
func (cs *ClientSession) lookupTool(name string) *Tool {
var found *Tool
cs.toolsCache.forEach(func(r *ListToolsResult) {
if found != nil {
return
}
for _, t := range r.Tools {
if t.Name == name {
found = t
return
}
}
})
return found
}

// registerElicitationWaiter registers a waiter for an elicitation complete
Expand Down Expand Up @@ -1135,11 +1141,24 @@ func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error {
}

// ListPrompts lists prompts that are currently available on the server.
//
// Results may be served from a client-side TTL cache populated by previous
// calls; see SEP-2549.
func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) {
if cs.usesNewProtocol() {
if result, ok := cachedListResult(&cs.promptsCache, params); ok {
return result, nil
}
params = injectRequestMeta(cs, params)
}
return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params)))
result, err := handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params)))
if err != nil {
return nil, err
}
if cs.usesNewProtocol() {
cs.promptsCache.put(params.Cursor, result)
}
return result, nil
}

// GetPrompt gets a prompt from the server.
Expand All @@ -1153,14 +1172,19 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams)
// ListTools lists tools that are currently available on the server.
func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) {
if cs.usesNewProtocol() {
if result, ok := cachedListResult(&cs.toolsCache, params); ok {
return result, nil
}
params = injectRequestMeta(cs, params)
}
result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params)))
if err != nil {
return nil, err
}
result.Tools = filterValidTools(cs.client.opts.Logger, result.Tools)
cs.cacheTools(result.Tools)
if cs.usesNewProtocol() {
cs.toolsCache.put(params.Cursor, result)
}
return result, nil
}

Expand All @@ -1175,7 +1199,7 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
// Avoid sending nil over the wire.
params.Arguments = map[string]any{}
}
if tool := cs.getCachedTool(params.Name); tool != nil {
if tool := cs.lookupTool(params.Name); tool != nil {
ctx = context.WithValue(ctx, toolContextKey, tool)
}
if cs.usesNewProtocol() {
Expand All @@ -1192,25 +1216,59 @@ func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLogging
// ListResources lists the resources that are currently available on the server.
func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) {
if cs.usesNewProtocol() {
if result, ok := cachedListResult(&cs.resourcesCache, params); ok {
return result, nil
}
params = injectRequestMeta(cs, params)
}
return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params)))
result, err := handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params)))
if err != nil {
return nil, err
}
if cs.usesNewProtocol() {
cs.resourcesCache.put(params.Cursor, result)
}
return result, nil
}

// ListResourceTemplates lists the resource templates that are currently available on the server.
func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) {
if cs.usesNewProtocol() {
if result, ok := cachedListResult(&cs.resourceTemplatesCache, params); ok {
return result, nil
}
params = injectRequestMeta(cs, params)
}
return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params)))
result, err := handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params)))
if err != nil {
return nil, err
}
if cs.usesNewProtocol() {
cs.resourceTemplatesCache.put(params.Cursor, result)
}
return result, nil
}

// ReadResource asks the server to read a resource and return its contents.
func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) {
if cs.usesNewProtocol() {
var uri string
if params != nil {
uri = params.URI
}
if result, ok := cs.readResourceCache.get(uri); ok {
return result, nil
}
params = injectRequestMeta(cs, params)
}
return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params)))
result, err := handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params)))
if err != nil {
return nil, err
}
if cs.usesNewProtocol() {
cs.readResourceCache.put(params.URI, result)
}
return result, nil
}

func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) {
Expand All @@ -1235,27 +1293,40 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar
}

func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) {
if cs, ok := req.GetSession().(*ClientSession); ok {
cs.toolsCache.invalidate()
}
if h := c.opts.ToolListChangedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) {
if cs, ok := req.GetSession().(*ClientSession); ok {
cs.promptsCache.invalidate()
}
if h := c.opts.PromptListChangedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) {
if cs, ok := req.GetSession().(*ClientSession); ok {
cs.resourcesCache.invalidate()
cs.resourceTemplatesCache.invalidate()
}
if h := c.opts.ResourceListChangedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) {
if cs, ok := req.GetSession().(*ClientSession); ok && req.Params != nil {
cs.readResourceCache.invalidateKey(req.Params.URI)
}
if h := c.opts.ResourceUpdatedHandler; h != nil {
h(ctx, req)
}
Expand Down
Loading
Loading