diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index fba1b4d9..af9ff396 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -11,7 +11,9 @@ import ( "errors" "fmt" "log/slog" + "math" "net/http" + "strconv" "strings" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" @@ -70,58 +72,127 @@ func unmarshalSchemaProperties(schema any) map[string]headerSchemaProperty { return s.Properties } -// extractParamHeaderAnnotations returns a map of parameter name to header name -// for all properties in the tool's InputSchema that have an x-mcp-header -// annotation. -func extractParamHeaderAnnotations(tool *Tool) map[string]string { +// paramHeaderBinding maps a (possibly nested) input-schema property to the +// HTTP header it carries. +type paramHeaderBinding struct { + Path []string + Header string +} + +// extractParamHeaderAnnotations returns the bindings for every property in +// the tool's InputSchema that has an x-mcp-header annotation +func extractParamHeaderAnnotations(tool *Tool) []paramHeaderBinding { props := unmarshalSchemaProperties(tool.InputSchema) if len(props) == 0 { return nil } - result := make(map[string]string) - for propName, prop := range props { - var headerName string - if err := json.Unmarshal(prop.XMCPHeader, &headerName); err != nil || headerName == "" { - continue - } - result[propName] = headerName - } + var result []paramHeaderBinding + result = collectParamHeaderAnnotations(props, nil, result) if len(result) == 0 { return nil } return result } -// primitiveToString conversion. -// Returns false in the second return value if the argument is not a primitive value. -func primitiveToString(value any) (string, bool) { - switch v := value.(type) { - case string: - return v, true - case float64: - return fmt.Sprintf("%g", v), true - case bool: - return fmt.Sprintf("%t", v), true - default: - return "", false +// collectParamHeaderAnnotations walks the schema properties and records every +// x-mcp-header annotation it finds, keyed by the property-name path. +func collectParamHeaderAnnotations(props map[string]headerSchemaProperty, prefix []string, out []paramHeaderBinding) []paramHeaderBinding { + for propName, prop := range props { + path := make([]string, len(prefix)+1) + copy(path, prefix) + path[len(prefix)] = propName + + var headerName string + if err := json.Unmarshal(prop.XMCPHeader, &headerName); err == nil && headerName != "" { + out = append(out, paramHeaderBinding{Path: path, Header: headerName}) + } + if len(prop.Properties) > 0 { + out = collectParamHeaderAnnotations(prop.Properties, path, out) + } + } + return out +} + +// lookupArgument navigates the arguments object using the given property-name +// path and returns the raw JSON value at that location. It reports whether +// the value was found. +func lookupArgument(args map[string]json.RawMessage, path []string) (json.RawMessage, bool) { + if len(path) == 0 { + return nil, false } + cur, ok := args[path[0]] + if !ok { + return nil, false + } + for _, part := range path[1:] { + var obj map[string]json.RawMessage + if err := internaljson.Unmarshal(cur, &obj); err != nil { + return nil, false + } + cur, ok = obj[part] + if !ok { + return nil, false + } + } + return cur, true } -// unmarshalPrimitive unmarshals a JSON value into a Go primitive -// (string, float64, or bool). Returns nil for non-primitive types. +// maxSafeInteger and minSafeInteger bound the integer values that can be +// faithfully represented as IEEE-754 double-precision floats. +const ( + maxSafeInteger = 1<<53 - 1 // 2^53 - 1 = 9007199254740991 + minSafeInteger = -(1<<53 - 1) // -(2^53 - 1) = -9007199254740991 +) + +// unmarshalPrimitive unmarshals a JSON value into the Go representation used +// for x-mcp-header processing per SEP-2243: +// +// - JSON string -> string +// - JSON boolean -> bool +// - JSON integer (within the JavaScript safe-integer range) -> int64 +// +// JSON numbers that are non-integers (have a fractional part, NaN, or ±Inf) +// or integers outside the safe range are rejected because the `number` type +// is not permitted for x-mcp-header parameters; only integer, string, boolean +// are allowed. func unmarshalPrimitive(raw json.RawMessage) any { var val any if err := internaljson.Unmarshal(raw, &val); err != nil { return nil } - switch val.(type) { - case string, float64, bool: - return val + switch v := val.(type) { + case string, bool: + return v + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { + return nil + } + if v < minSafeInteger || v > maxSafeInteger { + return nil + } + return int64(v) default: return nil } } +// primitiveToString formats an x-mcp-header value (as produced by +// [unmarshalPrimitive]) to its canonical header string representation per +// SEP-2243. Returns false if value is not one of the permitted primitive +// types (string, bool, int64). +func primitiveToString(value any) (string, bool) { + switch v := value.(type) { + case string: + return v, true + case bool: + return fmt.Sprintf("%t", v), true + case int64: + return strconv.FormatInt(v, 10), true + default: + return "", false + } +} + // setStandardHeaders populates standard MCP headers. // It requires the protocol version header to be set. func setStandardHeaders(ctx context.Context, header http.Header, msg jsonrpc.Message) { @@ -164,8 +235,8 @@ func generateParamHeaders(tool *Tool, params json.RawMessage) map[string]string } res := make(map[string]string) - for paramName, headerName := range paramHeaders { - argRaw, ok := raw.Arguments[paramName] + for _, b := range paramHeaders { + argRaw, ok := lookupArgument(raw.Arguments, b.Path) if !ok { continue } @@ -180,7 +251,7 @@ func generateParamHeaders(tool *Tool, params json.RawMessage) map[string]string if !ok { continue } - res[paramHeaderPrefix+headerName] = encoded + res[paramHeaderPrefix+b.Header] = encoded } return res } @@ -201,64 +272,84 @@ func filterValidTools(logger *slog.Logger, tools []*Tool) []*Tool { } // validateParamHeaderAnnotations checks that a tool's x-mcp-header annotations -// are valid. +// are valid. Annotations may appear on properties at any nesting +// depth within the inputSchema and must be unique across all of them. func validateParamHeaderAnnotations(tool *Tool) error { props := unmarshalSchemaProperties(tool.InputSchema) if len(props) == 0 { return nil } - seen := make(map[string]bool) - for propName, prop := range props { - if err := checkForNestedHeaders(prop, propName); err != nil { - return err - } - if prop.XMCPHeader == nil { - continue - } - var headerName string - if err := json.Unmarshal(prop.XMCPHeader, &headerName); err != nil || headerName == "" { - return fmt.Errorf("property %q: x-mcp-header must be a non-empty string", propName) - } - if err := validateHeaderName(headerName); err != nil { - return fmt.Errorf("property %q: %w", propName, err) - } - lower := strings.ToLower(headerName) - if seen[lower] { - return fmt.Errorf("property %q: duplicate x-mcp-header value %q (case-insensitive)", propName, headerName) - } - seen[lower] = true - - if prop.Type != "string" && prop.Type != "number" && prop.Type != "integer" && prop.Type != "boolean" { - return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types, got %v", propName, prop.Type) - } - } - return nil + return validateParamHeadersIn(props, "", seen) } -func checkForNestedHeaders(prop headerSchemaProperty, path string) error { - for propName, nested := range prop.Properties { - if nested.XMCPHeader != nil { - return fmt.Errorf("property %q: x-mcp-header cannot be applied to nested properties", path+"."+propName) +func validateParamHeadersIn(props map[string]headerSchemaProperty, prefix string, seen map[string]bool) error { + for propName, prop := range props { + path := propName + if prefix != "" { + path = prefix + "." + propName + } + if prop.XMCPHeader != nil { + if prop.Type != "string" && prop.Type != "integer" && prop.Type != "boolean" { + return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types (integer, string, boolean), got %q", path, prop.Type) + } + var headerName string + if err := json.Unmarshal(prop.XMCPHeader, &headerName); err != nil || headerName == "" { + return fmt.Errorf("property %q: x-mcp-header must be a non-empty string", path) + } + if err := validateHeaderName(headerName); err != nil { + return fmt.Errorf("property %q: %w", path, err) + } + lower := strings.ToLower(headerName) + if seen[lower] { + return fmt.Errorf("property %q: duplicate x-mcp-header value %q (case-insensitive)", path, headerName) + } + seen[lower] = true } - if err := checkForNestedHeaders(nested, path+"."+propName); err != nil { - return err + if len(prop.Properties) > 0 { + if err := validateParamHeadersIn(prop.Properties, path, seen); err != nil { + return err + } } } return nil } -// validateHeaderName checks that a header name contains only valid -// ASCII characters (excluding space and ':'). +// validateHeaderName checks that a header name matches the HTTP field-name +// token syntax (1*tchar). func validateHeaderName(name string) error { + if name == "" { + return fmt.Errorf("x-mcp-header value must be a non-empty string") + } for _, c := range name { - if c <= 0x20 || c > 0x7E || c == ':' { + if !isTChar(c) { return fmt.Errorf("x-mcp-header value %q contains invalid character %q", name, c) } } return nil } +// isTChar reports whether c is a valid HTTP token character (tchar) +// +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +func isTChar(c rune) bool { + switch { + case c >= '0' && c <= '9': + return true + case c >= 'A' && c <= 'Z': + return true + case c >= 'a' && c <= 'z': + return true + } + switch c { + case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', + '^', '_', '`', '|', '~': + return true + } + return false +} + func validateMcpHeaders(header http.Header, msg jsonrpc.Message, toolLookup func(string) (*serverTool, bool)) error { protocolVersion := header.Get(protocolVersionHeader) if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders { @@ -315,20 +406,20 @@ func validateParamHeaders(header http.Header, msg *jsonrpc.Request, tool *Tool) return nil } - for paramName, headerName := range paramHeaders { - fullHeader := paramHeaderPrefix + headerName + for _, b := range paramHeaders { + fullHeader := paramHeaderPrefix + b.Header headerVal := header.Get(fullHeader) - argRaw, argExists := raw.Arguments[paramName] + argRaw, argExists := lookupArgument(raw.Arguments, b.Path) if !argExists || string(argRaw) == "null" { if headerVal != "" { - return fmt.Errorf("header mismatch: unexpected %s header for absent or null parameter %q", fullHeader, paramName) + return fmt.Errorf("header mismatch: unexpected %s header for absent or null parameter %q", fullHeader, strings.Join(b.Path, ".")) } continue } if headerVal == "" { - return fmt.Errorf("header mismatch: missing %s header for parameter %q", fullHeader, paramName) + return fmt.Errorf("header mismatch: missing %s header for parameter %q", fullHeader, strings.Join(b.Path, ".")) } decoded, ok := decodeHeaderValue(headerVal) @@ -338,26 +429,44 @@ func validateParamHeaders(header http.Header, msg *jsonrpc.Request, tool *Tool) bodyVal := unmarshalPrimitive(argRaw) if bodyVal == nil { - return fmt.Errorf("header mismatch: %s header present but body parameter %q is not a primitive type", fullHeader, paramName) - } - expected, ok := primitiveToString(bodyVal) - if !ok { - return fmt.Errorf("header mismatch: %s header present but body parameter %q is not a primitive type", fullHeader, paramName) + return fmt.Errorf("header mismatch: %s header present but body parameter %q is not a primitive type", fullHeader, strings.Join(b.Path, ".")) } - // TODO: String comparison may not work ideally for numbers - if decoded != expected { + if !primitiveEqual(decoded, bodyVal) { return fmt.Errorf("header mismatch: %s header value '%s' does not match body value", fullHeader, headerVal) } } return nil } +// primitiveEqual reports whether the (decoded) header string equals the +// JSON-derived body value. +func primitiveEqual(headerStr string, bodyVal any) bool { + if bodyInt, ok := bodyVal.(int64); ok { + headerNum, err := strconv.ParseFloat(headerStr, 64) + if err != nil { + return false + } + if math.IsNaN(headerNum) || math.IsInf(headerNum, 0) || headerNum != math.Trunc(headerNum) { + return false + } + if headerNum < minSafeInteger || headerNum > maxSafeInteger { + return false + } + return int64(headerNum) == bodyInt + } + expected, ok := primitiveToString(bodyVal) + if !ok { + return false + } + return headerStr == expected +} + // encodeHeaderValue converts a parameter value to an HTTP header-safe string // per the SEP-2243 encoding rules: // - string: used as-is if safe ASCII, otherwise Base64 encoded -// - number (float64): decimal string representation -// - bool: lowercase "true" or "false" +// - int64: decimal string representation +// - bool: lowercase "true" or "false" // // Values that contain non-ASCII characters, control characters, or // leading/trailing whitespace are Base64-encoded with the =?base64?...?= wrapper. @@ -383,14 +492,14 @@ func decodeHeaderValue(headerValue string) (string, bool) { return headerValue, true } - if strings.HasPrefix(strings.ToLower(headerValue), base64Prefix) && - strings.HasSuffix(headerValue, base64Suffix) { - encoded := headerValue[len(base64Prefix) : len(headerValue)-len(base64Suffix)] - decoded, err := base64.StdEncoding.DecodeString(encoded) - if err != nil { - return "", false + if encoded, ok := strings.CutPrefix(headerValue, base64Prefix); ok { + if encoded, ok = strings.CutSuffix(encoded, base64Suffix); ok { + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", false + } + return string(decoded), true } - return string(decoded), true } return headerValue, true } @@ -407,6 +516,11 @@ func requiresBase64Encoding(s string) bool { return true } } + // Per SEP-2243, plain-ASCII values that match the base64 sentinel pattern + // must also be base64-encoded to avoid ambiguity with already-encoded values. + if strings.HasPrefix(s, base64Prefix) && strings.HasSuffix(s, base64Suffix) { + return true + } return false } diff --git a/mcp/streamable_headers_test.go b/mcp/streamable_headers_test.go index 88e9f467..52d52447 100644 --- a/mcp/streamable_headers_test.go +++ b/mcp/streamable_headers_test.go @@ -519,6 +519,55 @@ func TestValidateToolParamHeaders(t *testing.T) { wantErr: true, wantErrSub: "invalid character", }, + { + name: "x-mcp-header with separator char (parens) is invalid per RFC 9110", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "X-(Region)", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with equals sign is invalid per RFC 9110", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region=1", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with all tchar specials is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "!#$%&'*+-.^_`|~aZ0", + }, + }, + }, + }, + }, { name: "duplicate header names same case", tool: &Tool{ @@ -584,7 +633,7 @@ func TestValidateToolParamHeaders(t *testing.T) { wantErrSub: "primitive types", }, { - name: "x-mcp-header on number type is valid", + name: "x-mcp-header on number type is invalid", tool: &Tool{ Name: "test", InputSchema: map[string]any{ @@ -597,6 +646,8 @@ func TestValidateToolParamHeaders(t *testing.T) { }, }, }, + wantErr: true, + wantErrSub: "primitive types", }, { name: "x-mcp-header on integer type is valid", @@ -629,7 +680,7 @@ func TestValidateToolParamHeaders(t *testing.T) { }, }, { - name: "x-mcp-header on nested property inside object", + name: "x-mcp-header on nested property inside object is valid", tool: &Tool{ Name: "test", InputSchema: map[string]any{ @@ -647,11 +698,9 @@ func TestValidateToolParamHeaders(t *testing.T) { }, }, }, - wantErr: true, - wantErrSub: "nested", }, { - name: "x-mcp-header on deeply nested property", + name: "x-mcp-header on deeply nested property is valid", tool: &Tool{ Name: "test", InputSchema: map[string]any{ @@ -674,8 +723,26 @@ func TestValidateToolParamHeaders(t *testing.T) { }, }, }, + }, + { + name: "duplicate header names across nesting levels", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + }, + }, wantErr: true, - wantErrSub: "nested", + wantErrSub: "duplicate", }, { name: "object property without nested x-mcp-header is valid", @@ -733,7 +800,7 @@ func TestValidateToolParamHeaders(t *testing.T) { wantErrSub: "primitive types", }, { - name: "jsonschema.Schema nested x-mcp-header", + name: "jsonschema.Schema nested x-mcp-header is valid", tool: &Tool{ Name: "test", InputSchema: &jsonschema.Schema{ @@ -751,8 +818,6 @@ func TestValidateToolParamHeaders(t *testing.T) { }, }, }, - wantErr: true, - wantErrSub: "nested", }, { name: "json.RawMessage valid x-mcp-header", @@ -770,6 +835,50 @@ func TestValidateToolParamHeaders(t *testing.T) { wantErr: true, wantErrSub: "primitive types", }, + { + name: "nested invalid header name is rejected", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Bad Header", // contains a space + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "nested x-mcp-header on number type is rejected", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "number", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, } for _, tt := range tests { @@ -812,15 +921,15 @@ func TestFilterValidTools(t *testing.T) { Name: "plain", InputSchema: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}}, } - nestedInvalid := &Tool{ - Name: "nested-invalid", + nestedValid := &Tool{ + Name: "nested-valid", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ "config": map[string]any{ "type": "object", "properties": map[string]any{ - "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "tenant": map[string]any{"type": "string", "x-mcp-header": "TenantId"}, }, }, }, @@ -852,12 +961,13 @@ func TestFilterValidTools(t *testing.T) { }, } - result := filterValidTools(nil, []*Tool{valid, invalid, noAnnotation, nestedInvalid, validJsonSchema, invalidJsonSchema}) - if len(result) != 3 { - t.Fatalf("filterValidTools returned %d tools, want 3", len(result)) + result := filterValidTools(nil, []*Tool{valid, invalid, noAnnotation, nestedValid, validJsonSchema, invalidJsonSchema}) + if len(result) != 4 { + t.Fatalf("filterValidTools returned %d tools, want 4", len(result)) } - if result[0].Name != "valid" || result[1].Name != "plain" || result[2].Name != "valid-jsonschema" { - t.Errorf("filterValidTools returned [%s, %s, %s], want [valid, plain, valid-jsonschema]", result[0].Name, result[1].Name, result[2].Name) + if result[0].Name != "valid" || result[1].Name != "plain" || result[2].Name != "nested-valid" || result[3].Name != "valid-jsonschema" { + t.Errorf("filterValidTools returned [%s, %s, %s, %s], want [valid, plain, nested-valid, valid-jsonschema]", + result[0].Name, result[1].Name, result[2].Name, result[3].Name) } } @@ -947,13 +1057,13 @@ func TestSetStandardHeadersWithParamHeaders(t *testing.T) { }, }, { - name: "handles number argument", + name: "handles integer argument", tool: &Tool{ Name: "test", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ - "count": map[string]any{"type": "number", "x-mcp-header": "Count"}, + "count": map[string]any{"type": "integer", "x-mcp-header": "Count"}, }, }, }, @@ -965,6 +1075,23 @@ func TestSetStandardHeadersWithParamHeaders(t *testing.T) { "Mcp-Param-Count": "42", }, }, + { + name: "out-of-range integer argument produces no header", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{"type": "integer", "x-mcp-header": "Count"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"count": float64(maxSafeInteger) + 2}, + }, + wantHeaders: nil, + }, { name: "no tool in extra does not add param headers", tool: nil, @@ -974,6 +1101,37 @@ func TestSetStandardHeadersWithParamHeaders(t *testing.T) { }, wantHeaders: nil, }, + { + name: "nested arguments resolve via dotted path", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "tenant": map[string]any{"type": "string", "x-mcp-header": "TenantId"}, + }, + }, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{ + "config": map[string]any{ + "region": "us-west1", + "tenant": "acme", + }, + }, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "us-west1", + "Mcp-Param-TenantId": "acme", + }, + }, } for _, tt := range tests { @@ -1090,6 +1248,48 @@ func TestExtractToolParamHeaders(t *testing.T) { }, want: nil, }, + { + name: "nested x-mcp-header annotations produce path-slice bindings", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "tenant": map[string]any{"type": "string", "x-mcp-header": "TenantId"}, + "deep": map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean", "x-mcp-header": "DeepFlag"}, + }, + }, + }, + }, + }, + }, + }, + want: map[string]string{ + "region": "Region", + "config.tenant": "TenantId", + "config.deep.flag": "DeepFlag", + }, + }, + { + name: "property name containing a dot is preserved in path", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "user.id": map[string]any{"type": "string", "x-mcp-header": "UserId"}, + }, + }, + }, + want: map[string]string{"user.id": "UserId"}, + }, } for _, tt := range tests { @@ -1104,9 +1304,15 @@ func TestExtractToolParamHeaders(t *testing.T) { if len(got) != len(tt.want) { t.Fatalf("extractToolParamHeaders() returned %d entries, want %d", len(got), len(tt.want)) } + // Index returned bindings by joined-path for comparison; the + // expected map uses dotted-path keys for readability. + gotMap := make(map[string]string, len(got)) + for _, b := range got { + gotMap[strings.Join(b.Path, ".")] = b.Header + } for k, v := range tt.want { - if got[k] != v { - t.Errorf("extractToolParamHeaders()[%q] = %q, want %q", k, got[k], v) + if gotMap[k] != v { + t.Errorf("extractToolParamHeaders()[%q] = %q, want %q", k, gotMap[k], v) } } }) @@ -1120,10 +1326,24 @@ func TestUnmarshalPrimitive(t *testing.T) { want any }{ {"string", `"hello"`, "hello"}, - {"number", `42`, float64(42)}, - {"float", `3.14`, float64(3.14)}, {"true", `true`, true}, {"false", `false`, false}, + + // Integer JSON numbers are promoted to int64. + {"integer", `42`, int64(42)}, + {"integer zero", `0`, int64(0)}, + {"integer negative", `-7`, int64(-7)}, + {"integer max safe", `9007199254740991`, int64(maxSafeInteger)}, + {"integer min safe", `-9007199254740991`, int64(minSafeInteger)}, + // JSON serialization of an integer-valued float (e.g. "42.0") is + // still a valid integer at the value level and must be accepted. + {"integer-valued float", `42.0`, int64(42)}, + + // Non-integer numbers, out-of-range integers, and disallowed JSON + // kinds are rejected (return nil). + {"float with fraction", `3.14`, nil}, + {"integer above max safe", `9007199254740993`, nil}, + {"integer below min safe", `-9007199254740993`, nil}, {"null", `null`, nil}, {"array", `[1,2]`, nil}, {"object", `{"a":1}`, nil}, @@ -1158,15 +1378,24 @@ func TestEncodeHeaderValue(t *testing.T) { {"string with carriage return", "line1\r\nline2", "=?base64?bGluZTENCmxpbmUy?=", true}, {"string with leading tab", "\tindented", "=?base64?CWluZGVudGVk?=", true}, - // Numbers - {"integer", float64(42), "42", true}, - {"float", float64(3.14159), "3.14159", true}, + // Sentinel pattern collisions: plain-ASCII values that match the base64 + // sentinel pattern must also be base64-encoded to avoid ambiguity. + {"sentinel collision literal", "=?base64?literal?=", "=?base64?PT9iYXNlNjQ/bGl0ZXJhbD89?=", true}, + {"sentinel collision empty", "=?base64??=", "=?base64?PT9iYXNlNjQ/Pz0=?=", true}, + // Uppercase sentinel does NOT collide (case-sensitive markers). + {"uppercase pseudo-sentinel passes through", "=?BASE64?abc?=", "=?BASE64?abc?=", true}, + + {"integer", int64(42), "42", true}, + {"integer zero", int64(0), "0", true}, + {"integer negative", int64(-7), "-7", true}, + {"integer max safe", int64(maxSafeInteger), "9007199254740991", true}, + {"integer min safe", int64(minSafeInteger), "-9007199254740991", true}, // Booleans {"true", true, "true", true}, {"false", false, "false", true}, - // Unsupported types + {"raw float64 rejected", float64(42), "", false}, {"nil", nil, "", false}, {"slice", []string{"a"}, "", false}, } @@ -1196,7 +1425,10 @@ func TestDecodeHeaderValue(t *testing.T) { {"valid base64", "=?base64?SGVsbG8=?=", "Hello", true}, {"non-ASCII decoded", "=?base64?5pel5pys6Kqe?=", "日本語", true}, {"leading space decoded", "=?base64?IHVzLXdlc3Qx?=", " us-west1", true}, - {"case-insensitive prefix", "=?BASE64?SGVsbG8=?=", "Hello", true}, + // Per SEP-2243, the base64 sentinel markers are case-sensitive: an + // uppercase prefix is treated as a literal value, not a base64 marker. + {"uppercase prefix is literal", "=?BASE64?SGVsbG8=?=", "=?BASE64?SGVsbG8=?=", true}, + {"mixed case prefix is literal", "=?Base64?SGVsbG8=?=", "=?Base64?SGVsbG8=?=", true}, {"invalid base64 chars", "=?base64?SGVs!!!bG8=?=", "", false}, // Missing prefix or suffix: treated as literal values, not base64 {"missing prefix", "SGVsbG8=", "SGVsbG8=", true}, @@ -1225,6 +1457,7 @@ func TestEncodeDecodeRoundTrip(t *testing.T) { "Hello, 世界", "line1\nline2", "\ttab", + "=?base64?literal?=", // sentinel-pattern collision (SEP-2243) } for _, v := range values { encoded, ok := encodeHeaderValue(v) @@ -1240,3 +1473,96 @@ func TestEncodeDecodeRoundTrip(t *testing.T) { } } } + +// TestValidateParamHeaders_IntegerComparison verifies server-side validation +// of integer x-mcp-header parameters per SEP-2243. +func TestValidateParamHeaders_IntegerComparison(t *testing.T) { + tool := &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{"type": "integer", "x-mcp-header": "Count"}, + }, + }, + } + tests := []struct { + name string + headerVal string + bodyArg any + wantErr bool + }{ + // Canonical decimal form matches. + {"integer matches integer", "42", float64(42), false}, + {"integer header matches integer-valued float body", "42", float64(42.0), false}, + {"negative integer matches", "-7", float64(-7), false}, + {"large safe integer matches", "1000000000000", float64(1e12), false}, + + {"non-canonical '42.0' header matches integer body", "42.0", float64(42), false}, + {"scientific notation header matches integer body", "1e2", float64(100), false}, + {"negative non-canonical header matches", "-7.0", float64(-7), false}, + + // Genuine mismatches and invalid forms are still rejected. + {"different integers do not match", "42", float64(43), true}, + {"fractional header against integer body", "3.14", float64(3), true}, + {"non-numeric header fails", "not-a-number", float64(42), true}, + {"header outside safe range against integer body", "9007199254740993", float64(42), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header := http.Header{} + header.Set(paramHeaderPrefix+"Count", tt.headerVal) + args := map[string]any{"count": tt.bodyArg} + msg := &jsonrpc.Request{ + Method: "tools/call", + Params: mustMarshal(&CallToolParams{Name: "test", Arguments: args}), + } + err := validateParamHeaders(header, msg, tool) + if tt.wantErr && err == nil { + t.Errorf("validateParamHeaders() = nil, want error") + } + if !tt.wantErr && err != nil { + t.Errorf("validateParamHeaders() = %v, want nil", err) + } + }) + } +} + +// TestValidateParamHeaders_NestedArguments verifies that the server-side +// validation can look up nested arguments via the dotted path produced by +// extractParamHeaderAnnotations. +func TestValidateParamHeaders_NestedArguments(t *testing.T) { + tool := &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + }, + } + header := http.Header{} + header.Set(paramHeaderPrefix+"Region", "us-west1") + args := map[string]any{ + "config": map[string]any{"region": "us-west1"}, + } + msg := &jsonrpc.Request{ + Method: "tools/call", + Params: mustMarshal(&CallToolParams{Name: "test", Arguments: args}), + } + if err := validateParamHeaders(header, msg, tool); err != nil { + t.Errorf("validateParamHeaders() = %v, want nil", err) + } + + // Mismatched nested value should fail. + args["config"].(map[string]any)["region"] = "eu-west1" + msg.Params = mustMarshal(&CallToolParams{Name: "test", Arguments: args}) + if err := validateParamHeaders(header, msg, tool); err == nil { + t.Error("validateParamHeaders() = nil, want error for mismatched nested value") + } +}