From 10e8758fafe1305ebb13644c4acf9225be53b150 Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Fri, 22 May 2026 11:54:21 +0200 Subject: [PATCH] oauth: extract pkg/oauth, drop UpstreamIssuerAllowlist, adopt go-sdk discovery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor altinity-mcp's OAuth code into a dedicated pkg/oauth package, drop the unused UpstreamIssuerAllowlist config field, and adopt three pieces of upstream go-sdk infrastructure where it slots in cleanly. New pkg/oauth/ contains: OAuthConfig + mode helpers (alias-imported from pkg/config), Claims struct + raw-claim projection, errors, context keys (TokenKey / ClaimsKey + WithToken / TokenFromContext / WithClaims / ClaimsFromContext), Verifier owning the JWKS + auth-server-metadata discovery caches, identity policy helpers, forward-mode header builder (BuildClickHouseHeaders + EmailFromNamespacedExtra), and the JWT parse- and-verify path (looksLikeJWT, audienceMatchesResource, parseAndVerifyExternalJWT). The validator + identity tests now live in pkg/oauth/verifier_test.go. pkg/server/server_auth_oauth.go becomes a thin alias-and-delegate shim: OAuthClaims is a type alias to oauth.Claims, the error sentinels and OAuthTokenKey / OAuthClaimsKey re-export the pkg/oauth identifiers, and the ClickHouseJWEServer methods (ValidateOAuthToken, ExtractOAuthToken*, FetchOpenIDConfiguration, ValidateUpstreamIdentityToken, ValidateOAuthIdentityPolicyClaims) all delegate to a lazily-built *oauth.Verifier. The jwksCache / oidcConfigCache fields move off ClickHouseJWEServer onto Verifier. pkg/server/server_client.go's BuildClickHouseHeadersFromOAuth + emailFromNamespacedExtra become thin wrappers over the pkg/oauth helpers. pkg/config.OAuthConfig is now a type alias to oauth.OAuthConfig. The field moves out of pkg/config/config.go entirely; ServerConfig embeds oauth.OAuthConfig directly. go-sdk adoption: auth.GetAuthServerMetadata replaces the hand-rolled openid-configuration / oauth-authorization-server fallback discovery loop (go-sdk tries each well-known endpoint, returns nil on 404). The resulting *oauthex.AuthServerMeta is cached per-issuer. Verifier. ResolveUserInfoEndpoint is a thin parallel fetch for the OIDC-only userinfo_endpoint field that AuthServerMeta does not expose. UpstreamIssuerAllowlist removal: confirmed YAGNI — all four live deployments are single-AS. parseAndVerifyExternalJWT now uses the singular Issuer config field for the issuer check (slash-normalised). The dropped TestOAuthUpstreamIssuerAllowlist tests are replaced by TestOAuthIssuerEnforcement covering accept / reject / trailing-slash- tolerant. pkg/oauth/broker/ scaffolding: pure helpers from cmd/altinity-mcp/ oauth_server.go (URL normalisation, PKCE, scope advertising, EncodeOAuthJWE / DecodeOAuthJWE, error-body sanitisation, userinfo projection, IsGoogleIssuer) are extracted as exported package symbols. Not yet wired into cmd/altinity-mcp — the route-handler + method- receiver migration is captured as deferred work in docs/oauth_next_refactor.md. docs/oauth_next_refactor.md documents: - why auth.RequireBearerToken was deferred (forward-mode opaque-token soft-pass + ClaimsToHeaders / identity-policy / ClusterSecret impersonation load-bearing on local validation) - the deployment audit checklist the next refactor needs to complete - the three viable adoption strategies (synthetic TokenInfo / conditional bypass middleware / drop soft-pass) - RFC 7662 token introspection as the missing third leg - what landed in this PR vs what is deferred for the broker move Tests: full ./... suite green. New pkg/oauth/verifier_test.go covers JWKS resolution + caching + key rotation, claims-from-raw, identity policy, validateClaims, looksLikeJWT, HasRequiredScopes. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/altinity-mcp/main.go | 6 - docs/oauth_authorization.md | 3 - docs/oauth_next_refactor.md | 209 +++++++++ pkg/config/config.go | 191 +------- pkg/config/config_test.go | 17 +- pkg/oauth/broker/helpers.go | 485 +++++++++++++++++++ pkg/oauth/claims.go | 115 +++++ pkg/oauth/config.go | 165 +++++++ pkg/oauth/context.go | 44 ++ pkg/oauth/errors.go | 20 + pkg/oauth/forward.go | 86 ++++ pkg/oauth/identity.go | 86 ++++ pkg/oauth/jwks.go | 212 +++++++++ pkg/oauth/jwt.go | 119 +++++ pkg/oauth/validator.go | 142 ++++++ pkg/oauth/verifier_test.go | 498 ++++++++++++++++++++ pkg/server/server.go | 25 +- pkg/server/server_auth_oauth.go | 665 ++++----------------------- pkg/server/server_auth_oauth_test.go | 631 +------------------------ pkg/server/server_client.go | 96 +--- 20 files changed, 2324 insertions(+), 1491 deletions(-) create mode 100644 docs/oauth_next_refactor.md create mode 100644 pkg/oauth/broker/helpers.go create mode 100644 pkg/oauth/claims.go create mode 100644 pkg/oauth/config.go create mode 100644 pkg/oauth/context.go create mode 100644 pkg/oauth/errors.go create mode 100644 pkg/oauth/forward.go create mode 100644 pkg/oauth/identity.go create mode 100644 pkg/oauth/jwks.go create mode 100644 pkg/oauth/jwt.go create mode 100644 pkg/oauth/validator.go create mode 100644 pkg/oauth/verifier_test.go diff --git a/cmd/altinity-mcp/main.go b/cmd/altinity-mcp/main.go index 3fcdd27..e18febf 100644 --- a/cmd/altinity-mcp/main.go +++ b/cmd/altinity-mcp/main.go @@ -877,12 +877,6 @@ func warnOAuthMisconfiguration(cfg config.Config) { "to the request Host header. For production deployments behind a single canonical " + "hostname, set MCP_OAUTH_PUBLIC_RESOURCE_URL to lock the resource identity.") } - if len(oauth.UpstreamIssuerAllowlist) == 0 && strings.TrimSpace(oauth.Issuer) == "" && oauth.IsForwardMode() { - log.Warn().Msg("OAuth forward mode: neither oauth_issuer nor upstream_issuer_allowlist is set — " + - "upstream identity tokens will be accepted from any signed-by-discovered-JWKS issuer. " + - "Set MCP_OAUTH_ISSUER (single-tenant) or MCP_OAUTH_UPSTREAM_ISSUER_ALLOWLIST (multi-tenant) " + - "to constrain accepted issuers.") - } // C-1 nudge: forward mode without any JWKS source means we cannot validate // JWT bearers locally. The auth layer soft-passes such tokens to ClickHouse, // which is then the sole validator. MCP authorization spec §Token Handling diff --git a/docs/oauth_authorization.md b/docs/oauth_authorization.md index f0d44e5..ba2533d 100644 --- a/docs/oauth_authorization.md +++ b/docs/oauth_authorization.md @@ -515,9 +515,6 @@ server: # Gating mode: scopes required in every incoming AS-issued JWT required_scopes: [] - # Forward mode: allowed upstream IdP issuers for identity tokens - upstream_issuer_allowlist: [] - # Identity policy — applies to both modes (claims from JWT) allowed_email_domains: [] allowed_hosted_domains: [] diff --git a/docs/oauth_next_refactor.md b/docs/oauth_next_refactor.md new file mode 100644 index 0000000..6784dc4 --- /dev/null +++ b/docs/oauth_next_refactor.md @@ -0,0 +1,209 @@ +# OAuth refactor — follow-up: adopting `auth.RequireBearerToken` + +## Status + +Captured during the `pkg/oauth/` extraction (PR oauth/refactor-go-sdk-adoption). +This document records: +1. Why `auth.RequireBearerToken` from `modelcontextprotocol/go-sdk` was deferred. +2. The follow-up work that completes the broker extraction begun in this PR. + +## What landed in this PR + +- `pkg/oauth/` package: `OAuthConfig`, `Claims`, errors, context keys, `Verifier` + (owns JWKS + OIDC discovery cache), `ValidateToken`, identity policy, forward + header builder, namespaced-email extra-claim helper. All public APIs. +- `pkg/server/server_auth_oauth.go` is now a thin alias-and-delegate shim over + `pkg/oauth.Verifier`. `ClickHouseJWEServer` no longer owns OAuth cache state. +- `pkg/config.OAuthConfig` is a type alias to `pkg/oauth.OAuthConfig`. +- `UpstreamIssuerAllowlist` field, `issuerAllowed` helper, and matching tests + removed (YAGNI — single-AS deployments only). +- go-sdk `auth.GetAuthServerMetadata` adopted for OIDC/AS discovery (with + fallback into `oauthex.AuthServerMeta` struct). +- `pkg/oauth/broker/` package created with exported pure helpers (URL/path + normalisation, PKCE, scope advertising, JWE codec for pending-auth and + auth-code state, error-body sanitisation, userinfo claim projection, + `IsGoogleIssuer`). Not yet wired into cmd/altinity-mcp — see "Deferred work". + +## Deferred work — to land in follow-up PRs + +1. **Wire cmd/altinity-mcp/oauth_server.go to use `pkg/oauth/broker/`**. The + helpers exist in the package; the cmd file still carries its own copies. + Switching the call sites is mechanical (rename `normalizeURL` → + `broker.NormalizeURL` etc.); the bigger task is migrating method receivers + (`(a *application).encodePendingAuth` etc.) into a `Broker` type that owns + the JWE-secret accessor and CIMD resolver. +2. **Move oauth_server.go's route handlers into `pkg/oauth/broker/`**. Requires + the `Broker` type above plus a `JWEAuthenticator` interface for the parts + the middleware needs from `*server.ClickHouseJWEServer`. +3. **Move cimd.go and client_assertion.go** into `pkg/oauth/broker/`. +4. **Move OAuth test files** (`cmd/altinity-mcp/oauth_*_test.go`, + `pkg/server/oauth_e2e_test.go`, `pkg/server/oauth_gating_embedded_test.go`) + alongside their production code in `pkg/oauth/broker/`. +5. **Adopt `auth.ProtectedResourceMetadataHandler`** for the + `/.well-known/oauth-protected-resource` endpoint. Blocked on either making + `PublicResourceURL` mandatory (breaks dynamic-host derivation) or by + wrapping the static handler with a per-request metadata builder. + +## Context + +The current `cmd/altinity-mcp` OAuth code (now extracted into `pkg/oauth/`) wires +a hand-rolled bearer-token middleware (`AuthInjector`) that: + +1. Extracts the `Authorization: Bearer …` header. +2. Calls `Verifier.ValidateOAuthToken(token)`. +3. Soft-passes on two cases (returns `(nil, nil)`): + - Opaque (non-JWT) bearer in forward mode — local validation impossible + without RFC 7662 introspection. + - JWT bearer without configured JWKS/issuer — operator hasn't told the + server where to fetch verification keys. +4. Hard-fails on every other validation error with a WWW-Authenticate + challenge + JSON error envelope (RFC 6750 + the broker error shape). + +`auth.RequireBearerToken` from go-sdk would replace ~all of this with a single +middleware constructor + a `TokenVerifier` callback. It does NOT fit cleanly today. +This doc explains why and what would change that. + +## Why deferred + +### 1. Forward-mode opaque-token soft-pass cannot be expressed via `RequireBearerToken` + +`RequireBearerToken` expects the verifier to return `(*TokenInfo, error)`. Both +arms are required: + +- `(nil, nil)` produces HTTP 500 — the contract is "claim or reject". +- `TokenInfo.Expiration.IsZero()` is rejected at `auth.go:133-135`. + +To preserve the soft-pass we'd have to **synthesize a fake `TokenInfo`** with a +far-future `Expiration` and empty everything else. That is dishonest data +flowing through the auth layer; the downstream `BuildClickHouseHeadersFromOAuth` +would have to learn to treat a synthetic TokenInfo as "no validation happened". +A custom bypass middleware deletes less code than it adds. + +### 2. Forward-mode JWT validation is load-bearing + +Even if the soft-pass were removed, forward-mode validation is not a no-op: + +- `ClaimsToHeaders` security — operators map JWT claims to ClickHouse headers + (e.g. `sub` → `X-ClickHouse-User`). Skipping local validation would allow a + client to set arbitrary headers via a forged JWT signature. +- Identity policy — `allowed_email_domains`, `allowed_hosted_domains`, + `email_verified` are enforced post-validation. Soft-pass disables them. +- ClusterSecret impersonation — `server_client.go` derives the impersonated + user from `claims.Email`. Without claims we cannot impersonate. + +Adopting `RequireBearerToken` blindly would silently disable these features in +the soft-pass paths. + +### 3. The realistic deletion is small + +Once you account for restoring our richer error envelope and per-tool scope +checks, `RequireBearerToken` deletes ~10–30 lines net. The custom +`AuthInjector` is small and well-tested; the cost/benefit doesn't justify the +risk in this refactor. + +## Deployment audit checklist (the next refactor needs this) + +For each live deployment, answer: + +| Env | mode | uses `ClaimsToHeaders`? | identity policy? | ClusterSecret? | +|-----------|------------------|-------------------------|----------------------------|----------------| +| otel | gating+broker | ? | ? | ? | +| antalya | forward | ? | ? | ? | +| github | ? | ? | ? | ? | +| billing | ? | ? | ? | ? | + +Fill the table by reading `$MCP_DEPLOY_DIR//mcp-values.yaml`. Any "yes" +in columns 3–5 means soft-pass is **not** safe to drop for that env. + +## Three viable adoption strategies + +### Strategy A — Synthetic `TokenInfo` + +Make the verifier return a `TokenInfo` with sentinel values for the soft-pass +cases. Downstream code learns to detect the sentinel. + +- **Pro:** Drops in cleanly under `RequireBearerToken`. +- **Con:** Synthetic data poisons every downstream consumer. Easy to break by + refactor — a future change that uses `TokenInfo.Subject` would silently + start using `""` for soft-passed requests. +- **Verdict:** Don't. + +### Strategy B — Conditional bypass middleware + +Keep a thin altinity-mcp middleware that: +1. Checks if the inbound bearer would soft-pass (opaque or unconfigured JWKS). +2. If yes, sets a context marker and skips `RequireBearerToken`. +3. If no, delegates to `RequireBearerToken`. + +- **Pro:** Honest — soft-pass is a labeled context state, not synthetic claims. +- **Con:** Two middlewares to maintain. The detection logic for "would + soft-pass" duplicates `ValidateOAuthToken`'s early-return logic. +- **Verdict:** Workable; tests stay tractable. + +### Strategy C — Drop soft-pass entirely + +Require `Issuer` or `JWKSURL` to be configured under forward mode. Require +opaque-token deployments to configure introspection (RFC 7662, see below) or +fail closed. + +- **Pro:** No special cases. `RequireBearerToken` fits perfectly. +- **Con:** Breaking change for deployments that pre-date issuer/JWKS being + load-bearing. Requires the audit table above to be filled before deciding. +- **Verdict:** Preferred long-term; needs the audit. + +## The missing third leg — RFC 7662 token introspection + +The opaque-token soft-pass exists because we have no way to validate an opaque +bearer locally. RFC 7662 (OAuth 2.0 Token Introspection) closes that gap: the +resource server POSTs the token to the AS's `/introspect` endpoint and +receives a JSON document with `active`, `exp`, `scope`, `sub`, etc. + +What this would look like: + +```yaml +oauth: + introspection_endpoint: https://idp.example/oauth/introspect + introspection_client_id: altinity-mcp + introspection_client_secret_env: MCP_OAUTH_INTROSPECTION_SECRET + introspection_cache_ttl: 60s +``` + +Per-request cost: one HTTPS round-trip to the AS, cacheable for `introspection_cache_ttl` +seconds keyed by token hash. With this implemented, the opaque-token soft-pass +becomes unnecessary — every bearer is locally validatable, either via JWKS +(JWT) or via `/introspect` (opaque). + +Once introspection lands, **Strategy C** becomes the obvious choice. + +## Decision criteria — when to revisit + +Revisit this refactor when **all** of the following are true: + +1. Upstream go-sdk PRs land (filed alongside this refactor): + - `ClockSkew time.Duration` on `RequireBearerTokenOptions` (PR-1). + - `oauthex.MatchesResource` helper (PR-2) — RFC 9728/RFC 8707 + trailing-slash tolerance, currently `audienceMatchesResource` in + `pkg/oauth/jwt.go`. + - `AllowMissingExpiration bool` on `RequireBearerTokenOptions` (PR-3) — + so session-bound bearers without standalone `exp` can opt in. +2. The deployment audit table above is fully filled, and every "yes" cell is + either no longer in production or has been migrated to a non-soft-pass + configuration. +3. RFC 7662 introspection is implemented (closes the opaque-token soft-pass) + OR a written decision is recorded that opaque-token deployments are no + longer supported. + +Until all three hold, the custom `AuthInjector` stays. The cost of holding +it (a hundred lines of well-tested middleware) is lower than the risk of a +silent ClaimsToHeaders / identity-policy / impersonation regression. + +## Related references + +- RFC 6750 — Bearer token usage. +- RFC 7662 — OAuth 2.0 Token Introspection. +- RFC 8707 — Resource Indicators for OAuth 2.0. +- RFC 9728 — OAuth 2.0 Protected Resource Metadata. +- `pkg/oauth/middleware.go` — current `AuthInjector` implementation. +- `pkg/oauth/validator.go` — `ValidateOAuthToken` and the soft-pass cases. +- `vendor/.../modelcontextprotocol/go-sdk/auth/auth.go` — `RequireBearerToken` + contract this doc references. diff --git a/pkg/config/config.go b/pkg/config/config.go index 7bac1a6..ab81eed 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -7,9 +7,15 @@ import ( "path/filepath" "strings" + "github.com/altinity/altinity-mcp/pkg/oauth" "gopkg.in/yaml.v3" ) +// OAuthConfig is an alias for oauth.OAuthConfig so existing call sites that +// reference config.OAuthConfig continue to compile. The struct definition and +// the NormalizedMode/IsForwardMode/IsGatingMode helpers live in pkg/oauth. +type OAuthConfig = oauth.OAuthConfig + // ClickHouseProtocol defines the protocol used to connect to ClickHouse type ClickHouseProtocol string @@ -45,8 +51,8 @@ type ClickHouseConfig struct { Limit int `json:"limit,omitempty" yaml:"limit,omitempty" flag:"clickhouse-limit" env:"CLICKHOUSE_LIMIT" desc:"DEPRECATED: alias for max_result_rows"` MaxResultRows int `json:"max_result_rows,omitempty" yaml:"max_result_rows,omitempty" flag:"clickhouse-max-result-rows" env:"CLICKHOUSE_MAX_RESULT_ROWS" desc:"Per-request row cap on SELECT-like queries (0=default 500, <0=disable and defer to ClickHouse user profile)"` MaxResultBytes int `json:"max_result_bytes,omitempty" yaml:"max_result_bytes,omitempty" flag:"clickhouse-max-result-bytes" env:"CLICKHOUSE_MAX_RESULT_BYTES" desc:"Per-request approximate byte cap on result body (0=default 50000, <0=disable)"` - HttpHeaders map[string]string `json:"http_headers" yaml:"http_headers" flag:"clickhouse-http-headers" env:"CLICKHOUSE_HTTP_HEADERS" desc:"HTTP Headers for ClickHouse"` - ExtraSettings map[string]string `json:"extra_settings,omitempty" yaml:"extra_settings,omitempty" desc:"Per-request ClickHouse settings injected by tool_input_settings"` + HttpHeaders map[string]string `json:"http_headers" yaml:"http_headers" flag:"clickhouse-http-headers" env:"CLICKHOUSE_HTTP_HEADERS" desc:"HTTP Headers for ClickHouse"` + ExtraSettings map[string]string `json:"extra_settings,omitempty" yaml:"extra_settings,omitempty" desc:"Per-request ClickHouse settings injected by tool_input_settings"` // ClusterName + ClusterSecret enable interserver-secret authentication. // When ClusterSecret is set, altinity-mcp connects as a trusted cluster // peer (no username/password) and executes each query as the @@ -136,179 +142,18 @@ type JWEConfig struct { JWTSecretKey string `json:"jwt_secret_key" yaml:"jwt_secret_key" flag:"jwt-secret-key" env:"MCP_JWT_SECRET_KEY" desc:"Secret key for JWT signature verification"` } -// OAuthConfig defines configuration for OAuth 2.0 authentication. -// -// Every flag-tagged field is settable via CLI flag (`flag:` tag) or env var -// (`env:` tag). The env-var convention here is `MCP_OAUTH_` so -// secrets like SigningSecret can be injected from a Kubernetes Secret via -// the Helm chart's env: array using valueFrom.secretKeyRef. -type OAuthConfig struct { - // Mode controls whether altinity-mcp forwards external OAuth bearers or gates them into local MCP tokens. - // "forward" is the production path: pass the end-user bearer through to ClickHouse. - // "gating" keeps the built-in limited OAuth facade that issues its own tokens. - Mode string `json:"mode" yaml:"mode" flag:"oauth-mode" env:"MCP_OAUTH_MODE" desc:"OAuth operating mode (forward/gating)"` - - // Enabled enables OAuth authentication - Enabled bool `json:"enabled" yaml:"enabled" flag:"oauth-enabled" env:"MCP_OAUTH_ENABLED" desc:"Enable OAuth 2.0 authentication"` - - // Issuer is the OAuth token issuer URL for token validation (e.g., "https://accounts.google.com") - Issuer string `json:"issuer" yaml:"issuer" flag:"oauth-issuer" env:"MCP_OAUTH_ISSUER" desc:"OAuth token issuer URL for validation"` - - // JWKSURL is the URL to fetch JSON Web Key Set for token validation - // If empty, will be discovered from issuer's .well-known/openid-configuration - JWKSURL string `json:"jwks_url" yaml:"jwks_url" flag:"oauth-jwks-url" env:"MCP_OAUTH_JWKS_URL" desc:"URL to fetch JWKS for token validation"` - - // Audience is the expected audience claim in the token - Audience string `json:"audience" yaml:"audience" flag:"oauth-audience" env:"MCP_OAUTH_AUDIENCE" desc:"Expected audience claim in OAuth token"` - - // PublicResourceURL is the externally visible protected resource base URL. - // When empty, it is inferred from the request host/prefix or Audience path. - PublicResourceURL string `json:"public_resource_url" yaml:"public_resource_url" flag:"oauth-public-resource-url" env:"MCP_OAUTH_PUBLIC_RESOURCE_URL" desc:"Externally visible protected resource base URL"` - - // PublicAuthServerURL is the externally visible authorization server base URL. - // When empty, it is inferred from the request host/prefix or Issuer path. - PublicAuthServerURL string `json:"public_auth_server_url" yaml:"public_auth_server_url" flag:"oauth-public-auth-server-url" env:"MCP_OAUTH_PUBLIC_AUTH_SERVER_URL" desc:"Externally visible OAuth authorization server base URL"` - - // ClientID is the OAuth client ID (used for client credentials flow or validation) - ClientID string `json:"client_id" yaml:"client_id" flag:"oauth-client-id" env:"MCP_OAUTH_CLIENT_ID" desc:"OAuth client ID"` - - // ClientSecret is the OAuth client secret (used for client credentials flow) - ClientSecret string `json:"client_secret" yaml:"client_secret" flag:"oauth-client-secret" env:"MCP_OAUTH_CLIENT_SECRET" desc:"OAuth client secret"` - - // TokenURL is the OAuth token endpoint URL (used for client credentials flow) - TokenURL string `json:"token_url" yaml:"token_url" flag:"oauth-token-url" env:"MCP_OAUTH_TOKEN_URL" desc:"OAuth token endpoint URL"` - - // AuthURL is the OAuth authorization endpoint URL (used for authorization code flow) - AuthURL string `json:"auth_url" yaml:"auth_url" flag:"oauth-auth-url" env:"MCP_OAUTH_AUTH_URL" desc:"OAuth authorization endpoint URL"` - - // UserInfoURL is the upstream OpenID Connect userinfo endpoint URL. - // If empty, it will be discovered from issuer metadata when needed. - UserInfoURL string `json:"userinfo_url" yaml:"userinfo_url" flag:"oauth-userinfo-url" env:"MCP_OAUTH_USERINFO_URL" desc:"OAuth/OpenID Connect userinfo endpoint URL"` - - // Scopes is the list of OAuth scopes to request - Scopes []string `json:"scopes" yaml:"scopes" flag:"oauth-scopes" env:"MCP_OAUTH_SCOPES" desc:"OAuth scopes to request"` - - // UpstreamOfflineAccess opts forward/broker mode into appending - // `offline_access` to the scope sent upstream. Used mainly so the IdP's - // consent screen offers long-lived sessions; the upstream refresh token - // MCP receives is currently discarded. v1 issues NO downstream refresh - // tokens to CIMD clients — they re-authorize via /oauth/authorize when - // the access token expires. See #115 § Refresh-token policy. - // Default false. Effect is upstream-only; this flag does not turn on - // any downstream refresh-token issuance. - UpstreamOfflineAccess bool `json:"upstream_offline_access" yaml:"upstream_offline_access" flag:"oauth-upstream-offline-access" env:"MCP_OAUTH_UPSTREAM_OFFLINE_ACCESS" desc:"Append offline_access to the upstream scope so the IdP's consent screen offers long-lived sessions. v1 does NOT issue downstream refresh tokens regardless of this flag — clients re-authorize via /oauth/authorize."` - - // UpstreamForceConsent forces `prompt=consent` on every upstream - // /authorize call (Google-family providers only). The first authorize - // for a user with `upstream_offline_access: true` always triggers the - // consent screen anyway — Google mints the refresh_token there and - // remembers it. Subsequent silent-SSO redemptions reuse the existing - // grant without re-prompting. Set this to true only when the operator - // needs to force re-enrollment (e.g. after rotating the upstream OAuth - // client). Default false avoids the surprise re-consent on every login. - UpstreamForceConsent bool `json:"upstream_force_consent" yaml:"upstream_force_consent" flag:"oauth-upstream-force-consent" env:"MCP_OAUTH_UPSTREAM_FORCE_CONSENT" desc:"Force prompt=consent on every upstream /authorize (Google providers only). Default false reuses Google's stored offline-access grant after the first consent."` - - // BrokerUpstream opts gating mode into the DCR-via-MCP broker pattern that - // forward mode uses by default. When true under gating mode, altinity-mcp: - // - Acts as the OAuth AS to claude.ai/ChatGPT (hosts /oauth/{register, - // authorize,callback,token}, mints stateless DCR client_ids). - // - Brokers an upstream IdP using a static OAuth application - // (ClientID/ClientSecret/AuthURL/TokenURL config). - // - Returns the upstream id_token unchanged as the access_token to the - // MCP-client; on /mcp the gating-mode JWKS-validation path validates - // it against the upstream issuer's JWKS and impersonates the user to - // ClickHouse via cluster_secret + Auth.Username. - // This is the same shape as forward mode minus the JWT-passthrough-to-CH: - // CH is reached via interserver auth + email impersonation as in standard - // gating mode. Use when the upstream IdP does not support CIMD natively - // (e.g. Google directly) but you don't want to expose CH to per-query JWT - // validation. Default false: gating remains pure resource server (#109). - BrokerUpstream bool `json:"broker_upstream" yaml:"broker_upstream" flag:"oauth-broker-upstream" env:"MCP_OAUTH_BROKER_UPSTREAM" desc:"Gating mode: enable DCR-via-MCP broker pattern (act as AS to clients, broker upstream IdP). Requires client_id/client_secret/auth_url/token_url/issuer to be set."` - - // RequiredScopes is the list of scopes required for access (token must have all of these) - RequiredScopes []string `json:"required_scopes" yaml:"required_scopes" flag:"oauth-required-scopes" env:"MCP_OAUTH_REQUIRED_SCOPES" desc:"Required OAuth scopes for access"` - - // ClickHouseHeaderName is the header name to use when forwarding OAuth token to ClickHouse - // Default: "Authorization" (sends as "Bearer {token}") - // When set to a custom header, the raw token is sent without "Bearer " prefix - ClickHouseHeaderName string `json:"clickhouse_header_name" yaml:"clickhouse_header_name" flag:"oauth-clickhouse-header-name" env:"MCP_OAUTH_CLICKHOUSE_HEADER_NAME" desc:"Header name for forwarding OAuth token to ClickHouse"` - - // ClaimsToHeaders maps OAuth token claims to ClickHouse HTTP headers - // Example: {"sub": "X-ClickHouse-User", "email": "X-ClickHouse-Email"} - ClaimsToHeaders map[string]string `json:"claims_to_headers" yaml:"claims_to_headers" flag:"oauth-claims-to-headers" env:"MCP_OAUTH_CLAIMS_TO_HEADERS" desc:"Map OAuth claims to ClickHouse HTTP headers"` - - // AllowedEmailDomains constrains accepted principals by email domain. - AllowedEmailDomains []string `json:"allowed_email_domains" yaml:"allowed_email_domains" flag:"oauth-allowed-email-domains" env:"MCP_OAUTH_ALLOWED_EMAIL_DOMAINS" desc:"Allowed email domains for verified OAuth identities"` - - // AllowedHostedDomains constrains accepted principals by hosted/workspace domain claim such as Google hd. - AllowedHostedDomains []string `json:"allowed_hosted_domains" yaml:"allowed_hosted_domains" flag:"oauth-allowed-hosted-domains" env:"MCP_OAUTH_ALLOWED_HOSTED_DOMAINS" desc:"Allowed hosted/workspace domains for verified OAuth identities"` - - // AllowUnverifiedEmail opts out of the email_verified=true requirement. - // Default zero value (false) rejects tokens carrying email with email_verified=false. - // Set true only when the IdP omits email_verified or the operator trusts upstream verification. - AllowUnverifiedEmail bool `json:"allow_unverified_email" yaml:"allow_unverified_email" flag:"oauth-allow-unverified-email" env:"MCP_OAUTH_ALLOW_UNVERIFIED_EMAIL" desc:"Accept OAuth identities with email_verified=false (default: reject)"` - - // AuthorizationPath configures the relative path for the authorization endpoint. - AuthorizationPath string `json:"authorization_path" yaml:"authorization_path" flag:"oauth-authorization-path" env:"MCP_OAUTH_AUTHORIZATION_PATH" desc:"Relative path for OAuth authorization endpoint"` - - // CallbackPath configures the relative path for the upstream IdP callback handler. - CallbackPath string `json:"callback_path" yaml:"callback_path" flag:"oauth-callback-path" env:"MCP_OAUTH_CALLBACK_PATH" desc:"Relative path for OAuth upstream callback endpoint"` - - // TokenPath configures the relative path for the token endpoint. - TokenPath string `json:"token_path" yaml:"token_path" flag:"oauth-token-path" env:"MCP_OAUTH_TOKEN_PATH" desc:"Relative path for OAuth token endpoint"` - - // UpstreamIssuerAllowlist constrains which upstream identity token issuers are accepted during callback exchange. - UpstreamIssuerAllowlist []string `json:"upstream_issuer_allowlist" yaml:"upstream_issuer_allowlist" flag:"oauth-upstream-issuer-allowlist" env:"MCP_OAUTH_UPSTREAM_ISSUER_ALLOWLIST" desc:"Allowed upstream identity token issuers"` - - // AccessTokenTTLSeconds controls how long minted access tokens remain valid. - AccessTokenTTLSeconds int `json:"access_token_ttl_seconds" yaml:"access_token_ttl_seconds" flag:"oauth-access-token-ttl-seconds" env:"MCP_OAUTH_ACCESS_TOKEN_TTL_SECONDS" desc:"Access token lifetime in seconds"` - - // RefreshTokenTTLSeconds controls how long minted refresh tokens remain valid. - RefreshTokenTTLSeconds int `json:"refresh_token_ttl_seconds" yaml:"refresh_token_ttl_seconds" flag:"oauth-refresh-token-ttl-seconds" env:"MCP_OAUTH_REFRESH_TOKEN_TTL_SECONDS" desc:"Refresh token lifetime in seconds"` - - // SigningSecret is the server-side symmetric secret used to HKDF-derive - // keys for every stateless OAuth JWE this server mints: pending-auth - // state (the upstream `state` parameter) and the downstream auth-code - // returned from /oauth/callback. Required whenever OAuth broker mode is - // active (forward, or gating + broker_upstream). Per #115 v1 issues no - // downstream refresh tokens and no DCR client_secrets. - SigningSecret string `json:"signing_secret" yaml:"signing_secret" flag:"oauth-signing-secret" env:"MCP_OAUTH_SIGNING_SECRET" desc:"Server-side HKDF master secret for OAuth JWE artifacts (pending-auth state, downstream auth codes). Required whenever broker mode is active."` -} - -func (cfg OAuthConfig) NormalizedMode() string { - mode := strings.ToLower(strings.TrimSpace(cfg.Mode)) - switch mode { - case "forward": - return "forward" - case "gating": - return "gating" - case "": - return "gating" - default: - return mode - } -} - -func (cfg OAuthConfig) IsForwardMode() bool { - return cfg.NormalizedMode() == "forward" -} - -func (cfg OAuthConfig) IsGatingMode() bool { - return cfg.NormalizedMode() == "gating" -} - // ServerConfig defines configuration for the MCP server type ServerConfig struct { - Transport MCPTransport `json:"transport" yaml:"transport" flag:"transport" env:"MCP_TRANSPORT" default:"stdio" desc:"MCP transport type (stdio/http/sse)"` - Address string `json:"address" yaml:"address" flag:"address" env:"MCP_ADDRESS" default:"0.0.0.0" desc:"Server address for HTTP/SSE transport"` - Port int `json:"port" yaml:"port" flag:"port" env:"MCP_PORT" default:"8080" desc:"Server port for HTTP/SSE transport"` - TLS ServerTLSConfig `json:"tls" yaml:"tls"` - JWE JWEConfig `json:"jwe" yaml:"jwe"` - OAuth OAuthConfig `json:"oauth" yaml:"oauth"` - OpenAPI OpenAPIConfig `json:"openapi" yaml:"openapi" desc:"OpenAPI endpoints configuration"` - CORSOrigin string `json:"cors_origin" yaml:"cors_origin" flag:"cors-origin" env:"MCP_CORS_ORIGIN" default:"*" desc:"CORS origin for HTTP/SSE transports"` - ToolInputSettings []string `json:"tool_input_settings" yaml:"tool_input_settings" flag:"tool-input-settings" env:"TOOL_INPUT_SETTINGS" desc:"ClickHouse setting names allowed in tool arguments (e.g. custom_tenant_id)"` - BlockedQueryClauses []string `json:"blocked_query_clauses" yaml:"blocked_query_clauses" flag:"blocked-query-clauses" env:"BLOCKED_QUERY_CLAUSES" desc:"AST clause kinds to block: SQL-style names derived from clickhouse-sql-parser types (e.g. WHERE, SETTINGS, FORMAT, SET, EXPLAIN) or full type stems (WHERECLAUSE); INTO OUTFILE is a special form"` + Transport MCPTransport `json:"transport" yaml:"transport" flag:"transport" env:"MCP_TRANSPORT" default:"stdio" desc:"MCP transport type (stdio/http/sse)"` + Address string `json:"address" yaml:"address" flag:"address" env:"MCP_ADDRESS" default:"0.0.0.0" desc:"Server address for HTTP/SSE transport"` + Port int `json:"port" yaml:"port" flag:"port" env:"MCP_PORT" default:"8080" desc:"Server port for HTTP/SSE transport"` + TLS ServerTLSConfig `json:"tls" yaml:"tls"` + JWE JWEConfig `json:"jwe" yaml:"jwe"` + OAuth oauth.OAuthConfig `json:"oauth" yaml:"oauth"` + OpenAPI OpenAPIConfig `json:"openapi" yaml:"openapi" desc:"OpenAPI endpoints configuration"` + CORSOrigin string `json:"cors_origin" yaml:"cors_origin" flag:"cors-origin" env:"MCP_CORS_ORIGIN" default:"*" desc:"CORS origin for HTTP/SSE transports"` + ToolInputSettings []string `json:"tool_input_settings" yaml:"tool_input_settings" flag:"tool-input-settings" env:"TOOL_INPUT_SETTINGS" desc:"ClickHouse setting names allowed in tool arguments (e.g. custom_tenant_id)"` + BlockedQueryClauses []string `json:"blocked_query_clauses" yaml:"blocked_query_clauses" flag:"blocked-query-clauses" env:"BLOCKED_QUERY_CLAUSES" desc:"AST clause kinds to block: SQL-style names derived from clickhouse-sql-parser types (e.g. WHERE, SETTINGS, FORMAT, SET, EXPLAIN) or full type stems (WHERECLAUSE); INTO OUTFILE is a special form"` // Tools is the unified tool configuration (static + dynamic in one array). // Static tools: type + name. Dynamic tools: type + regexp + prefix + mode. Tools []ToolDefinition `json:"tools" yaml:"tools" desc:"Tool definitions (static and dynamic)"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index ce094a6..bc544bd 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -568,13 +568,12 @@ func TestConfigStructs(t *testing.T) { Scopes: []string{"read", "write"}, RequiredScopes: []string{"read"}, ClickHouseHeaderName: "X-Custom-Token", - ClaimsToHeaders: map[string]string{"sub": "X-User", "email": "X-Email"}, - AuthorizationPath: "/authorize", - CallbackPath: "/callback", - TokenPath: "/token", - UpstreamIssuerAllowlist: []string{"https://accounts.google.com"}, - AccessTokenTTLSeconds: 600, - RefreshTokenTTLSeconds: 86400, + ClaimsToHeaders: map[string]string{"sub": "X-User", "email": "X-Email"}, + AuthorizationPath: "/authorize", + CallbackPath: "/callback", + TokenPath: "/token", + AccessTokenTTLSeconds: 600, + RefreshTokenTTLSeconds: 86400, } require.True(t, cfg.Enabled) @@ -595,7 +594,6 @@ func TestConfigStructs(t *testing.T) { require.Equal(t, "/authorize", cfg.AuthorizationPath) require.Equal(t, "/callback", cfg.CallbackPath) require.Equal(t, "/token", cfg.TokenPath) - require.Equal(t, []string{"https://accounts.google.com"}, cfg.UpstreamIssuerAllowlist) require.Equal(t, 600, cfg.AccessTokenTTLSeconds) require.Equal(t, 86400, cfg.RefreshTokenTTLSeconds) }) @@ -635,8 +633,6 @@ server: authorization_path: "/authorize" callback_path: "/callback" token_path: "/token" - upstream_issuer_allowlist: - - "https://accounts.google.com" access_token_ttl_seconds: 600 refresh_token_ttl_seconds: 86400 scopes: @@ -683,7 +679,6 @@ logging: require.Equal(t, "/authorize", cfg.Server.OAuth.AuthorizationPath) require.Equal(t, "/callback", cfg.Server.OAuth.CallbackPath) require.Equal(t, "/token", cfg.Server.OAuth.TokenPath) - require.Equal(t, []string{"https://accounts.google.com"}, cfg.Server.OAuth.UpstreamIssuerAllowlist) require.Equal(t, 600, cfg.Server.OAuth.AccessTokenTTLSeconds) require.Equal(t, 86400, cfg.Server.OAuth.RefreshTokenTTLSeconds) }) diff --git a/pkg/oauth/broker/helpers.go b/pkg/oauth/broker/helpers.go new file mode 100644 index 0000000..c385990 --- /dev/null +++ b/pkg/oauth/broker/helpers.go @@ -0,0 +1,485 @@ +// Package broker holds altinity-mcp's OAuth broker-mode helpers — pure +// functions and stateless types lifted out of cmd/altinity-mcp/oauth_server.go +// during the pkg/oauth/ extraction. Route handlers and methods coupled to the +// `application` lifecycle remain in cmd/altinity-mcp until the next refactor +// completes the broker move; see docs/oauth_next_refactor.md. +package broker + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/altinity/altinity-mcp/pkg/jwe_auth" + "github.com/altinity/altinity-mcp/pkg/oauth" + "github.com/go-jose/go-jose/v4" +) + +// MaxOAuthResponseBytes caps every upstream-IdP HTTP response body the broker +// reads (token, userinfo, JWKS). +const MaxOAuthResponseBytes = 1 << 20 // 1 MB + +// UpstreamHTTPTimeout bounds outbound HTTP calls to upstream IdPs. +const UpstreamHTTPTimeout = 10 * time.Second + +// Well-known endpoint paths. Exported because main wires routes by name. +const ( + DefaultProtectedResourceMetadataPath = "/.well-known/oauth-protected-resource" + DefaultAuthorizationServerMetadataPath = "/.well-known/oauth-authorization-server" + DefaultOpenIDConfigurationPath = "/.well-known/openid-configuration" + DefaultRegistrationPath = "/oauth/register" + DefaultAuthorizationPath = "/oauth/authorize" + DefaultCallbackPath = "/oauth/callback" + DefaultTokenPath = "/oauth/token" +) + +// Default TTLs for OAuth artifacts. See cmd/altinity-mcp/oauth_server.go +// comments for the RFC §s these derive from. +const ( + DefaultPendingAuthTTLSeconds = 10 * 60 + DefaultAuthCodeTTLSeconds = 60 + DefaultAccessTokenTTLSeconds = 60 * 60 + // BrokerModeIDTokenRefreshThresholdSeconds: refresh the upstream id_token + // at /oauth/token when its remaining life is below this threshold. + BrokerModeIDTokenRefreshThresholdSeconds = 55 * 60 +) + +// OAuthKidV1 is the kid header on cmd-minted OAuth JWE artifacts. Selects the +// HKDF-derived key on decryption; absence (kid="") falls back to the legacy +// SHA256(secret) derivation for backwards compat. +const OAuthKidV1 = "v1" + +// HKDF info labels for cmd-internal OAuth key derivation. +const ( + HKDFInfoOAuthPendingAuth = "altinity-mcp/oauth/pending-auth/v1" + // v2 reflects the #115 semantics change (auth-code wraps upstream code + + // PKCE verifier, not a bearer). + HKDFInfoOAuthAuthCode = "altinity-mcp/oauth/auth-code/v2" +) + +// StatelessRegisteredClient is the in-memory shape parsed from a CIMD +// metadata document. +type StatelessRegisteredClient struct { + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + JWKSURI string `json:"jwks_uri,omitempty"` +} + +// OAuthPendingAuth captures the state of an in-flight /authorize → /callback +// dance — stateless, encoded as a JWE for cross-replica decode. +type OAuthPendingAuth struct { + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + Scope string `json:"scope"` + ClientState string `json:"client_state"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + Resource string `json:"resource,omitempty"` + UpstreamPKCEVerifier string `json:"upstream_pkce_verifier,omitempty"` + ExpiresAt time.Time +} + +// OAuthIssuedCode is the JWE-encoded downstream authorization code returned +// from /oauth/callback. Wraps the upstream auth code + PKCE verifier per the +// #115 HA replay model (upstream IdP is the cross-replica replay oracle). +type OAuthIssuedCode struct { + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + Scope string `json:"scope"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + Resource string `json:"resource,omitempty"` + UpstreamAuthCode string `json:"upstream_auth_code"` + UpstreamPKCEVerifier string `json:"upstream_pkce_verifier"` + ExpiresAt time.Time +} + +// EncodeOAuthJWE emits a JWE-wrapped JSON document of `claims`, encrypted +// with a key HKDF-derived from `secret` and the per-context `info` label. +// kid="v1" is set in the protected header so decoders pick the same key. +func EncodeOAuthJWE(secret []byte, info string, claims map[string]interface{}) (string, error) { + key := jwe_auth.DeriveKey(secret, info) + plaintext, err := json.Marshal(claims) + if err != nil { + return "", err + } + encrypter, err := jose.NewEncrypter( + jose.A256GCM, + jose.Recipient{Algorithm: jose.A256KW, Key: key}, + (&jose.EncrypterOptions{}). + WithType("JWE"). + WithContentType("JSON"). + WithHeader(jose.HeaderKey("kid"), OAuthKidV1), + ) + if err != nil { + return "", err + } + jweObj, err := encrypter.Encrypt(plaintext) + if err != nil { + return "", err + } + return jweObj.CompactSerialize() +} + +// DecodeOAuthJWE decrypts a JWE produced by EncodeOAuthJWE OR by the legacy +// jwe_auth.GenerateJWEToken path. kid selects the derivation: +// +// - kid == OAuthKidV1 → key = HKDF(secret, info) +// - kid == "" → key = SHA256(secret) (legacy) +func DecodeOAuthJWE(secret []byte, info string, token string) (map[string]interface{}, error) { + jweObj, err := jose.ParseEncrypted(token, + []jose.KeyAlgorithm{jose.A256KW}, + []jose.ContentEncryption{jose.A256GCM}) + if err != nil { + return nil, jwe_auth.ErrInvalidToken + } + if jweObj.Header.KeyID == OAuthKidV1 { + key := jwe_auth.DeriveKey(secret, info) + decrypted, err := jweObj.Decrypt(key) + if err != nil { + return nil, jwe_auth.ErrInvalidToken + } + var claims map[string]interface{} + if err := json.Unmarshal(decrypted, &claims); err != nil { + return nil, jwe_auth.ErrInvalidToken + } + if err := jwe_auth.ValidateClaimsWhitelist(claims); err != nil { + return nil, err + } + if err := jwe_auth.ValidateExpiration(claims); err != nil { + return nil, err + } + return claims, nil + } + return jwe_auth.ParseAndDecryptJWE(token, secret, secret) +} + +// StringFromClaims returns claims[key] as a string, or "". +func StringFromClaims(claims map[string]interface{}, key string) string { + if v, ok := claims[key].(string); ok { + return v + } + return "" +} + +// UnixFromClaims returns claims[key] as a time.Time, treating the value as a +// Unix timestamp encoded as float64 / int64 / int. Returns the zero time if +// the key is missing or has an unsupported type. +func UnixFromClaims(claims map[string]interface{}, key string) time.Time { + v, ok := claims[key] + if !ok { + return time.Time{} + } + switch t := v.(type) { + case float64: + return time.Unix(int64(t), 0) + case int64: + return time.Unix(t, 0) + case int: + return time.Unix(int64(t), 0) + } + return time.Time{} +} + +// NormalizeURL trims whitespace and any trailing slashes from raw. +func NormalizeURL(raw string) string { + return strings.TrimRight(strings.TrimSpace(raw), "/") +} + +// CanonicalResourceURL returns the protected-resource identifier in its +// canonical form: trimmed and with exactly one trailing slash. RFC 9728 §3.3 +// (the Bearer Token resource_metadata) and RFC 8707 (resource indicators) +// treat the resource URL as an opaque identifier compared by string match. +// Audience validation accepts either form via oauth.audienceMatchesResource. +func CanonicalResourceURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + return strings.TrimRight(trimmed, "/") + "/" +} + +// NormalizedPath ensures `raw` is a leading-slash-prefixed, trailing-slash-trimmed +// path. Empty input falls back to `fallback`. +func NormalizedPath(raw string, fallback string) string { + path := strings.TrimSpace(raw) + if path == "" { + path = fallback + } + if path == "" { + return "" + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + if path == "/" { + return path + } + return strings.TrimRight(path, "/") +} + +// JoinURLPath concatenates a base URL with a path, normalising each side. +func JoinURLPath(base string, path string) string { + base = NormalizeURL(base) + path = NormalizedPath(path, "") + if path == "" || path == "/" { + return base + } + return base + path +} + +// TTLSeconds returns value if positive, else fallback. +func TTLSeconds(value int, fallback int) int { + if value > 0 { + return value + } + return fallback +} + +// UniquePaths normalises each path and returns the de-duplicated subset in +// input order, skipping empties. +func UniquePaths(paths ...string) []string { + result := make([]string, 0, len(paths)) + seen := make(map[string]struct{}, len(paths)) + for _, path := range paths { + path = NormalizedPath(path, "") + if path == "" { + continue + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + result = append(result, path) + } + return result +} + +// SuffixPrefix returns the path portion that precedes any of the listed +// well-known markers. Used to recover the per-deployment URL prefix when an +// incoming /.well-known/* request lands on a sub-path mount. +func SuffixPrefix(path string, markers ...string) string { + for _, marker := range markers { + if !strings.HasPrefix(path, marker) { + continue + } + suffix := strings.TrimSpace(strings.TrimPrefix(path, marker)) + if suffix == "" { + continue + } + if !strings.HasPrefix(suffix, "/") { + suffix = "/" + suffix + } + return strings.TrimRight(suffix, "/") + } + return "" +} + +// PathFromConfiguredURL returns the path component of a configured URL or "". +func PathFromConfiguredURL(raw string) string { + if raw == "" { + return "" + } + parsed, err := url.Parse(raw) + if err != nil { + return "" + } + return strings.TrimRight(parsed.Path, "/") +} + +// PKCEChallenge returns the S256 PKCE challenge for the given verifier. +func PKCEChallenge(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +// NewPKCEVerifier generates a 32-byte random PKCE verifier per RFC 7636 §4.1. +// Used for the upstream-IdP leg; downstream verifiers come from the client. +func NewPKCEVerifier() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +// SanitizeScope collapses internal whitespace into single spaces. +func SanitizeScope(scope string) string { + return strings.Join(strings.Fields(scope), " ") +} + +// NormalizeUpstreamScopeForClient maps upstream-IdP-specific scope URIs back +// to the OIDC standard names the MCP client originally requested. Mainly +// rewrites Google's URI-form OIDC scopes ("https://www.googleapis.com/auth/ +// userinfo.email") to standard names ("email"). Unknown values pass through. +func NormalizeUpstreamScopeForClient(scope string) string { + if scope == "" { + return "" + } + parts := strings.Fields(scope) + out := make([]string, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for _, p := range parts { + var mapped string + switch p { + case "https://www.googleapis.com/auth/userinfo.email": + mapped = "email" + case "https://www.googleapis.com/auth/userinfo.profile": + mapped = "profile" + case "https://www.googleapis.com/auth/openid": + mapped = "openid" + default: + mapped = p + } + if _, dup := seen[mapped]; dup { + continue + } + seen[mapped] = struct{}{} + out = append(out, mapped) + } + return strings.Join(out, " ") +} + +// OIDCScopesForAdvertisement returns the subset of cfg.Scopes that +// altinity-mcp surfaces to MCP clients via discovery metadata and the +// WWW-Authenticate challenge. Only the OIDC-identity allowlist plus +// offline_access is admitted. +func OIDCScopesForAdvertisement(cfg oauth.OAuthConfig) []string { + allowed := map[string]struct{}{ + "openid": {}, + "email": {}, + "profile": {}, + "offline_access": {}, + } + out := make([]string, 0, len(cfg.Scopes)) + seen := make(map[string]struct{}) + for _, s := range cfg.Scopes { + if _, ok := allowed[s]; !ok { + continue + } + if _, dup := seen[s]; dup { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out +} + +// IsGoogleIssuer reports whether the configured issuer is Google's OIDC +// provider. Used to pick between `access_type=offline` (Google) and the +// `offline_access` scope (Auth0 and other RFC 6749 §6 strict providers). +func IsGoogleIssuer(issuer string) bool { + host := strings.ToLower(strings.TrimSpace(issuer)) + host = strings.TrimPrefix(host, "https://") + host = strings.TrimPrefix(host, "http://") + host, _, _ = strings.Cut(host, "/") + return host == "accounts.google.com" || host == "www.google.com" +} + +// SafeUpstreamErrorFields extracts the RFC 6749 §5.2 `error` code from an +// upstream OAuth error response body, if the body parses as JSON, and always +// returns the body byte length. Avoids logging the body verbatim (IdPs +// sometimes echo the failed token or other diagnostic data in +// `error_description`). +func SafeUpstreamErrorFields(body []byte) (errCode string, length int) { + var parsed struct { + Error string `json:"error"` + } + _ = json.Unmarshal(body, &parsed) + return parsed.Error, len(body) +} + +// RefreshErrorFields is the variant for the refresh-token grant that also +// returns error_description (sanitised). Google's refresh failures carry +// diagnostic detail in error_description that's worth surfacing. +func RefreshErrorFields(body []byte) (errCode, errDesc string) { + var parsed struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + _ = json.Unmarshal(body, &parsed) + return parsed.Error, SanitizeErrorDesc(parsed.ErrorDescription) +} + +// SanitizeErrorDesc bounds an OAuth error_description for inclusion in our +// own error messages and logs: strips newlines + control chars, caps at 120 +// bytes, returns a leading ": " separator if non-empty. +func SanitizeErrorDesc(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + if len(s) > 120 { + s = s[:120] + } + out := make([]rune, 0, len(s)) + for _, r := range s { + if r == '\r' || r == '\n' || r == '\t' { + out = append(out, ' ') + continue + } + if r < 0x20 || r == 0x7f { + continue + } + out = append(out, r) + } + return ": " + string(out) +} + +// WriteOAuthTokenError writes an RFC 6749 §5.2 JSON error response. When +// status is 401 it also sets a Bearer-scheme WWW-Authenticate per RFC 7235 +// §3.1 and RFC 6750 §3. +func WriteOAuthTokenError(w http.ResponseWriter, status int, code, description string) { + if status == http.StatusUnauthorized { + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error=%q, error_description=%q`, code, description)) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": code, + "error_description": description, + }) +} + +// ClaimsFromUserInfo projects a raw /userinfo JSON document into oauth.Claims. +// Distinct from oauth.claimsFromRawClaims because userinfo responses lack +// aud/exp/iat and the broker fills Issuer from operator config when absent. +func ClaimsFromUserInfo(raw map[string]interface{}) *oauth.Claims { + claims := &oauth.Claims{Extra: make(map[string]interface{})} + if sub, ok := raw["sub"].(string); ok { + claims.Subject = sub + } + if iss, ok := raw["iss"].(string); ok { + claims.Issuer = iss + } + if email, ok := raw["email"].(string); ok { + claims.Email = email + } + if name, ok := raw["name"].(string); ok { + claims.Name = name + } + if hd, ok := raw["hd"].(string); ok { + claims.HostedDomain = hd + } + if verified, ok := raw["email_verified"].(bool); ok { + claims.EmailVerified = verified + } + if scope, ok := raw["scope"].(string); ok { + claims.Scopes = strings.Fields(scope) + } + for key, value := range raw { + switch key { + case "sub", "iss", "email", "name", "hd", "email_verified", "scope": + default: + claims.Extra[key] = value + } + } + return claims +} diff --git a/pkg/oauth/claims.go b/pkg/oauth/claims.go new file mode 100644 index 0000000..0f516d6 --- /dev/null +++ b/pkg/oauth/claims.go @@ -0,0 +1,115 @@ +package oauth + +import ( + "encoding/json" + "strings" +) + +// Claims represents the validated claims from an OAuth token. Standard OIDC +// claims are first-class fields; everything else is captured in Extra so the +// broker can read fields like upstream-namespaced email or refresh-window +// metadata without losing information. +type Claims struct { + Subject string `json:"sub"` + Issuer string `json:"iss"` + Audience []string `json:"aud"` + ExpiresAt int64 `json:"exp"` + IssuedAt int64 `json:"iat"` + NotBefore int64 `json:"nbf,omitempty"` + Scopes []string `json:"scope"` + Email string `json:"email,omitempty"` + Name string `json:"name,omitempty"` + HostedDomain string `json:"hd,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + Extra map[string]interface{} `json:"-"` +} + +// claimsFromRawClaims projects raw JWT claims into Claims. Standard claims are +// populated by name (handling both float64 and json.Number representations); +// non-standard claims land in Extra unchanged. Audience and scope each accept +// the two RFC-defined representations (string or array). +func claimsFromRawClaims(rawClaims map[string]interface{}) *Claims { + claims := &Claims{ + Extra: make(map[string]interface{}), + } + + if sub, ok := rawClaims["sub"].(string); ok { + claims.Subject = sub + } + if iss, ok := rawClaims["iss"].(string); ok { + claims.Issuer = iss + } + if exp, ok := rawClaims["exp"].(float64); ok { + claims.ExpiresAt = int64(exp) + } + if exp, ok := rawClaims["exp"].(json.Number); ok { + if n, err := exp.Int64(); err == nil { + claims.ExpiresAt = n + } + } + if iat, ok := rawClaims["iat"].(float64); ok { + claims.IssuedAt = int64(iat) + } + if iat, ok := rawClaims["iat"].(json.Number); ok { + if n, err := iat.Int64(); err == nil { + claims.IssuedAt = n + } + } + if nbf, ok := rawClaims["nbf"].(float64); ok { + claims.NotBefore = int64(nbf) + } + if nbf, ok := rawClaims["nbf"].(json.Number); ok { + if n, err := nbf.Int64(); err == nil { + claims.NotBefore = n + } + } + if email, ok := rawClaims["email"].(string); ok { + claims.Email = email + } + if name, ok := rawClaims["name"].(string); ok { + claims.Name = name + } + if hd, ok := rawClaims["hd"].(string); ok { + claims.HostedDomain = hd + } + if emailVerified, ok := rawClaims["email_verified"].(bool); ok { + claims.EmailVerified = emailVerified + } + if emailVerified, ok := rawClaims["email_verified"].(string); ok { + claims.EmailVerified = strings.EqualFold(emailVerified, "true") + } + + switch aud := rawClaims["aud"].(type) { + case string: + claims.Audience = []string{aud} + case []interface{}: + for _, a := range aud { + if audStr, ok := a.(string); ok { + claims.Audience = append(claims.Audience, audStr) + } + } + } + + switch scope := rawClaims["scope"].(type) { + case string: + claims.Scopes = strings.Fields(scope) + case []interface{}: + for _, s := range scope { + if scopeStr, ok := s.(string); ok { + claims.Scopes = append(claims.Scopes, scopeStr) + } + } + } + + standardClaims := map[string]bool{ + "sub": true, "iss": true, "aud": true, "exp": true, "iat": true, "nbf": true, "jti": true, + "scope": true, "email": true, "name": true, "hd": true, "email_verified": true, + } + for k, v := range rawClaims { + if !standardClaims[k] { + claims.Extra[k] = v + } + } + + return claims +} diff --git a/pkg/oauth/config.go b/pkg/oauth/config.go new file mode 100644 index 0000000..c1b93d5 --- /dev/null +++ b/pkg/oauth/config.go @@ -0,0 +1,165 @@ +package oauth + +import "strings" + +// OAuthConfig defines configuration for OAuth 2.0 authentication. +// +// Every flag-tagged field is settable via CLI flag (`flag:` tag) or env var +// (`env:` tag). The env-var convention here is `MCP_OAUTH_` so +// secrets like SigningSecret can be injected from a Kubernetes Secret via +// the Helm chart's env: array using valueFrom.secretKeyRef. +type OAuthConfig struct { + // Mode controls whether altinity-mcp forwards external OAuth bearers or gates them into local MCP tokens. + // "forward" is the production path: pass the end-user bearer through to ClickHouse. + // "gating" keeps the built-in limited OAuth facade that issues its own tokens. + Mode string `json:"mode" yaml:"mode" flag:"oauth-mode" env:"MCP_OAUTH_MODE" desc:"OAuth operating mode (forward/gating)"` + + // Enabled enables OAuth authentication + Enabled bool `json:"enabled" yaml:"enabled" flag:"oauth-enabled" env:"MCP_OAUTH_ENABLED" desc:"Enable OAuth 2.0 authentication"` + + // Issuer is the OAuth token issuer URL for token validation (e.g., "https://accounts.google.com") + Issuer string `json:"issuer" yaml:"issuer" flag:"oauth-issuer" env:"MCP_OAUTH_ISSUER" desc:"OAuth token issuer URL for validation"` + + // JWKSURL is the URL to fetch JSON Web Key Set for token validation + // If empty, will be discovered from issuer's .well-known/openid-configuration + JWKSURL string `json:"jwks_url" yaml:"jwks_url" flag:"oauth-jwks-url" env:"MCP_OAUTH_JWKS_URL" desc:"URL to fetch JWKS for token validation"` + + // Audience is the expected audience claim in the token + Audience string `json:"audience" yaml:"audience" flag:"oauth-audience" env:"MCP_OAUTH_AUDIENCE" desc:"Expected audience claim in OAuth token"` + + // PublicResourceURL is the externally visible protected resource base URL. + // When empty, it is inferred from the request host/prefix or Audience path. + PublicResourceURL string `json:"public_resource_url" yaml:"public_resource_url" flag:"oauth-public-resource-url" env:"MCP_OAUTH_PUBLIC_RESOURCE_URL" desc:"Externally visible protected resource base URL"` + + // PublicAuthServerURL is the externally visible authorization server base URL. + // When empty, it is inferred from the request host/prefix or Issuer path. + PublicAuthServerURL string `json:"public_auth_server_url" yaml:"public_auth_server_url" flag:"oauth-public-auth-server-url" env:"MCP_OAUTH_PUBLIC_AUTH_SERVER_URL" desc:"Externally visible OAuth authorization server base URL"` + + // ClientID is the OAuth client ID (used for client credentials flow or validation) + ClientID string `json:"client_id" yaml:"client_id" flag:"oauth-client-id" env:"MCP_OAUTH_CLIENT_ID" desc:"OAuth client ID"` + + // ClientSecret is the OAuth client secret (used for client credentials flow) + ClientSecret string `json:"client_secret" yaml:"client_secret" flag:"oauth-client-secret" env:"MCP_OAUTH_CLIENT_SECRET" desc:"OAuth client secret"` + + // TokenURL is the OAuth token endpoint URL (used for client credentials flow) + TokenURL string `json:"token_url" yaml:"token_url" flag:"oauth-token-url" env:"MCP_OAUTH_TOKEN_URL" desc:"OAuth token endpoint URL"` + + // AuthURL is the OAuth authorization endpoint URL (used for authorization code flow) + AuthURL string `json:"auth_url" yaml:"auth_url" flag:"oauth-auth-url" env:"MCP_OAUTH_AUTH_URL" desc:"OAuth authorization endpoint URL"` + + // UserInfoURL is the upstream OpenID Connect userinfo endpoint URL. + // If empty, it will be discovered from issuer metadata when needed. + UserInfoURL string `json:"userinfo_url" yaml:"userinfo_url" flag:"oauth-userinfo-url" env:"MCP_OAUTH_USERINFO_URL" desc:"OAuth/OpenID Connect userinfo endpoint URL"` + + // Scopes is the list of OAuth scopes to request + Scopes []string `json:"scopes" yaml:"scopes" flag:"oauth-scopes" env:"MCP_OAUTH_SCOPES" desc:"OAuth scopes to request"` + + // UpstreamOfflineAccess opts forward/broker mode into appending + // `offline_access` to the scope sent upstream. Used mainly so the IdP's + // consent screen offers long-lived sessions; the upstream refresh token + // MCP receives is currently discarded. v1 issues NO downstream refresh + // tokens to CIMD clients — they re-authorize via /oauth/authorize when + // the access token expires. See #115 § Refresh-token policy. + // Default false. Effect is upstream-only; this flag does not turn on + // any downstream refresh-token issuance. + UpstreamOfflineAccess bool `json:"upstream_offline_access" yaml:"upstream_offline_access" flag:"oauth-upstream-offline-access" env:"MCP_OAUTH_UPSTREAM_OFFLINE_ACCESS" desc:"Append offline_access to the upstream scope so the IdP's consent screen offers long-lived sessions. v1 does NOT issue downstream refresh tokens regardless of this flag — clients re-authorize via /oauth/authorize."` + + // UpstreamForceConsent forces `prompt=consent` on every upstream + // /authorize call (Google-family providers only). The first authorize + // for a user with `upstream_offline_access: true` always triggers the + // consent screen anyway — Google mints the refresh_token there and + // remembers it. Subsequent silent-SSO redemptions reuse the existing + // grant without re-prompting. Set this to true only when the operator + // needs to force re-enrollment (e.g. after rotating the upstream OAuth + // client). Default false avoids the surprise re-consent on every login. + UpstreamForceConsent bool `json:"upstream_force_consent" yaml:"upstream_force_consent" flag:"oauth-upstream-force-consent" env:"MCP_OAUTH_UPSTREAM_FORCE_CONSENT" desc:"Force prompt=consent on every upstream /authorize (Google providers only). Default false reuses Google's stored offline-access grant after the first consent."` + + // BrokerUpstream opts gating mode into the DCR-via-MCP broker pattern that + // forward mode uses by default. When true under gating mode, altinity-mcp: + // - Acts as the OAuth AS to claude.ai/ChatGPT (hosts /oauth/{register, + // authorize,callback,token}, mints stateless DCR client_ids). + // - Brokers an upstream IdP using a static OAuth application + // (ClientID/ClientSecret/AuthURL/TokenURL config). + // - Returns the upstream id_token unchanged as the access_token to the + // MCP-client; on /mcp the gating-mode JWKS-validation path validates + // it against the upstream issuer's JWKS and impersonates the user to + // ClickHouse via cluster_secret + Auth.Username. + // This is the same shape as forward mode minus the JWT-passthrough-to-CH: + // CH is reached via interserver auth + email impersonation as in standard + // gating mode. Use when the upstream IdP does not support CIMD natively + // (e.g. Google directly) but you don't want to expose CH to per-query JWT + // validation. Default false: gating remains pure resource server (#109). + BrokerUpstream bool `json:"broker_upstream" yaml:"broker_upstream" flag:"oauth-broker-upstream" env:"MCP_OAUTH_BROKER_UPSTREAM" desc:"Gating mode: enable DCR-via-MCP broker pattern (act as AS to clients, broker upstream IdP). Requires client_id/client_secret/auth_url/token_url/issuer to be set."` + + // RequiredScopes is the list of scopes required for access (token must have all of these) + RequiredScopes []string `json:"required_scopes" yaml:"required_scopes" flag:"oauth-required-scopes" env:"MCP_OAUTH_REQUIRED_SCOPES" desc:"Required OAuth scopes for access"` + + // ClickHouseHeaderName is the header name to use when forwarding OAuth token to ClickHouse + // Default: "Authorization" (sends as "Bearer {token}") + // When set to a custom header, the raw token is sent without "Bearer " prefix + ClickHouseHeaderName string `json:"clickhouse_header_name" yaml:"clickhouse_header_name" flag:"oauth-clickhouse-header-name" env:"MCP_OAUTH_CLICKHOUSE_HEADER_NAME" desc:"Header name for forwarding OAuth token to ClickHouse"` + + // ClaimsToHeaders maps OAuth token claims to ClickHouse HTTP headers + // Example: {"sub": "X-ClickHouse-User", "email": "X-ClickHouse-Email"} + ClaimsToHeaders map[string]string `json:"claims_to_headers" yaml:"claims_to_headers" flag:"oauth-claims-to-headers" env:"MCP_OAUTH_CLAIMS_TO_HEADERS" desc:"Map OAuth claims to ClickHouse HTTP headers"` + + // AllowedEmailDomains constrains accepted principals by email domain. + AllowedEmailDomains []string `json:"allowed_email_domains" yaml:"allowed_email_domains" flag:"oauth-allowed-email-domains" env:"MCP_OAUTH_ALLOWED_EMAIL_DOMAINS" desc:"Allowed email domains for verified OAuth identities"` + + // AllowedHostedDomains constrains accepted principals by hosted/workspace domain claim such as Google hd. + AllowedHostedDomains []string `json:"allowed_hosted_domains" yaml:"allowed_hosted_domains" flag:"oauth-allowed-hosted-domains" env:"MCP_OAUTH_ALLOWED_HOSTED_DOMAINS" desc:"Allowed hosted/workspace domains for verified OAuth identities"` + + // AllowUnverifiedEmail opts out of the email_verified=true requirement. + // Default zero value (false) rejects tokens carrying email with email_verified=false. + // Set true only when the IdP omits email_verified or the operator trusts upstream verification. + AllowUnverifiedEmail bool `json:"allow_unverified_email" yaml:"allow_unverified_email" flag:"oauth-allow-unverified-email" env:"MCP_OAUTH_ALLOW_UNVERIFIED_EMAIL" desc:"Accept OAuth identities with email_verified=false (default: reject)"` + + // AuthorizationPath configures the relative path for the authorization endpoint. + AuthorizationPath string `json:"authorization_path" yaml:"authorization_path" flag:"oauth-authorization-path" env:"MCP_OAUTH_AUTHORIZATION_PATH" desc:"Relative path for OAuth authorization endpoint"` + + // CallbackPath configures the relative path for the upstream IdP callback handler. + CallbackPath string `json:"callback_path" yaml:"callback_path" flag:"oauth-callback-path" env:"MCP_OAUTH_CALLBACK_PATH" desc:"Relative path for OAuth upstream callback endpoint"` + + // TokenPath configures the relative path for the token endpoint. + TokenPath string `json:"token_path" yaml:"token_path" flag:"oauth-token-path" env:"MCP_OAUTH_TOKEN_PATH" desc:"Relative path for OAuth token endpoint"` + + // AccessTokenTTLSeconds controls how long minted access tokens remain valid. + AccessTokenTTLSeconds int `json:"access_token_ttl_seconds" yaml:"access_token_ttl_seconds" flag:"oauth-access-token-ttl-seconds" env:"MCP_OAUTH_ACCESS_TOKEN_TTL_SECONDS" desc:"Access token lifetime in seconds"` + + // RefreshTokenTTLSeconds controls how long minted refresh tokens remain valid. + RefreshTokenTTLSeconds int `json:"refresh_token_ttl_seconds" yaml:"refresh_token_ttl_seconds" flag:"oauth-refresh-token-ttl-seconds" env:"MCP_OAUTH_REFRESH_TOKEN_TTL_SECONDS" desc:"Refresh token lifetime in seconds"` + + // SigningSecret is the server-side symmetric secret used to HKDF-derive + // keys for every stateless OAuth JWE this server mints: pending-auth + // state (the upstream `state` parameter) and the downstream auth-code + // returned from /oauth/callback. Required whenever OAuth broker mode is + // active (forward, or gating + broker_upstream). Per #115 v1 issues no + // downstream refresh tokens and no DCR client_secrets. + SigningSecret string `json:"signing_secret" yaml:"signing_secret" flag:"oauth-signing-secret" env:"MCP_OAUTH_SIGNING_SECRET" desc:"Server-side HKDF master secret for OAuth JWE artifacts (pending-auth state, downstream auth codes). Required whenever broker mode is active."` +} + +// NormalizedMode returns the OAuth mode lowercased and trimmed; empty input +// defaults to "gating". +func (cfg OAuthConfig) NormalizedMode() string { + mode := strings.ToLower(strings.TrimSpace(cfg.Mode)) + switch mode { + case "forward": + return "forward" + case "gating": + return "gating" + case "": + return "gating" + default: + return mode + } +} + +// IsForwardMode reports whether the configured mode is "forward". +func (cfg OAuthConfig) IsForwardMode() bool { + return cfg.NormalizedMode() == "forward" +} + +// IsGatingMode reports whether the configured mode is "gating" (or unset). +func (cfg OAuthConfig) IsGatingMode() bool { + return cfg.NormalizedMode() == "gating" +} diff --git a/pkg/oauth/context.go b/pkg/oauth/context.go new file mode 100644 index 0000000..bee4fa6 --- /dev/null +++ b/pkg/oauth/context.go @@ -0,0 +1,44 @@ +package oauth + +import "context" + +// contextKey avoids collisions with other packages using context.WithValue. +type contextKey string + +// Context keys for OAuth bearer + validated claims. JWE keys stay in pkg/server. +const ( + TokenKey contextKey = "oauth_token" + ClaimsKey contextKey = "oauth_claims" +) + +// TokenFromContext returns the raw bearer token previously stored on the +// request context by AuthInjector. Empty if not set. +func TokenFromContext(ctx context.Context) string { + if v := ctx.Value(TokenKey); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// ClaimsFromContext returns the validated claims previously stored on the +// request context by AuthInjector. Nil if not set. +func ClaimsFromContext(ctx context.Context) *Claims { + if v := ctx.Value(ClaimsKey); v != nil { + if c, ok := v.(*Claims); ok { + return c + } + } + return nil +} + +// WithToken returns a copy of ctx carrying tok under TokenKey. +func WithToken(ctx context.Context, tok string) context.Context { + return context.WithValue(ctx, TokenKey, tok) +} + +// WithClaims returns a copy of ctx carrying claims under ClaimsKey. +func WithClaims(ctx context.Context, claims *Claims) context.Context { + return context.WithValue(ctx, ClaimsKey, claims) +} diff --git a/pkg/oauth/errors.go b/pkg/oauth/errors.go new file mode 100644 index 0000000..e68b3d1 --- /dev/null +++ b/pkg/oauth/errors.go @@ -0,0 +1,20 @@ +package oauth + +import "errors" + +var ( + // ErrMissingToken is returned when an OAuth bearer token is missing from the request. + ErrMissingToken = errors.New("missing OAuth token") + // ErrInvalidToken is returned when an OAuth bearer token fails validation. + ErrInvalidToken = errors.New("invalid OAuth token") + // ErrTokenExpired is returned when an OAuth token has expired (with clock-skew tolerance). + ErrTokenExpired = errors.New("OAuth token expired") + // ErrInsufficientScopes is returned when a token doesn't carry the required scopes. + ErrInsufficientScopes = errors.New("insufficient OAuth scopes") + // ErrEmailNotVerified is returned when a token's email claim is not verified + // and the configuration does not allow unverified emails. + ErrEmailNotVerified = errors.New("OAuth email is not verified") + // ErrUnauthorizedDomain is returned when a token's principal domain is not + // in the configured allowed-email-domain / allowed-hosted-domain list. + ErrUnauthorizedDomain = errors.New("OAuth identity domain is not allowed") +) diff --git a/pkg/oauth/forward.go b/pkg/oauth/forward.go new file mode 100644 index 0000000..ea1234c --- /dev/null +++ b/pkg/oauth/forward.go @@ -0,0 +1,86 @@ +package oauth + +import ( + "encoding/json" + "strings" +) + +// BuildClickHouseHeaders builds the HTTP headers that forward-mode requires +// when proxying a request to ClickHouse: the bearer itself (under +// `Authorization` or a custom name) plus any claims-to-headers mapping. The +// caller is responsible for not invoking this in gating mode — that case +// returns nil per the legacy contract. +func BuildClickHouseHeaders(cfg OAuthConfig, token string, claims *Claims) map[string]string { + if !cfg.IsForwardMode() { + return nil + } + + headers := make(map[string]string) + + headerName := cfg.ClickHouseHeaderName + if headerName == "" { + headerName = "Authorization" + } + if headerName == "Authorization" { + headers[headerName] = "Bearer " + token + } else { + headers[headerName] = token + } + + if len(cfg.ClaimsToHeaders) > 0 && claims != nil { + for claimName, hdr := range cfg.ClaimsToHeaders { + var value string + switch claimName { + case "sub": + value = claims.Subject + case "iss": + value = claims.Issuer + case "email": + value = claims.Email + case "name": + value = claims.Name + case "email_verified": + if claims.EmailVerified { + value = "true" + } else { + value = "false" + } + case "hd": + value = claims.HostedDomain + default: + if v, ok := claims.Extra[claimName]; ok { + if strVal, ok := v.(string); ok { + value = strVal + } else if jsonBytes, err := json.Marshal(v); err == nil { + value = string(jsonBytes) + } + } + } + if value != "" { + headers[hdr] = value + } + } + } + + return headers +} + +// EmailFromNamespacedExtra returns the first string-valued claim whose key +// ends with `/email` from the JWT's non-standard claim map. Auth0 third-party +// (DCR) tokens in enhanced security mode silently drop non-namespaced custom +// claims, forcing operators to set email under a URL-prefixed key (e.g. +// `https://mcp.altinity.cloud/email`). Looking up by suffix lets MCP accept +// any namespace the operator chose. +func EmailFromNamespacedExtra(extra map[string]interface{}) string { + for k, v := range extra { + if !strings.HasSuffix(k, "/email") { + continue + } + if s, ok := v.(string); ok { + if t := strings.TrimSpace(s); t != "" { + return t + } + } + } + return "" +} diff --git a/pkg/oauth/identity.go b/pkg/oauth/identity.go new file mode 100644 index 0000000..48fbc7a --- /dev/null +++ b/pkg/oauth/identity.go @@ -0,0 +1,86 @@ +package oauth + +import ( + "strings" + + "github.com/rs/zerolog/log" +) + +// clockSkewSecs bounds the tolerance applied to exp/nbf/iat claims. Static +// rather than configurable; the next refactor (see docs/oauth_next_refactor.md +// § PR-1) lifts this into a per-Verifier option via go-sdk's +// RequireBearerTokenOptions.ClockSkew. +const clockSkewSecs = int64(60) + +// EmailDomain returns the lowercased domain portion of an email address, or +// "" when the input is malformed. Trimmed first so leading/trailing whitespace +// doesn't smuggle past the @ split. +func EmailDomain(email string) string { + parts := strings.Split(strings.ToLower(strings.TrimSpace(email)), "@") + if len(parts) != 2 { + return "" + } + return parts[1] +} + +// ContainsDomain reports whether target matches any domain in domains, case- +// and whitespace-insensitively. Used for the allowed_email_domains and +// allowed_hosted_domains identity policies. +func ContainsDomain(domains []string, target string) bool { + for _, domain := range domains { + if strings.EqualFold(strings.TrimSpace(domain), strings.TrimSpace(target)) { + return true + } + } + return false +} + +// HasRequiredScopes reports whether tokenScopes is a superset of +// requiredScopes. Comparison is exact (case- and whitespace-sensitive) since +// OAuth scope strings are user-defined and case-sensitive per RFC 6749 §3.3. +func HasRequiredScopes(tokenScopes, requiredScopes []string) bool { + scopeSet := make(map[string]bool) + for _, s := range tokenScopes { + scopeSet[s] = true + } + for _, required := range requiredScopes { + if !scopeSet[required] { + return false + } + } + return true +} + +// validateIdentityPolicy applies the configured email_verified, allowed_email_domains +// and allowed_hosted_domains checks. Returns ErrEmailNotVerified or +// ErrUnauthorizedDomain on failure. +func (v *Verifier) validateIdentityPolicy(claims *Claims) error { + cfg := v.cfg + if !cfg.AllowUnverifiedEmail && claims.Email != "" && !claims.EmailVerified { + log.Error().Str("email", claims.Email).Msg("OAuth identity email is not verified") + return ErrEmailNotVerified + } + + if len(cfg.AllowedEmailDomains) > 0 { + domain := EmailDomain(claims.Email) + if domain == "" || !ContainsDomain(cfg.AllowedEmailDomains, domain) { + log.Error().Str("email", claims.Email).Strs("allowed_domains", cfg.AllowedEmailDomains).Msg("OAuth identity email domain is not allowed") + return ErrUnauthorizedDomain + } + } + + if len(cfg.AllowedHostedDomains) > 0 { + if claims.HostedDomain == "" || !ContainsDomain(cfg.AllowedHostedDomains, claims.HostedDomain) { + log.Error().Str("hosted_domain", claims.HostedDomain).Strs("allowed_hosted_domains", cfg.AllowedHostedDomains).Msg("OAuth identity hosted domain is not allowed") + return ErrUnauthorizedDomain + } + } + + return nil +} + +// ValidateIdentityPolicyClaims is the exported wrapper used by the broker to +// re-run the identity policy after exchanging an upstream identity token. +func (v *Verifier) ValidateIdentityPolicyClaims(claims *Claims) error { + return v.validateIdentityPolicy(claims) +} diff --git a/pkg/oauth/jwks.go b/pkg/oauth/jwks.go new file mode 100644 index 0000000..e1f7789 --- /dev/null +++ b/pkg/oauth/jwks.go @@ -0,0 +1,212 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "github.com/rs/zerolog/log" +) + +const ( + // jwksCacheTTL bounds how long a JWKS or OIDC discovery response stays + // cached before re-fetch. + jwksCacheTTL = 5 * time.Minute + // httpTimeout bounds the broker's outbound discovery + JWKS HTTP calls. + httpTimeout = 10 * time.Second +) + +// Verifier validates OAuth tokens against an issuer's JWKS, caching both the +// JWKS document and the authorization-server metadata (RFC 8414 / OIDC +// discovery) it needs to locate the JWKS URI. Safe for concurrent use. +type Verifier struct { + cfg OAuthConfig + + jwksCache jose.JSONWebKeySet + jwksCacheURL string + jwksCacheTime time.Time + jwksMu sync.RWMutex + + asMetaCache oauthex.AuthServerMeta + asMetaCacheURL string + asMetaTime time.Time + asMetaMu sync.RWMutex +} + +// NewVerifier builds a Verifier for the given OAuth configuration. +func NewVerifier(cfg OAuthConfig) *Verifier { + return &Verifier{cfg: cfg} +} + +// Config returns the OAuthConfig the Verifier was built with. +func (v *Verifier) Config() OAuthConfig { + return v.cfg +} + +// resolveJWKSURL resolves the JWKS URI by configuration override, then by OIDC +// / OAuth 2.0 Authorization Server Metadata discovery from the configured +// issuer. Returns an error if neither path succeeds. +func (v *Verifier) resolveJWKSURL(ctx context.Context) (string, error) { + if explicit := strings.TrimSpace(v.cfg.JWKSURL); explicit != "" { + return explicit, nil + } + issuer := strings.TrimSpace(v.cfg.Issuer) + if issuer == "" { + return "", fmt.Errorf("oauth issuer or jwks_url must be configured") + } + asMeta, err := v.fetchAuthServerMeta(ctx, issuer) + if err != nil { + return "", err + } + jwksURI := strings.TrimSpace(asMeta.JWKSURI) + if jwksURI == "" { + return "", fmt.Errorf("openid discovery did not return jwks_uri") + } + return jwksURI, nil +} + +// fetchAuthServerMeta returns the cached or freshly-discovered authorization +// server metadata for issuer. Uses auth.GetAuthServerMetadata which tries the +// MCP-spec-required well-known endpoints in order (OAuth 2.0 first, then OIDC +// discovery, plus path-aware variants). +func (v *Verifier) fetchAuthServerMeta(ctx context.Context, issuer string) (*oauthex.AuthServerMeta, error) { + issuer = strings.TrimRight(strings.TrimSpace(issuer), "/") + if issuer == "" { + return nil, fmt.Errorf("issuer is required") + } + + v.asMetaMu.RLock() + if v.asMetaCacheURL == issuer && !v.asMetaTime.IsZero() && v.asMetaTime.Add(jwksCacheTTL).After(time.Now()) && v.asMetaCache.Issuer != "" { + cached := v.asMetaCache + v.asMetaMu.RUnlock() + return &cached, nil + } + v.asMetaMu.RUnlock() + + httpClient := &http.Client{Timeout: httpTimeout} + asMeta, err := auth.GetAuthServerMetadata(ctx, issuer, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover authorization server metadata for issuer %q: %w", issuer, err) + } + if asMeta == nil { + return nil, fmt.Errorf("no authorization server metadata found for issuer %q", issuer) + } + + v.asMetaMu.Lock() + v.asMetaCache = *asMeta + v.asMetaCacheURL = issuer + v.asMetaTime = time.Now() + v.asMetaMu.Unlock() + return asMeta, nil +} + +// FetchAuthServerMeta exposes the cached/discovered auth-server metadata for +// the given issuer. Used by the broker to resolve upstream /authorize and +// /token endpoints when the operator hasn't pinned them explicitly. +func (v *Verifier) FetchAuthServerMeta(ctx context.Context, issuer string) (*oauthex.AuthServerMeta, error) { + return v.fetchAuthServerMeta(ctx, issuer) +} + +// ResolveUserInfoEndpoint returns the OIDC userinfo_endpoint advertised by +// issuer's /.well-known/openid-configuration document. oauthex.AuthServerMeta +// (RFC 8414) doesn't expose this field — userinfo is OIDC-only — so this is +// the surgical fallback when the operator hasn't pinned UserInfoURL. +// +// Returns "" without error when the document doesn't advertise the field; the +// caller treats that the same as "no userinfo configured". +func (v *Verifier) ResolveUserInfoEndpoint(ctx context.Context, issuer string) (string, error) { + issuer = strings.TrimRight(strings.TrimSpace(issuer), "/") + if issuer == "" { + return "", fmt.Errorf("issuer is required") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, issuer+"/.well-known/openid-configuration", nil) + if err != nil { + return "", err + } + resp, err := (&http.Client{Timeout: httpTimeout}).Do(req) + if err != nil { + return "", err + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + log.Warn().Stack().Err(closeErr).Msgf("can't close openid-configuration response body for %s", issuer) + } + }() + if resp.StatusCode >= 300 { + return "", nil + } + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", err + } + var partial struct { + UserInfoEndpoint string `json:"userinfo_endpoint"` + } + if err := json.Unmarshal(body, &partial); err != nil { + return "", err + } + return strings.TrimSpace(partial.UserInfoEndpoint), nil +} + +// fetchJWKSet returns the cached or freshly-fetched JWKS for jwksURI. +func (v *Verifier) fetchJWKSet(ctx context.Context, jwksURI string) (*jose.JSONWebKeySet, error) { + now := time.Now() + + v.jwksMu.RLock() + if len(v.jwksCache.Keys) > 0 && v.jwksCacheURL == jwksURI && v.jwksCacheTime.Add(jwksCacheTTL).After(now) { + cached := v.jwksCache + v.jwksMu.RUnlock() + return &cached, nil + } + v.jwksMu.RUnlock() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI, nil) + if err != nil { + return nil, fmt.Errorf("failed to build jwks request: %w", err) + } + resp, err := (&http.Client{Timeout: httpTimeout}).Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch jwks: %w", err) + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + log.Warn().Stack().Err(closeErr).Msgf("can't close %s response body", jwksURI) + } + }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read jwks response: %w", err) + } + if resp.StatusCode >= 300 { + return nil, fmt.Errorf("jwks endpoint returned status %d", resp.StatusCode) + } + + var keySet jose.JSONWebKeySet + if err := json.Unmarshal(body, &keySet); err != nil { + return nil, fmt.Errorf("failed to parse jwks response: %w", err) + } + + v.jwksMu.Lock() + v.jwksCache = keySet + v.jwksCacheURL = jwksURI + v.jwksCacheTime = now + v.jwksMu.Unlock() + return &keySet, nil +} + +// invalidateJWKSCache forces the next fetchJWKSet call to re-fetch. Used when +// the upstream AS rotates its signing key (kid we just saw is absent from the +// cached set). +func (v *Verifier) invalidateJWKSCache() { + v.jwksMu.Lock() + v.jwksCacheTime = time.Time{} + v.jwksMu.Unlock() +} diff --git a/pkg/oauth/jwt.go b/pkg/oauth/jwt.go new file mode 100644 index 0000000..970f334 --- /dev/null +++ b/pkg/oauth/jwt.go @@ -0,0 +1,119 @@ +package oauth + +import ( + "context" + "fmt" + "strings" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/rs/zerolog/log" +) + +// looksLikeJWT is a cheap structural check: a JWS in compact form is three +// base64url segments joined by dots. False positives on garbage bearers that +// happen to contain two dots are caught downstream by ParseSigned. +func looksLikeJWT(token string) bool { + return strings.Count(token, ".") == 2 +} + +// audienceMatchesResource compares an incoming audience claim list against an +// expected resource URL with trailing-slash tolerance. RFC 9728's canonical +// form uses a trailing slash, but upstream IdPs (and prior altinity-mcp +// metadata responses) sometimes emit the form without one — match both. +// Falls back to exact match if either side isn't a URL. +func audienceMatchesResource(claims []string, expected string) bool { + expectedTrimmed := strings.TrimRight(strings.TrimSpace(expected), "/") + for _, c := range claims { + if c == expected { + return true + } + if strings.TrimRight(strings.TrimSpace(c), "/") == expectedTrimmed { + return true + } + } + return false +} + +// parseAndVerifyExternalJWT parses a compact-serialised JWT, fetches the JWKS +// for the configured issuer (with a one-shot kid-rotation refresh), and +// returns the validated claims. Issuer enforcement (singular config.Issuer) +// and audience enforcement (expectedAudience) both happen here, slash- +// normalised so a deployment whose issuer config omits the trailing slash +// matches a token whose `iss` includes it. +func (v *Verifier) parseAndVerifyExternalJWT(ctx context.Context, token, expectedAudience string) (*Claims, error) { + jwksURI, err := v.resolveJWKSURL(ctx) + if err != nil { + return nil, err + } + + parsed, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{ + jose.RS256, jose.RS384, jose.RS512, + jose.ES256, jose.ES384, jose.ES512, + jose.PS256, jose.PS384, jose.PS512, + jose.EdDSA, + }) + if err != nil { + return nil, fmt.Errorf("failed to parse signed JWT: %w", err) + } + if len(parsed.Headers) == 0 { + return nil, fmt.Errorf("missing JWT header") + } + + keySet, err := v.fetchJWKSet(ctx, jwksURI) + if err != nil { + return nil, err + } + + keys := keySet.Keys + keyID := parsed.Headers[0].KeyID + if keyID != "" { + keys = keySet.Key(keyID) + if len(keys) == 0 { + // kid absent from the cached JWKS — the AS may have rotated its + // signing key since the last fetch. Invalidate the cache and + // retry once before giving up. + v.invalidateJWKSCache() + keySet, err = v.fetchJWKSet(ctx, jwksURI) + if err != nil { + return nil, err + } + keys = keySet.Key(keyID) + if len(keys) == 0 { + return nil, fmt.Errorf("no JWK found for kid %q", keyID) + } + log.Info().Str("kid", keyID).Msg("oauth: JWKS re-fetched after key rotation; new kid found") + } + } + + expectedIssuer := strings.TrimRight(strings.TrimSpace(v.cfg.Issuer), "/") + var ( + rawClaims map[string]interface{} + signatureVerified bool + issuerRejected bool + audienceRejected bool + ) + for _, key := range keys { + rawClaims = make(map[string]interface{}) + if err := parsed.Claims(key.Key, &rawClaims); err != nil { + continue + } + signatureVerified = true + claims := claimsFromRawClaims(rawClaims) + gotIssuer := strings.TrimRight(strings.TrimSpace(claims.Issuer), "/") + if expectedIssuer != "" && gotIssuer != expectedIssuer { + issuerRejected = true + continue + } + if expectedAudience != "" && !audienceMatchesResource(claims.Audience, expectedAudience) { + audienceRejected = true + continue + } + return claims, nil + } + if signatureVerified && (issuerRejected || audienceRejected) { + return nil, ErrInvalidToken + } + + return nil, fmt.Errorf("failed to verify JWT signature with discovered JWKs") +} diff --git a/pkg/oauth/validator.go b/pkg/oauth/validator.go new file mode 100644 index 0000000..c7531c7 --- /dev/null +++ b/pkg/oauth/validator.go @@ -0,0 +1,142 @@ +package oauth + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +// ExtractTokenFromRequest extracts an OAuth bearer token from an HTTP request, +// per MCP authorization spec §Token Requirements: +// +// "MCP client MUST use the Authorization request header field defined in +// OAuth 2.1 §5.1.1: Authorization: Bearer " +// "Access tokens MUST NOT be included in the URI query string" +// +// Only the Authorization header is accepted. +func ExtractTokenFromRequest(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + return "" +} + +// RequiresLocalValidation reports whether the auth layer should call +// ValidateToken on inbound bearers. We always do, in both gating and forward +// modes — ValidateToken itself decides what kind of validation applies for +// the configured mode and token shape. +func (v *Verifier) RequiresLocalValidation() bool { + return v.cfg.Enabled +} + +// ValidateToken validates an OAuth bearer and returns claims. +// +// Both modes route through the JWKS-based external-JWT validator: under +// gating, MCP is a pure resource server and the bearer is an upstream IdP +// (Auth0) access token; under forward, MCP proxies the upstream IdP token to +// the client unchanged. In both cases local validation is signature + iss + +// aud + exp against the configured JWKS. +// +// Two cases soft-pass (return nil claims, nil error) — the auth layer accepts +// the request and forwards to ClickHouse, which is then the sole validator: +// +// 1. Opaque (non-JWT) bearers — RFC 7662 introspection is not implemented; +// local validation isn't possible. +// 2. JWT bearers with neither Issuer nor JWKSURL configured — operator +// hasn't told us where to fetch verification keys. +// +// Soft-pass preserves compatibility with deployments that pre-date C-1 and +// rely entirely on ClickHouse-side validation. See docs/oauth_next_refactor.md +// for the plan to remove soft-pass once token introspection lands. +func (v *Verifier) ValidateToken(ctx context.Context, token string) (*Claims, error) { + if !v.cfg.Enabled { + return nil, nil + } + + if token == "" { + return nil, ErrMissingToken + } + + mode := v.cfg.NormalizedMode() + if !looksLikeJWT(token) { + if v.cfg.IsGatingMode() { + log.Error().Str("mode", mode).Msg("OAuth token is not a JWT; gating mode requires a signed JWT from the upstream AS") + return nil, ErrInvalidToken + } + log.Debug().Str("mode", mode).Msg("Bearer is opaque (not a JWT); skipping local validation, deferring to ClickHouse") + return nil, nil + } + if strings.TrimSpace(v.cfg.JWKSURL) == "" && strings.TrimSpace(v.cfg.Issuer) == "" { + log.Debug().Str("mode", mode).Msg("JWT received but neither oauth_issuer nor jwks_url is configured; skipping local validation") + return nil, nil + } + claims, err := v.parseAndVerifyExternalJWT(ctx, token, v.cfg.Audience) + if err != nil { + log.Error().Err(err).Str("mode", mode).Msg("Failed to validate OAuth token") + return nil, err + } + + return v.validateClaims(claims) +} + +// validateClaims applies post-signature-verification checks: audience (slash- +// normalised), exp/nbf/iat (with clockSkewSecs tolerance), required scopes, +// and identity policy (email_verified, allowed domains). +func (v *Verifier) validateClaims(claims *Claims) (*Claims, error) { + // Issuer enforcement happens in parseAndVerifyExternalJWT, the only path + // that reaches here. Re-validating here would duplicate the check. + + if v.cfg.Audience != "" { + if len(claims.Audience) == 0 { + log.Error().Str("expected", v.cfg.Audience).Msg("OAuth token missing audience claim") + return nil, ErrInvalidToken + } + if !audienceMatchesResource(claims.Audience, v.cfg.Audience) { + log.Error().Str("expected", v.cfg.Audience).Strs("got", claims.Audience).Msg("OAuth token audience mismatch") + return nil, ErrInvalidToken + } + } + + now := time.Now().Unix() + if claims.ExpiresAt > 0 && now > claims.ExpiresAt+clockSkewSecs { + log.Error().Int64("exp", claims.ExpiresAt).Msg("OAuth token expired") + return nil, ErrTokenExpired + } + if claims.NotBefore > 0 && now+clockSkewSecs < claims.NotBefore { + log.Error().Int64("nbf", claims.NotBefore).Msg("OAuth token not yet valid") + return nil, ErrInvalidToken + } + if claims.IssuedAt > 0 && claims.IssuedAt > now+clockSkewSecs { + log.Error().Int64("iat", claims.IssuedAt).Msg("OAuth token issued in the future") + return nil, ErrInvalidToken + } + + if len(v.cfg.RequiredScopes) > 0 { + if !HasRequiredScopes(claims.Scopes, v.cfg.RequiredScopes) { + log.Error().Strs("required", v.cfg.RequiredScopes).Strs("got", claims.Scopes).Msg("OAuth token missing required scopes") + return nil, ErrInsufficientScopes + } + } + + if err := v.validateIdentityPolicy(claims); err != nil { + return nil, err + } + + return claims, nil +} + +// ValidateUpstreamIdentityToken parses an upstream identity token using the +// JWKS path (no soft-pass) and applies the identity policy. Used by the +// broker's /oauth/callback after exchanging the upstream authorization code +// for an id_token. +func (v *Verifier) ValidateUpstreamIdentityToken(ctx context.Context, token, expectedAudience string) (*Claims, error) { + claims, err := v.parseAndVerifyExternalJWT(ctx, token, expectedAudience) + if err != nil { + return nil, err + } + return claims, v.validateIdentityPolicy(claims) +} diff --git a/pkg/oauth/verifier_test.go b/pkg/oauth/verifier_test.go new file mode 100644 index 0000000..ce47d7a --- /dev/null +++ b/pkg/oauth/verifier_test.go @@ -0,0 +1,498 @@ +package oauth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" +) + +// encodeOIDC returns a JSON-encodable discovery document with the PKCE +// methods go-sdk's auth.GetAuthServerMetadata requires. +func encodeOIDC(issuer, jwksURI string) map[string]interface{} { + return map[string]interface{}{ + "issuer": issuer, + "jwks_uri": jwksURI, + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "authorization_endpoint": issuer + "/authorize", + "token_endpoint": issuer + "/token", + "id_token_signing_alg_values_supported": []string{"RS256"}, + } +} + +func TestResolveJWKSURL(t *testing.T) { + t.Parallel() + t.Run("direct_jwks_url_configured", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{ + JWKSURL: "https://auth.example.com/jwks", + }) + url, err := v.resolveJWKSURL(context.Background()) + require.NoError(t, err) + require.Equal(t, "https://auth.example.com/jwks", url) + }) + + t.Run("openid_configuration_discovery", func(t *testing.T) { + t.Parallel() + var mockURL string + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(encodeOIDC(mockURL, mockURL+"/keys")) + }) + mockServer := httptest.NewServer(mux) + defer mockServer.Close() + mockURL = mockServer.URL + + v := NewVerifier(OAuthConfig{Issuer: mockURL}) + url, err := v.resolveJWKSURL(context.Background()) + require.NoError(t, err) + require.Equal(t, mockURL+"/keys", url) + }) + + t.Run("oauth_authorization_server_discovery", func(t *testing.T) { + t.Parallel() + // go-sdk tries oauth-authorization-server first, falling through to + // openid-configuration on 404. Same shape works for both — pick one. + var mockURL string + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(encodeOIDC(mockURL, mockURL+"/fallback-keys")) + }) + mockServer := httptest.NewServer(mux) + defer mockServer.Close() + mockURL = mockServer.URL + + v := NewVerifier(OAuthConfig{Issuer: mockURL}) + url, err := v.resolveJWKSURL(context.Background()) + require.NoError(t, err) + require.Equal(t, mockURL+"/fallback-keys", url) + }) + + t.Run("both_discovery_endpoints_fail", func(t *testing.T) { + t.Parallel() + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer mockServer.Close() + + v := NewVerifier(OAuthConfig{Issuer: mockServer.URL}) + _, err := v.resolveJWKSURL(context.Background()) + require.Error(t, err) + }) + + t.Run("missing_jwks_uri_in_metadata", func(t *testing.T) { + t.Parallel() + // auth.GetAuthServerMetadata returns an error if the issuer in the + // metadata document does not match the issuer URL or if PKCE is + // missing; emit a document that satisfies discovery but lacks + // jwks_uri to exercise the Verifier's "missing jwks_uri" branch. + var mockURL string + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": mockURL, + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "authorization_endpoint": mockURL + "/authorize", + "token_endpoint": mockURL + "/token", + "id_token_signing_alg_values_supported": []string{"RS256"}, + }) + })) + defer mockServer.Close() + mockURL = mockServer.URL + + v := NewVerifier(OAuthConfig{Issuer: mockURL}) + _, err := v.resolveJWKSURL(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "jwks_uri") + }) +} + +func TestAuthServerMetaCaching(t *testing.T) { + t.Parallel() + var requestCount int + var mockURL string + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(encodeOIDC(mockURL, mockURL+"/keys")) + }) + // Respond to the other well-known so go-sdk's first-try doesn't error. + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }) + mockServer := httptest.NewServer(mux) + defer mockServer.Close() + mockURL = mockServer.URL + + v := NewVerifier(OAuthConfig{Issuer: mockURL}) + + t.Run("cache_hit_within_ttl", func(t *testing.T) { + requestCount = 0 + _, err := v.FetchAuthServerMeta(context.Background(), mockURL) + require.NoError(t, err) + _, err = v.FetchAuthServerMeta(context.Background(), mockURL) + require.NoError(t, err) + require.Equal(t, 1, requestCount, "second call should hit cache") + }) + + t.Run("cache_miss_after_ttl_expires", func(t *testing.T) { + _, err := v.FetchAuthServerMeta(context.Background(), mockURL) + require.NoError(t, err) + + v.asMetaMu.Lock() + v.asMetaTime = time.Now().Add(-jwksCacheTTL - time.Second) + v.asMetaMu.Unlock() + + countBefore := requestCount + _, err = v.FetchAuthServerMeta(context.Background(), mockURL) + require.NoError(t, err) + require.Equal(t, countBefore+1, requestCount, "should re-fetch after TTL expiry") + }) +} + +func TestParseAndVerifyExternalJWTUnknownKid(t *testing.T) { + t.Parallel() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + knownJWK := jose.JSONWebKey{Key: &privateKey.PublicKey, KeyID: "known", Algorithm: "RS256", Use: "sig"} + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jose.JSONWebKeySet{Keys: []jose.JSONWebKey{knownJWK}}) + return + } + http.NotFound(w, r) + })) + defer mockServer.Close() + + v := NewVerifier(OAuthConfig{ + Issuer: mockServer.URL, + JWKSURL: mockServer.URL + "/jwks", + }) + + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "unknown"), + ) + require.NoError(t, err) + + payload, err := json.Marshal(map[string]interface{}{ + "sub": "user-1", + "iss": mockServer.URL, + "aud": "test-audience", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + require.NoError(t, err) + + object, err := signer.Sign(payload) + require.NoError(t, err) + token, err := object.CompactSerialize() + require.NoError(t, err) + + _, err = v.parseAndVerifyExternalJWT(context.Background(), token, "test-audience") + require.Error(t, err) + require.Contains(t, err.Error(), "no JWK found for kid") +} + +// TestJWKSRefetchOnKidMiss verifies that a kid absent from the cached JWKS +// triggers a one-shot cache-bypass re-fetch, allowing tokens issued after a +// key rotation to be accepted without waiting for the TTL to expire. +func TestJWKSRefetchOnKidMiss(t *testing.T) { + t.Parallel() + + oldKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + newKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + const oldKid = "old-signing-key" + const newKid = "new-signing-key" + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/jwks": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ + {Key: &newKey.PublicKey, KeyID: newKid, Algorithm: "RS256", Use: "sig"}, + }}) + default: + http.NotFound(w, r) + } + })) + defer mockServer.Close() + + v := NewVerifier(OAuthConfig{ + Issuer: mockServer.URL, + JWKSURL: mockServer.URL + "/jwks", + }) + + // Seed the JWKS cache with the old key and a far-future TTL so that a + // normal fetch would not re-fetch. + v.jwksMu.Lock() + v.jwksCache = jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ + {Key: &oldKey.PublicKey, KeyID: oldKid, Algorithm: "RS256", Use: "sig"}, + }} + v.jwksCacheURL = mockServer.URL + "/jwks" + v.jwksCacheTime = time.Now().Add(10 * time.Minute) + v.jwksMu.Unlock() + + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: newKey}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", newKid), + ) + require.NoError(t, err) + payload, err := json.Marshal(map[string]interface{}{ + "sub": "user-1", + "iss": mockServer.URL, + "aud": "test-audience", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + require.NoError(t, err) + obj, err := signer.Sign(payload) + require.NoError(t, err) + token, err := obj.CompactSerialize() + require.NoError(t, err) + + claims, err := v.parseAndVerifyExternalJWT(context.Background(), token, "test-audience") + require.NoError(t, err) + require.Equal(t, "user-1", claims.Subject) +} + +func TestEmailDomain(t *testing.T) { + t.Parallel() + tests := []struct { + name string + email string + want string + }{ + {"normal", "user@example.com", "example.com"}, + {"uppercase", "User@EXAMPLE.COM", "example.com"}, + {"whitespace", " user@example.com ", "example.com"}, + {"no_at", "noatsign", ""}, + {"empty", "", ""}, + {"multiple_at", "a@b@c", ""}, + {"just_at", "@", ""}, + {"domain_only", "@domain.com", "domain.com"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, EmailDomain(tt.email)) + }) + } +} + +func TestClaimsFromRawClaims(t *testing.T) { + t.Parallel() + + t.Run("all_standard_fields", func(t *testing.T) { + t.Parallel() + raw := map[string]interface{}{ + "sub": "user123", + "iss": "https://auth.example.com", + "exp": float64(1700000000), + "iat": float64(1699999000), + "nbf": float64(1699998000), + "email": "user@example.com", + "name": "Test User", + "hd": "example.com", + "email_verified": true, + "aud": "my-api", + "scope": "read write", + } + claims := claimsFromRawClaims(raw) + require.Equal(t, "user123", claims.Subject) + require.Equal(t, "https://auth.example.com", claims.Issuer) + require.Equal(t, int64(1700000000), claims.ExpiresAt) + require.Equal(t, int64(1699999000), claims.IssuedAt) + require.Equal(t, int64(1699998000), claims.NotBefore) + require.Equal(t, "user@example.com", claims.Email) + require.Equal(t, "Test User", claims.Name) + require.Equal(t, "example.com", claims.HostedDomain) + require.True(t, claims.EmailVerified) + require.Equal(t, []string{"my-api"}, claims.Audience) + require.Equal(t, []string{"read", "write"}, claims.Scopes) + }) + + t.Run("json_number_fields", func(t *testing.T) { + t.Parallel() + raw := map[string]interface{}{ + "sub": "user", + "exp": json.Number("1700000000"), + "iat": json.Number("1699999000"), + "nbf": json.Number("1699998000"), + } + claims := claimsFromRawClaims(raw) + require.Equal(t, int64(1700000000), claims.ExpiresAt) + require.Equal(t, int64(1699999000), claims.IssuedAt) + require.Equal(t, int64(1699998000), claims.NotBefore) + }) + + t.Run("audience_array", func(t *testing.T) { + t.Parallel() + raw := map[string]interface{}{ + "aud": []interface{}{"api1", "api2"}, + } + claims := claimsFromRawClaims(raw) + require.Equal(t, []string{"api1", "api2"}, claims.Audience) + }) + + t.Run("scope_array", func(t *testing.T) { + t.Parallel() + raw := map[string]interface{}{ + "scope": []interface{}{"read", "write", "admin"}, + } + claims := claimsFromRawClaims(raw) + require.Equal(t, []string{"read", "write", "admin"}, claims.Scopes) + }) + + t.Run("email_verified_string", func(t *testing.T) { + t.Parallel() + raw := map[string]interface{}{ + "email_verified": "true", + } + claims := claimsFromRawClaims(raw) + require.True(t, claims.EmailVerified) + + raw2 := map[string]interface{}{ + "email_verified": "false", + } + claims2 := claimsFromRawClaims(raw2) + require.False(t, claims2.EmailVerified) + }) + + t.Run("extra_claims_preserved", func(t *testing.T) { + t.Parallel() + raw := map[string]interface{}{ + "sub": "user", + "custom1": "value1", + "custom_num": float64(42), + } + claims := claimsFromRawClaims(raw) + require.Equal(t, "value1", claims.Extra["custom1"]) + require.Equal(t, float64(42), claims.Extra["custom_num"]) + _, hasSub := claims.Extra["sub"] + require.False(t, hasSub) + }) + + t.Run("empty_claims", func(t *testing.T) { + t.Parallel() + claims := claimsFromRawClaims(map[string]interface{}{}) + require.NotNil(t, claims) + require.Empty(t, claims.Subject) + require.NotNil(t, claims.Extra) + }) +} + +func TestLooksLikeJWT(t *testing.T) { + t.Parallel() + require.True(t, looksLikeJWT("a.b.c")) + require.False(t, looksLikeJWT("not-a-jwt")) + require.False(t, looksLikeJWT("a.b")) + require.False(t, looksLikeJWT("a.b.c.d")) +} + +func TestHasRequiredScopes(t *testing.T) { + t.Parallel() + require.True(t, HasRequiredScopes([]string{"read", "write", "admin"}, []string{"read", "write"})) + require.False(t, HasRequiredScopes([]string{"read"}, []string{"read", "admin"})) + require.True(t, HasRequiredScopes([]string{"read"}, []string{})) + require.True(t, HasRequiredScopes([]string{}, []string{})) + require.False(t, HasRequiredScopes([]string{}, []string{"read"})) +} + +func TestValidateClaims(t *testing.T) { + t.Parallel() + + t.Run("audience_missing_when_required", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{Audience: "my-audience"}) + _, err := v.validateClaims(&Claims{}) + require.ErrorIs(t, err, ErrInvalidToken) + }) + + t.Run("audience_mismatch", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{Audience: "my-audience"}) + _, err := v.validateClaims(&Claims{Audience: []string{"wrong-audience"}}) + require.ErrorIs(t, err, ErrInvalidToken) + }) + + t.Run("audience_trailing_slash_tolerant", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{Audience: "https://mcp.example.com"}) + _, err := v.validateClaims(&Claims{ + Audience: []string{"https://mcp.example.com/"}, + ExpiresAt: time.Now().Unix() + 300, + }) + require.NoError(t, err) + + v = NewVerifier(OAuthConfig{Audience: "https://mcp.example.com/"}) + _, err = v.validateClaims(&Claims{ + Audience: []string{"https://mcp.example.com"}, + ExpiresAt: time.Now().Unix() + 300, + }) + require.NoError(t, err) + }) + + t.Run("token_expired", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{}) + _, err := v.validateClaims(&Claims{ExpiresAt: time.Now().Unix() - 300}) + require.ErrorIs(t, err, ErrTokenExpired) + }) + + t.Run("not_yet_valid", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{}) + _, err := v.validateClaims(&Claims{NotBefore: time.Now().Unix() + 300}) + require.ErrorIs(t, err, ErrInvalidToken) + }) + + t.Run("issued_in_future", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{}) + _, err := v.validateClaims(&Claims{IssuedAt: time.Now().Unix() + 300}) + require.ErrorIs(t, err, ErrInvalidToken) + }) + + t.Run("missing_required_scopes", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{RequiredScopes: []string{"admin"}}) + _, err := v.validateClaims(&Claims{Scopes: []string{"read"}}) + require.ErrorIs(t, err, ErrInsufficientScopes) + }) + + t.Run("valid_claims", func(t *testing.T) { + t.Parallel() + v := NewVerifier(OAuthConfig{ + Issuer: "https://issuer.example.com", + Audience: "my-aud", + RequiredScopes: []string{"read"}, + }) + claims, err := v.validateClaims(&Claims{ + Issuer: "https://issuer.example.com", + Audience: []string{"my-aud"}, + ExpiresAt: time.Now().Unix() + 300, + Scopes: []string{"read", "write"}, + }) + require.NoError(t, err) + require.Equal(t, "https://issuer.example.com", claims.Issuer) + }) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index d33351b..9743dc0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -6,11 +6,10 @@ import ( "fmt" "strings" "sync" - "time" "github.com/altinity/altinity-mcp/pkg/clickhouse" "github.com/altinity/altinity-mcp/pkg/config" - "github.com/go-jose/go-jose/v4" + "github.com/altinity/altinity-mcp/pkg/oauth" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/rs/zerolog/log" ) @@ -24,16 +23,12 @@ type ClickHouseJWEServer struct { dynamicTools map[string]dynamicToolMeta dynamicToolsMu sync.RWMutex dynamicToolsInit bool - // JWKS cache for OAuth token validation - jwksCache jose.JSONWebKeySet - jwksCacheURL string - jwksCacheMu sync.RWMutex - jwksCacheTime time.Time - oidcConfigCache OpenIDConfiguration - oidcConfigCacheURL string - oidcConfigMu sync.RWMutex - oidcConfigTime time.Time - blockedClauses map[string]bool + // oauthVerifier owns the JWKS + OIDC discovery cache it needs to validate + // inbound OAuth bearers. Constructed in NewClickHouseMCPServer; tests that + // build ClickHouseJWEServer via struct literal get a lazily-built verifier + // from the verifier() getter. + oauthVerifier *oauth.Verifier + blockedClauses map[string]bool } // ToolHandlerFunc is a function type for tool handlers @@ -73,6 +68,7 @@ func NewClickHouseMCPServer(cfg config.Config, version string) *ClickHouseJWESer Config: cfg, Version: version, dynamicTools: make(map[string]dynamicToolMeta), + oauthVerifier: oauth.NewVerifier(cfg.Server.OAuth), blockedClauses: NormalizeBlockedClauses(cfg.Server.BlockedQueryClauses), } @@ -545,11 +541,10 @@ func GetClickHouseJWEServerFromContext(ctx context.Context) *ClickHouseJWEServer // contextKey avoids collisions with other packages using context.WithValue. type contextKey string -// Auth context keys +// Auth context keys for JWE + the embedded MCP server. OAuth token / claims +// keys live in pkg/oauth (re-exported as vars in server_auth_oauth.go). const ( JWETokenKey contextKey = "jwe_token" JWEClaimsKey contextKey = "jwe_claims" - OAuthTokenKey contextKey = "oauth_token" - OAuthClaimsKey contextKey = "oauth_claims" CHJWEServerKey contextKey = "clickhouse_jwe_server" ) diff --git a/pkg/server/server_auth_oauth.go b/pkg/server/server_auth_oauth.go index a62ed1f..2fd94ef 100644 --- a/pkg/server/server_auth_oauth.go +++ b/pkg/server/server_auth_oauth.go @@ -2,620 +2,115 @@ package server import ( "context" - "encoding/json" - "errors" - "fmt" - "io" "net/http" - "strings" - "time" - "github.com/go-jose/go-jose/v4" - "github.com/go-jose/go-jose/v4/jwt" - "github.com/rs/zerolog/log" + "github.com/altinity/altinity-mcp/pkg/oauth" ) -var ( - // ErrMissingOAuthToken is returned when OAuth token is missing - ErrMissingOAuthToken = errors.New("missing OAuth token") - // ErrInvalidOAuthToken is returned when OAuth token is invalid - ErrInvalidOAuthToken = errors.New("invalid OAuth token") - // ErrOAuthTokenExpired is returned when OAuth token has expired - ErrOAuthTokenExpired = errors.New("OAuth token expired") - // ErrOAuthInsufficientScopes is returned when token doesn't have required scopes - ErrOAuthInsufficientScopes = errors.New("insufficient OAuth scopes") - // ErrOAuthEmailNotVerified is returned when token email is not verified - ErrOAuthEmailNotVerified = errors.New("OAuth email is not verified") - // ErrOAuthUnauthorizedDomain is returned when token principal domain is not allowed - ErrOAuthUnauthorizedDomain = errors.New("OAuth identity domain is not allowed") -) +// Re-exports of pkg/oauth identifiers so existing pkg/server-aware callers and +// tests continue to compile after the extraction. Prefer the oauth package +// directly in new code. -const ( - oauthJWKSCacheTTL = 5 * time.Minute - oauthHTTPTimeout = 10 * time.Second - oauthClockSkewSecs = int64(60) -) +// OAuthClaims is the validated claim set from an OAuth token. Type alias so +// pkg/server callers and pkg/oauth callers share the same underlying type. +type OAuthClaims = oauth.Claims +// OpenIDConfiguration is the minimal subset of OIDC discovery metadata the +// broker reads. Returned by FetchOpenIDConfiguration during the transition; +// once the broker code moves to pkg/oauth/broker it'll consume +// oauthex.AuthServerMeta + Verifier.ResolveUserInfoEndpoint directly. type OpenIDConfiguration struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - JWKSURI string `json:"jwks_uri"` - UserInfoEndpoint string `json:"userinfo_endpoint"` + Issuer string + AuthorizationEndpoint string + TokenEndpoint string + JWKSURI string + UserInfoEndpoint string } -// OAuthClaims represents the claims from an OAuth token -type OAuthClaims struct { - Subject string `json:"sub"` - Issuer string `json:"iss"` - Audience []string `json:"aud"` - ExpiresAt int64 `json:"exp"` - IssuedAt int64 `json:"iat"` - NotBefore int64 `json:"nbf,omitempty"` - Scopes []string `json:"scope"` - Email string `json:"email,omitempty"` - Name string `json:"name,omitempty"` - HostedDomain string `json:"hd,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - Extra map[string]interface{} -} +// Error sentinels mirrored from pkg/oauth. errors.Is across the alias works +// because errors.Is unwraps to the same underlying error value. +var ( + ErrMissingOAuthToken = oauth.ErrMissingToken + ErrInvalidOAuthToken = oauth.ErrInvalidToken + ErrOAuthTokenExpired = oauth.ErrTokenExpired + ErrOAuthInsufficientScopes = oauth.ErrInsufficientScopes + ErrOAuthEmailNotVerified = oauth.ErrEmailNotVerified + ErrOAuthUnauthorizedDomain = oauth.ErrUnauthorizedDomain +) -// ExtractOAuthTokenFromRequest extracts an OAuth bearer token from an HTTP -// request, per MCP authorization spec §Token Requirements: -// -// "MCP client MUST use the Authorization request header field defined in -// OAuth 2.1 §5.1.1: Authorization: Bearer " -// "Access tokens MUST NOT be included in the URI query string" -// -// Only the Authorization header is accepted. Earlier revisions of this server -// also honoured `x-oauth-token` and `x-altinity-oauth-token` for legacy -// clients; those have been removed for spec conformance. +// OAuthTokenKey / OAuthClaimsKey re-export the pkg/oauth context keys so +// values stored under one are readable via the other. Declared as vars (not +// const) because contextKey is a value-type from pkg/oauth's perspective. +var ( + OAuthTokenKey any = oauth.TokenKey + OAuthClaimsKey any = oauth.ClaimsKey +) + +// ExtractOAuthTokenFromRequest reads the Authorization: Bearer header. func (s *ClickHouseJWEServer) ExtractOAuthTokenFromRequest(r *http.Request) string { - authHeader := r.Header.Get("Authorization") - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimPrefix(authHeader, "Bearer ") - } - return "" + return oauth.ExtractTokenFromRequest(r) } -// ExtractOAuthTokenFromCtx extracts an OAuth token from context +// ExtractOAuthTokenFromCtx returns the OAuth token stored on ctx by the auth +// injector, or "" if none. func (s *ClickHouseJWEServer) ExtractOAuthTokenFromCtx(ctx context.Context) string { - if tokenFromCtx := ctx.Value(OAuthTokenKey); tokenFromCtx != nil { - if tokenStr, ok := tokenFromCtx.(string); ok { - return tokenStr - } - } - return "" + return oauth.TokenFromContext(ctx) } // oauthRequiresLocalValidation reports whether the auth layer should call -// ValidateOAuthToken on inbound bearers. We always do, in both gating and -// forward modes. Forward-mode JWTs are validated locally (signature + iss + -// aud + exp) per MCP authorization spec §Token Handling and §Access Token -// Privilege Restriction ("MCP servers MUST validate access tokens" / -// "MUST only accept tokens specifically intended for themselves"). -// ValidateOAuthToken itself decides what kind of validation applies for the -// configured mode and token shape. +// ValidateOAuthToken on inbound bearers. Delegates to the Verifier. func (s *ClickHouseJWEServer) oauthRequiresLocalValidation() bool { - return s.Config.Server.OAuth.Enabled + return s.verifier().RequiresLocalValidation() } -// ValidateOAuthToken validates an OAuth bearer and returns claims. -// -// Both modes route through the JWKS-based external-JWT validator: under -// gating, MCP is a pure resource server and the bearer is an upstream IdP -// (Auth0) access token; under forward, MCP proxies the upstream IdP token to -// the client unchanged. In both cases local validation is signature + iss + -// aud + exp against the configured JWKS. -// -// Two cases soft-pass (return nil claims, nil error) — the auth layer accepts -// the request and forwards to ClickHouse, which is then the sole validator: -// -// 1. Opaque (non-JWT) bearers — RFC 7662 introspection is not implemented; -// local validation isn't possible. -// 2. JWT bearers with neither Issuer nor JWKSURL configured — operator -// hasn't told us where to fetch verification keys. -// -// Soft-pass preserves compatibility with deployments that pre-date C-1 and -// rely entirely on ClickHouse-side validation. Operators who want full -// C-1 coverage set Issuer or JWKSURL; warnOAuthMisconfiguration nudges -// them at startup. +// ValidateOAuthToken validates an OAuth bearer and returns claims. See +// pkg/oauth.Verifier.ValidateToken for the full contract (including the two +// soft-pass cases). func (s *ClickHouseJWEServer) ValidateOAuthToken(token string) (*OAuthClaims, error) { - if !s.Config.Server.OAuth.Enabled { - return nil, nil - } - - if token == "" { - return nil, ErrMissingOAuthToken - } - - mode := s.Config.Server.OAuth.NormalizedMode() - if !looksLikeJWT(token) { - if s.Config.Server.OAuth.IsGatingMode() { - log.Error().Str("mode", mode).Msg("OAuth token is not a JWT; gating mode requires a signed JWT from the upstream AS") - return nil, ErrInvalidOAuthToken - } - log.Debug().Str("mode", mode).Msg("Bearer is opaque (not a JWT); skipping local validation, deferring to ClickHouse") - return nil, nil - } - if strings.TrimSpace(s.Config.Server.OAuth.JWKSURL) == "" && strings.TrimSpace(s.Config.Server.OAuth.Issuer) == "" { - log.Debug().Str("mode", mode).Msg("JWT received but neither oauth_issuer nor jwks_url is configured; skipping local validation") - return nil, nil - } - claims, err := s.parseAndVerifyOAuthToken(token, s.Config.Server.OAuth.Audience) - if err != nil { - log.Error().Err(err).Str("mode", mode).Msg("Failed to validate OAuth token") - return nil, err - } - - return s.validateOAuthClaims(claims) + return s.verifier().ValidateToken(context.Background(), token) } -func (s *ClickHouseJWEServer) validateOAuthClaims(claims *OAuthClaims) (*OAuthClaims, error) { - // Issuer enforcement happens upstream in parseAndVerifyExternalJWT, which - // is the only path that reaches here. It already validates `iss` against - // UpstreamIssuerAllowlist (preferred) or the singular `Issuer` config — - // re-validating here would duplicate the check and incorrectly reject - // tokens issued under a multi-issuer allowlist (where the singular - // `Issuer` field is not authoritative). - - // Validate audience if configured. Compare slash-normalised — the token's - // `aud` claim is whatever string the client passed in `resource` at - // /authorize (RFC 8707), so it may legitimately differ in trailing slash - // from the operator's configured Audience. Either form is acceptable. - if s.Config.Server.OAuth.Audience != "" { - if len(claims.Audience) == 0 { - log.Error().Str("expected", s.Config.Server.OAuth.Audience).Msg("OAuth token missing audience claim") - return nil, ErrInvalidOAuthToken - } - if !audienceMatchesResource(claims.Audience, s.Config.Server.OAuth.Audience) { - log.Error().Str("expected", s.Config.Server.OAuth.Audience).Strs("got", claims.Audience).Msg("OAuth token audience mismatch") - return nil, ErrInvalidOAuthToken - } - } - - now := time.Now().Unix() - if claims.ExpiresAt > 0 && now > claims.ExpiresAt+oauthClockSkewSecs { - log.Error().Int64("exp", claims.ExpiresAt).Msg("OAuth token expired") - return nil, ErrOAuthTokenExpired - } - if claims.NotBefore > 0 && now+oauthClockSkewSecs < claims.NotBefore { - log.Error().Int64("nbf", claims.NotBefore).Msg("OAuth token not yet valid") - return nil, ErrInvalidOAuthToken - } - if claims.IssuedAt > 0 && claims.IssuedAt > now+oauthClockSkewSecs { - log.Error().Int64("iat", claims.IssuedAt).Msg("OAuth token issued in the future") - return nil, ErrInvalidOAuthToken - } - - if len(s.Config.Server.OAuth.RequiredScopes) > 0 { - if !hasRequiredScopes(claims.Scopes, s.Config.Server.OAuth.RequiredScopes) { - log.Error().Strs("required", s.Config.Server.OAuth.RequiredScopes).Strs("got", claims.Scopes).Msg("OAuth token missing required scopes") - return nil, ErrOAuthInsufficientScopes - } - } - - if err := s.validateOAuthIdentityPolicy(claims); err != nil { - return nil, err - } - - return claims, nil +// ValidateUpstreamIdentityToken parses an upstream identity token (no +// soft-pass) and applies the identity policy. Used by the broker on /callback. +func (s *ClickHouseJWEServer) ValidateUpstreamIdentityToken(token, expectedAudience string) (*OAuthClaims, error) { + return s.verifier().ValidateUpstreamIdentityToken(context.Background(), token, expectedAudience) } -func (s *ClickHouseJWEServer) validateOAuthIdentityPolicy(claims *OAuthClaims) error { - oauthCfg := s.Config.Server.OAuth - if !oauthCfg.AllowUnverifiedEmail && claims.Email != "" && !claims.EmailVerified { - log.Error().Str("email", claims.Email).Msg("OAuth identity email is not verified") - return ErrOAuthEmailNotVerified - } - - if len(oauthCfg.AllowedEmailDomains) > 0 { - domain := emailDomain(claims.Email) - if domain == "" || !containsDomain(oauthCfg.AllowedEmailDomains, domain) { - log.Error().Str("email", claims.Email).Strs("allowed_domains", oauthCfg.AllowedEmailDomains).Msg("OAuth identity email domain is not allowed") - return ErrOAuthUnauthorizedDomain - } - } - - if len(oauthCfg.AllowedHostedDomains) > 0 { - if claims.HostedDomain == "" || !containsDomain(oauthCfg.AllowedHostedDomains, claims.HostedDomain) { - log.Error().Str("hosted_domain", claims.HostedDomain).Strs("allowed_hosted_domains", oauthCfg.AllowedHostedDomains).Msg("OAuth identity hosted domain is not allowed") - return ErrOAuthUnauthorizedDomain - } - } - - return nil -} - -// ValidateOAuthIdentityPolicyClaims applies configured post-verification identity policy checks. +// ValidateOAuthIdentityPolicyClaims re-runs the identity policy checks on an +// already-parsed claim set. Used after the broker exchanges an opaque +// access_token for a userinfo response. func (s *ClickHouseJWEServer) ValidateOAuthIdentityPolicyClaims(claims *OAuthClaims) error { - return s.validateOAuthIdentityPolicy(claims) -} - -func emailDomain(email string) string { - parts := strings.Split(strings.ToLower(strings.TrimSpace(email)), "@") - if len(parts) != 2 { - return "" - } - return parts[1] -} - -func containsDomain(domains []string, target string) bool { - for _, domain := range domains { - if strings.EqualFold(strings.TrimSpace(domain), strings.TrimSpace(target)) { - return true - } - } - return false -} - -func containsString(values []string, target string) bool { - for _, value := range values { - if value == target { - return true - } - } - return false -} - -// audienceMatchesResource compares an incoming audience claim list against -// an expected resource URL with trailing-slash tolerance. RFC 9728's -// canonical form uses a trailing slash, but upstream IdPs (and prior -// altinity-mcp metadata responses) sometimes emit the form without one, -// so we match both. Falls back to exact match if either side isn't a URL. -func audienceMatchesResource(claims []string, expected string) bool { - expectedTrimmed := strings.TrimRight(strings.TrimSpace(expected), "/") - for _, c := range claims { - if c == expected { - return true - } - if strings.TrimRight(strings.TrimSpace(c), "/") == expectedTrimmed { - return true - } - } - return false -} - -func looksLikeJWT(token string) bool { - return strings.Count(token, ".") == 2 -} - -func (s *ClickHouseJWEServer) parseAndVerifyOAuthToken(token string, expectedAudience string) (*OAuthClaims, error) { - if looksLikeJWT(token) { - return s.parseAndVerifyExternalJWT(token, expectedAudience) - } - return nil, fmt.Errorf("%w: opaque bearer tokens are not supported without token introspection", ErrInvalidOAuthToken) -} - -func (s *ClickHouseJWEServer) parseAndVerifyExternalJWT(token string, expectedAudience string) (*OAuthClaims, error) { - jwksURI, err := s.resolveOAuthJWKSURL() - if err != nil { - return nil, err - } - - parsed, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{ - jose.RS256, jose.RS384, jose.RS512, - jose.ES256, jose.ES384, jose.ES512, - jose.PS256, jose.PS384, jose.PS512, - jose.EdDSA, - }) - if err != nil { - return nil, fmt.Errorf("failed to parse signed JWT: %w", err) - } - if len(parsed.Headers) == 0 { - return nil, fmt.Errorf("missing JWT header") - } - - keySet, err := s.fetchOAuthJWKSet(jwksURI) - if err != nil { - return nil, err - } - - keys := keySet.Keys - keyID := parsed.Headers[0].KeyID - if keyID != "" { - keys = keySet.Key(keyID) - if len(keys) == 0 { - // kid absent from the cached JWKS — the AS may have rotated its - // signing key since the last fetch. Invalidate the cache and - // retry once before giving up. - s.jwksCacheMu.Lock() - s.jwksCacheTime = time.Time{} - s.jwksCacheMu.Unlock() - keySet, err = s.fetchOAuthJWKSet(jwksURI) - if err != nil { - return nil, err - } - keys = keySet.Key(keyID) - if len(keys) == 0 { - return nil, fmt.Errorf("no JWK found for kid %q", keyID) - } - log.Info().Str("kid", keyID).Msg("oauth: JWKS re-fetched after key rotation; new kid found") - } - } - - // Issuer enforcement: when the operator has configured - // UpstreamIssuerAllowlist, require the token's `iss` to be in that set - // (multi-tenant deployments). Otherwise fall back to the singular - // `Issuer` config field for the standard single-tenant case. If neither - // is set, no issuer check happens (caller's responsibility to configure). - allowlist := s.Config.Server.OAuth.UpstreamIssuerAllowlist - expectedIssuer := strings.TrimSpace(s.Config.Server.OAuth.Issuer) - var ( - rawClaims map[string]interface{} - signatureVerified bool - issuerRejected bool - audienceRejected bool - ) - for _, key := range keys { - rawClaims = make(map[string]interface{}) - if err := parsed.Claims(key.Key, &rawClaims); err != nil { - continue - } - signatureVerified = true - claims := oauthClaimsFromRawClaims(rawClaims) - if !issuerAllowed(claims.Issuer, allowlist, expectedIssuer) { - issuerRejected = true - continue - } - if expectedAudience != "" && !audienceMatchesResource(claims.Audience, expectedAudience) { - audienceRejected = true - continue - } - return claims, nil - } - if signatureVerified && (issuerRejected || audienceRejected) { - return nil, ErrInvalidOAuthToken - } - - return nil, fmt.Errorf("failed to verify JWT signature with discovered JWKs") -} - -// issuerAllowed implements the issuer policy used in upstream-token validation: -// when UpstreamIssuerAllowlist is non-empty, the token's iss MUST be one of -// the listed values (multi-tenant). Otherwise, when a singular Issuer is -// configured, the token's iss MUST match it (single-tenant). With neither set, -// no issuer check is performed (the caller is responsible for configuring at -// least one of these — see warnOAuthMisconfiguration). -// -// Comparison is slash-normalised on both sides — operator config and the -// token's `iss` may legitimately differ in trailing slash (e.g. Auth0 emits -// the form with a slash; some configs omit it). Matches the convention used -// by validateOAuthClaims for self-issued tokens. -func issuerAllowed(got string, allowlist []string, singleIssuer string) bool { - norm := func(s string) string { return strings.TrimRight(strings.TrimSpace(s), "/") } - got = norm(got) - if len(allowlist) > 0 { - for _, allowed := range allowlist { - if norm(allowed) == got { - return true - } - } - return false - } - if norm(singleIssuer) != "" { - return got == norm(singleIssuer) - } - return true + return s.verifier().ValidateIdentityPolicyClaims(claims) } -func (s *ClickHouseJWEServer) ValidateUpstreamIdentityToken(token string, expectedAudience string) (*OAuthClaims, error) { - claims, err := s.parseAndVerifyExternalJWT(token, expectedAudience) - if err != nil { - return nil, err - } - return claims, s.ValidateOAuthIdentityPolicyClaims(claims) -} - -func (s *ClickHouseJWEServer) resolveOAuthJWKSURL() (string, error) { - if strings.TrimSpace(s.Config.Server.OAuth.JWKSURL) != "" { - return strings.TrimSpace(s.Config.Server.OAuth.JWKSURL), nil - } - if strings.TrimSpace(s.Config.Server.OAuth.Issuer) == "" { - return "", fmt.Errorf("oauth issuer or jwks_url must be configured") - } - discovery, err := s.fetchOpenIDConfiguration(strings.TrimSpace(s.Config.Server.OAuth.Issuer)) - if err != nil { - return "", err - } - if strings.TrimSpace(discovery.JWKSURI) == "" { - return "", fmt.Errorf("openid discovery did not return jwks_uri") - } - return strings.TrimSpace(discovery.JWKSURI), nil -} - -func (s *ClickHouseJWEServer) fetchOpenIDConfiguration(issuer string) (*OpenIDConfiguration, error) { - issuer = strings.TrimRight(strings.TrimSpace(issuer), "/") - if issuer == "" { - return nil, fmt.Errorf("issuer is required") - } - - s.oidcConfigMu.RLock() - if s.oidcConfigCacheURL == issuer && !s.oidcConfigTime.IsZero() && s.oidcConfigTime.Add(oauthJWKSCacheTTL).After(time.Now()) && s.oidcConfigCache.Issuer != "" { - cached := s.oidcConfigCache - s.oidcConfigMu.RUnlock() - return &cached, nil - } - s.oidcConfigMu.RUnlock() - - urls := []string{ - issuer + "/.well-known/openid-configuration", - } - if !strings.Contains(issuer, "/.well-known/") { - urls = append(urls, issuer+"/.well-known/oauth-authorization-server") - } - - client := &http.Client{Timeout: oauthHTTPTimeout} - for _, metadataURL := range urls { - resp, err := client.Get(metadataURL) - if err != nil { - continue - } - body, readErr := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if closeErr := resp.Body.Close(); closeErr != nil { - log.Warn().Stack().Err(closeErr).Msgf("can't close %s response body", metadataURL) - } - if resp.StatusCode >= 300 || readErr != nil { - continue - } - var discovery OpenIDConfiguration - if err := json.Unmarshal(body, &discovery); err == nil { - s.oidcConfigMu.Lock() - s.oidcConfigCache = discovery - s.oidcConfigCacheURL = issuer - s.oidcConfigTime = time.Now() - s.oidcConfigMu.Unlock() - return &discovery, nil - } - } - - return nil, fmt.Errorf("failed to discover openid configuration for issuer %q", issuer) -} - -// FetchOpenIDConfiguration returns the discovered OIDC metadata for the configured issuer. +// FetchOpenIDConfiguration returns the discovered OIDC metadata subset the +// broker needs. Composes go-sdk's auth-server-metadata discovery with our +// surgical userinfo_endpoint fallback (oauthex.AuthServerMeta is RFC 8414 +// only and does not include userinfo_endpoint). func (s *ClickHouseJWEServer) FetchOpenIDConfiguration(issuer string) (*OpenIDConfiguration, error) { - return s.fetchOpenIDConfiguration(issuer) -} - -func (s *ClickHouseJWEServer) fetchOAuthJWKSet(jwksURI string) (*jose.JSONWebKeySet, error) { - now := time.Now() - - s.jwksCacheMu.RLock() - if len(s.jwksCache.Keys) > 0 && s.jwksCacheURL == jwksURI && s.jwksCacheTime.Add(oauthJWKSCacheTTL).After(now) { - cached := s.jwksCache - s.jwksCacheMu.RUnlock() - return &cached, nil - } - s.jwksCacheMu.RUnlock() - - resp, err := (&http.Client{Timeout: oauthHTTPTimeout}).Get(jwksURI) + ctx := context.Background() + v := s.verifier() + asMeta, err := v.FetchAuthServerMeta(ctx, issuer) if err != nil { - return nil, fmt.Errorf("failed to fetch jwks: %w", err) - } - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - log.Warn().Stack().Err(err).Msgf("can't close %s response body", jwksURI) - } - }() - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("failed to read jwks response: %w", err) - } - if resp.StatusCode >= 300 { - return nil, fmt.Errorf("jwks endpoint returned status %d", resp.StatusCode) - } - - var keySet jose.JSONWebKeySet - if err := json.Unmarshal(body, &keySet); err != nil { - return nil, fmt.Errorf("failed to parse jwks response: %w", err) - } - - s.jwksCacheMu.Lock() - s.jwksCache = keySet - s.jwksCacheURL = jwksURI - s.jwksCacheTime = now - s.jwksCacheMu.Unlock() - - return &keySet, nil -} - -func oauthClaimsFromRawClaims(rawClaims map[string]interface{}) *OAuthClaims { - claims := &OAuthClaims{ - Extra: make(map[string]interface{}), - } - - if sub, ok := rawClaims["sub"].(string); ok { - claims.Subject = sub - } - if iss, ok := rawClaims["iss"].(string); ok { - claims.Issuer = iss - } - if exp, ok := rawClaims["exp"].(float64); ok { - claims.ExpiresAt = int64(exp) - } - if exp, ok := rawClaims["exp"].(json.Number); ok { - if n, err := exp.Int64(); err == nil { - claims.ExpiresAt = n - } - } - if iat, ok := rawClaims["iat"].(float64); ok { - claims.IssuedAt = int64(iat) - } - if iat, ok := rawClaims["iat"].(json.Number); ok { - if n, err := iat.Int64(); err == nil { - claims.IssuedAt = n - } - } - if nbf, ok := rawClaims["nbf"].(float64); ok { - claims.NotBefore = int64(nbf) - } - if nbf, ok := rawClaims["nbf"].(json.Number); ok { - if n, err := nbf.Int64(); err == nil { - claims.NotBefore = n - } - } - if email, ok := rawClaims["email"].(string); ok { - claims.Email = email - } - if name, ok := rawClaims["name"].(string); ok { - claims.Name = name - } - if hd, ok := rawClaims["hd"].(string); ok { - claims.HostedDomain = hd - } - if emailVerified, ok := rawClaims["email_verified"].(bool); ok { - claims.EmailVerified = emailVerified - } - if emailVerified, ok := rawClaims["email_verified"].(string); ok { - claims.EmailVerified = strings.EqualFold(emailVerified, "true") - } - - switch aud := rawClaims["aud"].(type) { - case string: - claims.Audience = []string{aud} - case []interface{}: - for _, a := range aud { - if audStr, ok := a.(string); ok { - claims.Audience = append(claims.Audience, audStr) - } - } - } - - switch scope := rawClaims["scope"].(type) { - case string: - claims.Scopes = strings.Fields(scope) - case []interface{}: - for _, s := range scope { - if scopeStr, ok := s.(string); ok { - claims.Scopes = append(claims.Scopes, scopeStr) - } - } - } - - standardClaims := map[string]bool{ - "sub": true, "iss": true, "aud": true, "exp": true, "iat": true, "nbf": true, "jti": true, - "scope": true, "email": true, "name": true, "hd": true, "email_verified": true, - } - for k, v := range rawClaims { - if !standardClaims[k] { - claims.Extra[k] = v - continue - } - } - - return claims -} - -// hasRequiredScopes checks if all required scopes are present -func hasRequiredScopes(tokenScopes, requiredScopes []string) bool { - scopeSet := make(map[string]bool) - for _, s := range tokenScopes { - scopeSet[s] = true - } - for _, required := range requiredScopes { - if !scopeSet[required] { - return false - } + return nil, err } - return true + userInfo, _ := v.ResolveUserInfoEndpoint(ctx, issuer) + return &OpenIDConfiguration{ + Issuer: asMeta.Issuer, + AuthorizationEndpoint: asMeta.AuthorizationEndpoint, + TokenEndpoint: asMeta.TokenEndpoint, + JWKSURI: asMeta.JWKSURI, + UserInfoEndpoint: userInfo, + }, nil +} + +// verifier returns the lazily-initialised Verifier. The single +// NewClickHouseMCPServer construction path initialises s.oauthVerifier +// up-front, but tests construct ClickHouseJWEServer directly via struct +// literal, so this getter falls back to building one on demand. +func (s *ClickHouseJWEServer) verifier() *oauth.Verifier { + if s.oauthVerifier == nil { + s.oauthVerifier = oauth.NewVerifier(s.Config.Server.OAuth) + } + return s.oauthVerifier } diff --git a/pkg/server/server_auth_oauth_test.go b/pkg/server/server_auth_oauth_test.go index 952119c..e3e384b 100644 --- a/pkg/server/server_auth_oauth_test.go +++ b/pkg/server/server_auth_oauth_test.go @@ -669,25 +669,22 @@ func TestOAuthRequiresLocalValidation(t *testing.T) { }) } -// TestOAuthUpstreamIssuerAllowlist verifies that the operator-configured -// allowlist actually constrains upstream IdP token validation. Before this -// fix, the field was loaded into config but no handler consulted it — operators -// who set it for hardening got zero enforcement. -func TestOAuthUpstreamIssuerAllowlist(t *testing.T) { +// TestOAuthIssuerEnforcement verifies the singular-Issuer single-tenant policy. +func TestOAuthIssuerEnforcement(t *testing.T) { t.Parallel() - t.Run("token_from_allowlisted_issuer_accepted", func(t *testing.T) { + t.Run("token_from_configured_issuer_accepted", func(t *testing.T) { t.Parallel() provider := newTestOAuthProvider(t, nil) srv := &ClickHouseJWEServer{ Config: config.Config{ Server: config.ServerConfig{ OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - UpstreamIssuerAllowlist: []string{provider.server.URL, "https://other.example.com"}, - JWKSURL: provider.server.URL + "/jwks", - Audience: "clickhouse-api", + Enabled: true, + Mode: "forward", + Issuer: provider.server.URL, + JWKSURL: provider.server.URL + "/jwks", + Audience: "clickhouse-api", }, }, }, @@ -703,18 +700,18 @@ func TestOAuthUpstreamIssuerAllowlist(t *testing.T) { require.Equal(t, provider.server.URL, claims.Issuer) }) - t.Run("token_from_non_allowlisted_issuer_rejected", func(t *testing.T) { + t.Run("token_from_other_issuer_rejected", func(t *testing.T) { t.Parallel() provider := newTestOAuthProvider(t, nil) srv := &ClickHouseJWEServer{ Config: config.Config{ Server: config.ServerConfig{ OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - UpstreamIssuerAllowlist: []string{"https://only-this-one.example.com"}, - JWKSURL: provider.server.URL + "/jwks", - Audience: "clickhouse-api", + Enabled: true, + Mode: "forward", + Issuer: "https://only-this-one.example.com", + JWKSURL: provider.server.URL + "/jwks", + Audience: "clickhouse-api", }, }, }, @@ -729,24 +726,18 @@ func TestOAuthUpstreamIssuerAllowlist(t *testing.T) { require.ErrorIs(t, err, ErrInvalidOAuthToken) }) - t.Run("allowlist_takes_precedence_over_singular_issuer", func(t *testing.T) { - // When both Issuer (singular) and UpstreamIssuerAllowlist are set, the - // allowlist wins. The singular Issuer is still used for OIDC/JWKS - // discovery if no JWKSURL is configured, but for *token validation* - // the allowlist is authoritative — otherwise the allowlist would be - // useless in single-issuer-but-multi-tenant deployments. + t.Run("issuer_match_is_trailing_slash_tolerant", func(t *testing.T) { t.Parallel() provider := newTestOAuthProvider(t, nil) srv := &ClickHouseJWEServer{ Config: config.Config{ Server: config.ServerConfig{ OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - Issuer: "https://something-else.example.com", - UpstreamIssuerAllowlist: []string{provider.server.URL}, - JWKSURL: provider.server.URL + "/jwks", - Audience: "clickhouse-api", + Enabled: true, + Mode: "forward", + Issuer: provider.server.URL + "/", + JWKSURL: provider.server.URL + "/jwks", + Audience: "clickhouse-api", }, }, }, @@ -763,35 +754,6 @@ func TestOAuthUpstreamIssuerAllowlist(t *testing.T) { }) } -func TestIssuerAllowed(t *testing.T) { - t.Parallel() - cases := []struct { - name string - got string - allowlist []string - singleIssuer string - want bool - }{ - {"exact match in allowlist", "https://idp.example.com/", []string{"https://idp.example.com/"}, "", true}, - {"got has slash, allowlist entry doesn't", "https://idp.example.com/", []string{"https://idp.example.com"}, "", true}, - {"got missing slash, allowlist entry has it", "https://idp.example.com", []string{"https://idp.example.com/"}, "", true}, - {"allowlist with surrounding whitespace", "https://idp.example.com", []string{" https://idp.example.com/ "}, "", true}, - {"allowlist non-empty and got not in it", "https://attacker.example.com", []string{"https://idp.example.com/"}, "", false}, - {"allowlist takes precedence over singular issuer", "https://idp.example.com/", []string{"https://other.example.com/"}, "https://idp.example.com/", false}, - {"singular issuer match with mixed slash", "https://idp.example.com", nil, "https://idp.example.com/", true}, - {"singular issuer mismatch", "https://attacker.example.com", nil, "https://idp.example.com/", false}, - {"no allowlist, no single issuer accepts everything", "https://anything.example.com", nil, "", true}, - {"no allowlist, blank single issuer accepts everything", "https://anything.example.com", nil, " ", true}, - } - for _, c := range cases { - c := c - t.Run(c.name, func(t *testing.T) { - t.Parallel() - require.Equal(t, c.want, issuerAllowed(c.got, c.allowlist, c.singleIssuer)) - }) - } -} - // TestOAuthBuildClickHouseHeaders tests building ClickHouse headers from OAuth func TestOAuthBuildClickHouseHeaders(t *testing.T) { t.Parallel() @@ -1874,319 +1836,6 @@ func TestOAuthOpenAPIFullFlow(t *testing.T) { }) } -func TestResolveOAuthJWKSURL(t *testing.T) { - t.Parallel() - t.Run("direct_jwks_url_configured", func(t *testing.T) { - t.Parallel() - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - JWKSURL: "https://auth.example.com/jwks", - }, - }, - }, - } - url, err := srv.resolveOAuthJWKSURL() - require.NoError(t, err) - require.Equal(t, "https://auth.example.com/jwks", url) - }) - - t.Run("openid_configuration_discovery", func(t *testing.T) { - t.Parallel() - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/.well-known/openid-configuration" { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]string{ - "issuer": "https://auth.example.com", - "jwks_uri": "https://auth.example.com/keys", - }) - return - } - http.NotFound(w, r) - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - }, - }, - }, - } - url, err := srv.resolveOAuthJWKSURL() - require.NoError(t, err) - require.Equal(t, "https://auth.example.com/keys", url) - }) - - t.Run("fallback_to_oauth_authorization_server", func(t *testing.T) { - t.Parallel() - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/.well-known/openid-configuration" { - http.NotFound(w, r) - return - } - if r.URL.Path == "/.well-known/oauth-authorization-server" { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]string{ - "issuer": "https://auth.example.com", - "jwks_uri": "https://auth.example.com/fallback-keys", - }) - return - } - http.NotFound(w, r) - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - }, - }, - }, - } - url, err := srv.resolveOAuthJWKSURL() - require.NoError(t, err) - require.Equal(t, "https://auth.example.com/fallback-keys", url) - }) - - t.Run("both_discovery_endpoints_fail", func(t *testing.T) { - t.Parallel() - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.NotFound(w, r) - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - }, - }, - }, - } - _, err := srv.resolveOAuthJWKSURL() - require.Error(t, err) - require.Contains(t, err.Error(), "failed to discover") - }) - - t.Run("discovery_missing_jwks_uri", func(t *testing.T) { - t.Parallel() - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]string{ - "issuer": "https://auth.example.com", - }) - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - }, - }, - }, - } - _, err := srv.resolveOAuthJWKSURL() - require.Error(t, err) - require.Contains(t, err.Error(), "jwks_uri") - }) -} - -func TestOIDCConfigCaching(t *testing.T) { - t.Parallel() - var requestCount int - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]string{ - "issuer": "https://auth.example.com", - "jwks_uri": "https://auth.example.com/keys", - }) - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - }, - }, - }, - } - - // NOTE: subtests are NOT parallel — they share requestCount and srv cache state - t.Run("cache_hit_within_ttl", func(t *testing.T) { - requestCount = 0 - _, err := srv.FetchOpenIDConfiguration(mockServer.URL) - require.NoError(t, err) - _, err = srv.FetchOpenIDConfiguration(mockServer.URL) - require.NoError(t, err) - require.Equal(t, 1, requestCount, "second call should hit cache") - }) - - t.Run("cache_miss_after_ttl_expires", func(t *testing.T) { - // Ensure cache is populated - _, err := srv.FetchOpenIDConfiguration(mockServer.URL) - require.NoError(t, err) - - // Manipulate cache time to simulate TTL expiry - srv.oidcConfigMu.Lock() - srv.oidcConfigTime = time.Now().Add(-oauthJWKSCacheTTL - time.Second) - srv.oidcConfigMu.Unlock() - - countBefore := requestCount - _, err = srv.FetchOpenIDConfiguration(mockServer.URL) - require.NoError(t, err) - require.Equal(t, countBefore+1, requestCount, "should re-fetch after TTL expiry") - }) -} - -func TestParseAndVerifyExternalJWTUnknownKid(t *testing.T) { - t.Parallel() - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - // Create JWKS with kid "known" - knownJWK := jose.JSONWebKey{Key: &privateKey.PublicKey, KeyID: "known", Algorithm: "RS256", Use: "sig"} - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/.well-known/openid-configuration": - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]string{ - "issuer": r.Host, - "jwks_uri": "http://" + r.Host + "/jwks", - }) - case "/jwks": - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(jose.JSONWebKeySet{Keys: []jose.JSONWebKey{knownJWK}}) - default: - http.NotFound(w, r) - } - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - JWKSURL: mockServer.URL + "/jwks", - }, - }, - }, - } - - // Sign token with kid "unknown" - signer, err := jose.NewSigner( - jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, - (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "unknown"), - ) - require.NoError(t, err) - - payload, err := json.Marshal(map[string]interface{}{ - "sub": "user-1", - "iss": mockServer.URL, - "aud": "test-audience", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - require.NoError(t, err) - - object, err := signer.Sign(payload) - require.NoError(t, err) - token, err := object.CompactSerialize() - require.NoError(t, err) - - _, err = srv.parseAndVerifyExternalJWT(token, "test-audience") - require.Error(t, err) - require.Contains(t, err.Error(), "no JWK found for kid") -} - -// TestJWKSRefetchOnKidMiss verifies that a kid absent from the cached JWKS -// triggers a one-shot cache-bypass re-fetch, allowing tokens issued after -// a key rotation to be accepted without waiting for the TTL to expire. -func TestJWKSRefetchOnKidMiss(t *testing.T) { - t.Parallel() - - oldKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - newKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - const oldKid = "old-signing-key" - const newKid = "new-signing-key" - - // The mock JWKS endpoint always serves the new key. The test seeds - // the server's in-memory cache with the old key to simulate a stale - // cache from before the rotation. - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/jwks": - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ - {Key: &newKey.PublicKey, KeyID: newKid, Algorithm: "RS256", Use: "sig"}, - }}) - default: - http.NotFound(w, r) - } - })) - defer mockServer.Close() - - srv := &ClickHouseJWEServer{ - Config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Issuer: mockServer.URL, - JWKSURL: mockServer.URL + "/jwks", - }, - }, - }, - } - - // Seed the JWKS cache with the old key and a far-future TTL so that a - // normal fetch would not re-fetch. - srv.jwksCacheMu.Lock() - srv.jwksCache = jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ - {Key: &oldKey.PublicKey, KeyID: oldKid, Algorithm: "RS256", Use: "sig"}, - }} - srv.jwksCacheURL = mockServer.URL + "/jwks" - srv.jwksCacheTime = time.Now().Add(10 * time.Minute) // far future — won't expire naturally - srv.jwksCacheMu.Unlock() - - // Issue a JWT signed with the new key (kid = newKid). - signer, err := jose.NewSigner( - jose.SigningKey{Algorithm: jose.RS256, Key: newKey}, - (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", newKid), - ) - require.NoError(t, err) - payload, err := json.Marshal(map[string]interface{}{ - "sub": "user-1", - "iss": mockServer.URL, - "aud": "test-audience", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - require.NoError(t, err) - obj, err := signer.Sign(payload) - require.NoError(t, err) - token, err := obj.CompactSerialize() - require.NoError(t, err) - - // Should succeed: kid-miss triggers a cache-bypass re-fetch that finds newKid. - claims, err := srv.parseAndVerifyExternalJWT(token, "test-audience") - require.NoError(t, err) - require.Equal(t, "user-1", claims.Subject) -} - func TestGatingModeIdentityPolicy(t *testing.T) { t.Parallel() const gatingSecret = "test-gating-secret-32-byte-key!!" @@ -2250,135 +1899,6 @@ func TestGatingModeIdentityPolicy(t *testing.T) { require.ErrorIs(t, err, ErrOAuthUnauthorizedDomain) }) } - -// ---------- coverage gap tests ---------- - -func TestEmailDomain(t *testing.T) { - t.Parallel() - tests := []struct { - name string - email string - want string - }{ - {"normal", "user@example.com", "example.com"}, - {"uppercase", "User@EXAMPLE.COM", "example.com"}, - {"whitespace", " user@example.com ", "example.com"}, - {"no_at", "noatsign", ""}, - {"empty", "", ""}, - {"multiple_at", "a@b@c", ""}, - {"just_at", "@", ""}, - {"domain_only", "@domain.com", "domain.com"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.want, emailDomain(tt.email)) - }) - } -} - -func TestOAuthClaimsFromRawClaims(t *testing.T) { - t.Parallel() - - t.Run("all_standard_fields", func(t *testing.T) { - t.Parallel() - raw := map[string]interface{}{ - "sub": "user123", - "iss": "https://auth.example.com", - "exp": float64(1700000000), - "iat": float64(1699999000), - "nbf": float64(1699998000), - "email": "user@example.com", - "name": "Test User", - "hd": "example.com", - "email_verified": true, - "aud": "my-api", - "scope": "read write", - } - claims := oauthClaimsFromRawClaims(raw) - require.Equal(t, "user123", claims.Subject) - require.Equal(t, "https://auth.example.com", claims.Issuer) - require.Equal(t, int64(1700000000), claims.ExpiresAt) - require.Equal(t, int64(1699999000), claims.IssuedAt) - require.Equal(t, int64(1699998000), claims.NotBefore) - require.Equal(t, "user@example.com", claims.Email) - require.Equal(t, "Test User", claims.Name) - require.Equal(t, "example.com", claims.HostedDomain) - require.True(t, claims.EmailVerified) - require.Equal(t, []string{"my-api"}, claims.Audience) - require.Equal(t, []string{"read", "write"}, claims.Scopes) - }) - - t.Run("json_number_fields", func(t *testing.T) { - t.Parallel() - raw := map[string]interface{}{ - "sub": "user", - "exp": json.Number("1700000000"), - "iat": json.Number("1699999000"), - "nbf": json.Number("1699998000"), - } - claims := oauthClaimsFromRawClaims(raw) - require.Equal(t, int64(1700000000), claims.ExpiresAt) - require.Equal(t, int64(1699999000), claims.IssuedAt) - require.Equal(t, int64(1699998000), claims.NotBefore) - }) - - t.Run("audience_array", func(t *testing.T) { - t.Parallel() - raw := map[string]interface{}{ - "aud": []interface{}{"api1", "api2"}, - } - claims := oauthClaimsFromRawClaims(raw) - require.Equal(t, []string{"api1", "api2"}, claims.Audience) - }) - - t.Run("scope_array", func(t *testing.T) { - t.Parallel() - raw := map[string]interface{}{ - "scope": []interface{}{"read", "write", "admin"}, - } - claims := oauthClaimsFromRawClaims(raw) - require.Equal(t, []string{"read", "write", "admin"}, claims.Scopes) - }) - - t.Run("email_verified_string", func(t *testing.T) { - t.Parallel() - raw := map[string]interface{}{ - "email_verified": "true", - } - claims := oauthClaimsFromRawClaims(raw) - require.True(t, claims.EmailVerified) - - raw2 := map[string]interface{}{ - "email_verified": "false", - } - claims2 := oauthClaimsFromRawClaims(raw2) - require.False(t, claims2.EmailVerified) - }) - - t.Run("extra_claims_preserved", func(t *testing.T) { - t.Parallel() - raw := map[string]interface{}{ - "sub": "user", - "custom1": "value1", - "custom_num": float64(42), - } - claims := oauthClaimsFromRawClaims(raw) - require.Equal(t, "value1", claims.Extra["custom1"]) - require.Equal(t, float64(42), claims.Extra["custom_num"]) - _, hasSub := claims.Extra["sub"] - require.False(t, hasSub) - }) - - t.Run("empty_claims", func(t *testing.T) { - t.Parallel() - claims := oauthClaimsFromRawClaims(map[string]interface{}{}) - require.NotNil(t, claims) - require.Empty(t, claims.Subject) - require.NotNil(t, claims.Extra) - }) -} - func TestBuildClickHouseHeadersFromOAuth(t *testing.T) { t.Parallel() @@ -2491,114 +2011,3 @@ func TestBuildClickHouseHeadersFromOAuth(t *testing.T) { require.Equal(t, "false", headers["X-V"]) }) } - -func TestLooksLikeJWT(t *testing.T) { - t.Parallel() - require.True(t, looksLikeJWT("a.b.c")) - require.False(t, looksLikeJWT("not-a-jwt")) - require.False(t, looksLikeJWT("a.b")) - require.False(t, looksLikeJWT("a.b.c.d")) -} - -func TestValidateOAuthClaims(t *testing.T) { - t.Parallel() - - t.Run("audience_missing_when_required", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ - Audience: "my-audience", - }}}} - _, err := s.validateOAuthClaims(&OAuthClaims{}) - require.ErrorIs(t, err, ErrInvalidOAuthToken) - }) - - t.Run("audience_mismatch", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ - Audience: "my-audience", - }}}} - _, err := s.validateOAuthClaims(&OAuthClaims{Audience: []string{"wrong-audience"}}) - require.ErrorIs(t, err, ErrInvalidOAuthToken) - }) - - t.Run("audience_trailing_slash_tolerant", func(t *testing.T) { - t.Parallel() - // Configured without trailing slash, claim has one — and vice versa. - // Both must validate so the canonical /.well-known/oauth-protected-resource - // form (slash) and prior issued tokens (no slash) both round-trip. - cfg := config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ - Audience: "https://mcp.example.com", - }}} - s := &ClickHouseJWEServer{Config: cfg} - _, err := s.validateOAuthClaims(&OAuthClaims{ - Audience: []string{"https://mcp.example.com/"}, - ExpiresAt: time.Now().Unix() + 300, - }) - require.NoError(t, err) - - cfg.Server.OAuth.Audience = "https://mcp.example.com/" - s = &ClickHouseJWEServer{Config: cfg} - _, err = s.validateOAuthClaims(&OAuthClaims{ - Audience: []string{"https://mcp.example.com"}, - ExpiresAt: time.Now().Unix() + 300, - }) - require.NoError(t, err) - }) - - t.Run("token_expired", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{}}}} - _, err := s.validateOAuthClaims(&OAuthClaims{ExpiresAt: time.Now().Unix() - 300}) - require.ErrorIs(t, err, ErrOAuthTokenExpired) - }) - - t.Run("not_yet_valid", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{}}}} - _, err := s.validateOAuthClaims(&OAuthClaims{NotBefore: time.Now().Unix() + 300}) - require.ErrorIs(t, err, ErrInvalidOAuthToken) - }) - - t.Run("issued_in_future", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{}}}} - _, err := s.validateOAuthClaims(&OAuthClaims{IssuedAt: time.Now().Unix() + 300}) - require.ErrorIs(t, err, ErrInvalidOAuthToken) - }) - - t.Run("missing_required_scopes", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ - RequiredScopes: []string{"admin"}, - }}}} - _, err := s.validateOAuthClaims(&OAuthClaims{Scopes: []string{"read"}}) - require.ErrorIs(t, err, ErrOAuthInsufficientScopes) - }) - - t.Run("valid_claims", func(t *testing.T) { - t.Parallel() - s := &ClickHouseJWEServer{Config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ - Issuer: "https://issuer.example.com", - Audience: "my-aud", - RequiredScopes: []string{"read"}, - }}}} - claims, err := s.validateOAuthClaims(&OAuthClaims{ - Issuer: "https://issuer.example.com", - Audience: []string{"my-aud"}, - ExpiresAt: time.Now().Unix() + 300, - Scopes: []string{"read", "write"}, - }) - require.NoError(t, err) - require.Equal(t, "https://issuer.example.com", claims.Issuer) - }) - -} - -func TestHasRequiredScopes(t *testing.T) { - t.Parallel() - require.True(t, hasRequiredScopes([]string{"read", "write", "admin"}, []string{"read", "write"})) - require.False(t, hasRequiredScopes([]string{"read"}, []string{"read", "admin"})) - require.True(t, hasRequiredScopes([]string{"read"}, []string{})) - require.True(t, hasRequiredScopes([]string{}, []string{})) - require.False(t, hasRequiredScopes([]string{}, []string{"read"})) -} diff --git a/pkg/server/server_client.go b/pkg/server/server_client.go index faf6c96..d989ebf 100644 --- a/pkg/server/server_client.go +++ b/pkg/server/server_client.go @@ -2,7 +2,6 @@ package server import ( "context" - "encoding/json" "fmt" "net/http" "strings" @@ -10,6 +9,7 @@ import ( "github.com/altinity/altinity-mcp/pkg/clickhouse" "github.com/altinity/altinity-mcp/pkg/config" "github.com/altinity/altinity-mcp/pkg/jwe_auth" + "github.com/altinity/altinity-mcp/pkg/oauth" "github.com/rs/zerolog/log" ) @@ -114,76 +114,17 @@ func (s *ClickHouseJWEServer) GetJWEClaimsFromCtx(ctx context.Context) map[strin return nil } -// GetOAuthClaimsFromCtx extracts OAuth claims from context +// GetOAuthClaimsFromCtx extracts OAuth claims from context. Delegates to +// the pkg/oauth context helper; preserved for callers/tests that hold a +// *ClickHouseJWEServer rather than reaching for pkg/oauth directly. func (s *ClickHouseJWEServer) GetOAuthClaimsFromCtx(ctx context.Context) *OAuthClaims { - if claims := ctx.Value(OAuthClaimsKey); claims != nil { - if oauthClaims, ok := claims.(*OAuthClaims); ok { - return oauthClaims - } - } - return nil + return oauth.ClaimsFromContext(ctx) } -// BuildClickHouseHeadersFromOAuth builds HTTP headers to forward to ClickHouse based on OAuth config +// BuildClickHouseHeadersFromOAuth builds HTTP headers to forward to ClickHouse based on OAuth config. +// Thin wrapper around oauth.BuildClickHouseHeaders. func (s *ClickHouseJWEServer) BuildClickHouseHeadersFromOAuth(token string, claims *OAuthClaims) map[string]string { - if !s.Config.Server.OAuth.IsForwardMode() { - return nil - } - - headers := make(map[string]string) - - // Forward the access token (always in forward mode) - headerName := s.Config.Server.OAuth.ClickHouseHeaderName - if headerName == "" { - headerName = "Authorization" - } - if headerName == "Authorization" { - headers[headerName] = "Bearer " + token - } else { - headers[headerName] = token - } - - // Map claims to headers if configured - if len(s.Config.Server.OAuth.ClaimsToHeaders) > 0 && claims != nil { - for claimName, headerName := range s.Config.Server.OAuth.ClaimsToHeaders { - var value string - switch claimName { - case "sub": - value = claims.Subject - case "iss": - value = claims.Issuer - case "email": - value = claims.Email - case "name": - value = claims.Name - case "email_verified": - if claims.EmailVerified { - value = "true" - } else { - value = "false" - } - case "hd": - value = claims.HostedDomain - default: - // Check extra claims - if v, ok := claims.Extra[claimName]; ok { - if strVal, ok := v.(string); ok { - value = strVal - } else { - // Try to JSON encode non-string values - if jsonBytes, err := json.Marshal(v); err == nil { - value = string(jsonBytes) - } - } - } - } - if value != "" { - headers[headerName] = value - } - } - } - - return headers + return oauth.BuildClickHouseHeaders(s.Config.Server.OAuth, token, claims) } // ValidateAuth validates authentication using priority/fallback semantics. @@ -312,7 +253,7 @@ func (s *ClickHouseJWEServer) GetClickHouseClientWithOAuth(ctx context.Context, var impersonateAs string if e := strings.TrimSpace(oauthClaims.Email); e != "" { impersonateAs = e - } else if e := emailFromNamespacedExtra(oauthClaims.Extra); e != "" { + } else if e := oauth.EmailFromNamespacedExtra(oauthClaims.Extra); e != "" { impersonateAs = e } else if s := strings.TrimSpace(oauthClaims.Subject); s != "" { impersonateAs = s @@ -337,22 +278,3 @@ func (s *ClickHouseJWEServer) GetClickHouseClientWithOAuth(ctx context.Context, return client, nil } -// emailFromNamespacedExtra returns the first string-valued claim whose key -// ends with `/email` from the JWT's non-standard claim map. Auth0 third-party -// (DCR) tokens in enhanced security mode silently drop non-namespaced custom -// claims, forcing operators to set email under a URL-prefixed key (e.g. -// `https://mcp.altinity.cloud/email`). Looking up by suffix lets MCP accept -// any namespace the operator chose. -func emailFromNamespacedExtra(extra map[string]interface{}) string { - for k, v := range extra { - if !strings.HasSuffix(k, "/email") { - continue - } - if s, ok := v.(string); ok { - if t := strings.TrimSpace(s); t != "" { - return t - } - } - } - return "" -}