From e483a471dcf7c2a01e9fbeec519be029b9e811a0 Mon Sep 17 00:00:00 2001 From: hardikl Date: Fri, 15 May 2026 20:40:51 +0530 Subject: [PATCH 1/2] feat: poc for oauth support in mcp server - keycloak --- cmd/cmds.go | 10 ++- go.mod | 4 + go.sum | 6 ++ server/oauth.go | 188 +++++++++++++++++++++++++++++++++++++++++++++++ server/server.go | 136 ++++++++++++++++++++++++++++------ 5 files changed, 321 insertions(+), 23 deletions(-) create mode 100644 server/oauth.go diff --git a/cmd/cmds.go b/cmd/cmds.go index 4c2d649..425c408 100644 --- a/cmd/cmds.go +++ b/cmd/cmds.go @@ -15,8 +15,11 @@ import ( var logger = setupLogger() type Globals struct { - LogLevel string `enum:"debug,info,warn,error" default:"info" env:"LOG_LEVEL" help:"Log level, one of: ${enum}"` - ConfigPath string `name:"config" default:"ontap.yaml" env:"ONTAP_MCP_CONFIG" help:"ONTAP-MCP config path"` + LogLevel string `enum:"debug,info,warn,error" default:"info" env:"LOG_LEVEL" help:"Log level, one of: ${enum}"` + ConfigPath string `name:"config" default:"ontap.yaml" env:"ONTAP_MCP_CONFIG" help:"ONTAP-MCP config path"` + OAuthServerURL string `name:"oauth-server-url" env:"OAUTH_SERVER_URL" help:"Authorization Server URL"` + JwksURL string `name:"jwks-url" env:"JWKS_URL" help:"JWKS URL"` + ResourceURL string `name:"resource-url" env:"RESOURCE_URL" help:"Resource URL for this MCP server"` } type CLI struct { @@ -50,6 +53,9 @@ func (a *StartCmd) Run(cli *CLI) error { InspectTraffic: cli.Start.InspectTraffic, ReadOnly: cli.Start.ReadOnly, Stateless: cli.Start.Stateless, + OauthServerURL: cli.OAuthServerURL, + JwksURL: cli.JwksURL, + ResourceURL: cli.ResourceURL, } app := server.NewApp(cfg, opts, logger) diff --git a/go.mod b/go.mod index c9a2482..834f9d4 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/netapp/ontap-mcp go 1.26.2 require ( + github.com/MicahParks/keyfunc/v3 v3.8.0 github.com/alecthomas/kong v1.15.0 github.com/carlmjohnson/requests v0.25.1 github.com/goccy/go-yaml v1.19.2 @@ -10,6 +11,8 @@ require ( ) require ( + github.com/MicahParks/jwkset v0.11.0 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/google/jsonschema-go v0.4.3 // indirect github.com/segmentio/asm v1.2.1 // indirect github.com/segmentio/encoding v0.5.4 // indirect @@ -17,4 +20,5 @@ require ( golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.42.0 // indirect + golang.org/x/time v0.9.0 // indirect ) diff --git a/go.sum b/go.sum index 327abef..b374e6e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.8.0 h1:Hx2dgIjAXGk9slakM6rV9BOeaWDPEXXZ4Us8guNBfds= +github.com/MicahParks/keyfunc/v3 v3.8.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/kong v1.15.0 h1:BVJstKbpO73zKpmIu+m/aLRrNmWwxXPIGTNin9VmLVI= @@ -30,5 +34,7 @@ golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= diff --git a/server/oauth.go b/server/oauth.go new file mode 100644 index 0000000..2e96564 --- /dev/null +++ b/server/oauth.go @@ -0,0 +1,188 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "slices" + "strings" + "time" + + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" +) + +type OAuthConfig struct { + AuthServerURL string + JwksURL string + ResourceURL string + jwks keyfunc.Keyfunc +} + +func (c *OAuthConfig) InitJWKS() error { + jwks, err := keyfunc.NewDefault([]string{c.JwksURL}) + if err != nil { + return fmt.Errorf("failed to create client of JWKS: %w", err) + } + c.jwks = jwks + slog.Info("Initialized JWKS", slog.String("JwksURL", c.JwksURL)) + return nil +} + +func (c *OAuthConfig) OAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + c.sendUnauthorized(w, r) + return + } + + // Bearer token + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + c.sendUnauthorized(w, r) + return + } + + // Validate JWT token + token, err := jwt.Parse(tokenString, c.jwks.Keyfunc, jwt.WithValidMethods([]string{"RS256"})) + if err != nil { + slog.Error("failed to parse token", slog.Any("err", err)) + c.sendUnauthorized(w, r) + return + } + + if !token.Valid { + slog.Error("token is invalid") + c.sendUnauthorized(w, r) + return + } + + // Get claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + slog.Error("claim type invalid") + c.sendUnauthorized(w, r) + return + } + + slog.Debug("=== JWT Access Token Debug ===") + slog.Debug("", slog.String("token", tokenString)) + claimsJSON, _ := json.MarshalIndent(claims, "", " ") + slog.Debug("", slog.String("claim", string(claimsJSON))) + slog.Debug("===============================") + + // Validate audience + if !c.validateAudience(claims) { + slog.Error("audience invalid") + c.sendUnauthorized(w, r) + return + } + + // Validate issuer + if !c.validateIssuer(claims) { + slog.Error("issuer invalid") + c.sendUnauthorized(w, r) + return + } + + // Validate expiration + // Note: jwt.Parse already validates exp by default, but we explicitly check here for clarity + if !c.validateExpiration(claims) { + slog.Error("token has been expired") + c.sendUnauthorized(w, r) + return + } + + // Validate scope + if !c.validateScope(claims) { + slog.Error("scope insufficient") + c.sendUnauthorized(w, r) + return + } + + // Authorization successful + next.ServeHTTP(w, r) + }) +} + +func (c *OAuthConfig) validateAudience(claims jwt.MapClaims) bool { + aud, ok := claims["aud"] + if !ok { + return false + } + + // aud can be a string or array of strings + switch v := aud.(type) { + case string: + return v == c.ResourceURL + case []any: + for _, a := range v { + if audStr, ok := a.(string); ok && audStr == c.ResourceURL { + return true + } + } + return false + default: + return false + } +} + +func (c *OAuthConfig) validateIssuer(claims jwt.MapClaims) bool { + iss, ok := claims["iss"].(string) + if !ok { + return false + } + return iss == c.AuthServerURL +} + +func (c *OAuthConfig) validateExpiration(claims jwt.MapClaims) bool { + exp, ok := claims["exp"].(float64) + if !ok { + return false + } + // Allow 60 seconds of clock skew + return time.Now().Unix() < int64(exp)+60 +} + +func (c *OAuthConfig) validateScope(claims jwt.MapClaims) bool { + scope, ok := claims["scope"].(string) + if !ok { + return false + } + // Scope is a space-separated string (OAuth 2.0 standard) + // Check if "mcp:tools" is present + s := strings.Split(scope, " ") + return slices.Contains(s, "mcp:tools") +} + +func (c *OAuthConfig) sendUnauthorized(w http.ResponseWriter, _ *http.Request) { + metadataURL := c.ResourceURL + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", + fmt.Sprintf(`Bearer resource_metadata="%q", scope="openid profile email"`, metadataURL)) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} + +func LoggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + slog.Debug("", slog.String("method", r.Method), slog.String("urlPath", r.URL.Path), slog.String("RemoteAddr", r.RemoteAddr)) + + if r.Method == http.MethodPost && r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + slog.Error("Error reading body", slog.Any("err", err)) + } else { + slog.Debug("", slog.String("body", string(bodyBytes))) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + } + + next.ServeHTTP(w, r) + slog.Debug("request finished", slog.Any("duration", time.Since(start))) + }) +} diff --git a/server/server.go b/server/server.go index b5026ad..1980e78 100644 --- a/server/server.go +++ b/server/server.go @@ -1,11 +1,13 @@ package server import ( + "bufio" "bytes" "context" "encoding/json" "errors" "fmt" + "github.com/modelcontextprotocol/go-sdk/oauthex" "io" "log/slog" "net/http" @@ -37,6 +39,9 @@ type Options struct { Port int ReadOnly bool Stateless bool + OauthServerURL string + JwksURL string + ResourceURL string TestHTTPClient *http.Client // Optional HTTP client for testing } @@ -82,6 +87,7 @@ func (a *App) StartServer() { if a.options.Stateless { a.logger.Info("MCP server is running in stateless mode; mcp-session-id header validation is disabled") } + server := a.createMCPServer() a.runHTTPServer(server) } @@ -228,6 +234,20 @@ func (a *App) createMCPServer() *mcp.Server { func (a *App) runHTTPServer(server *mcp.Server) { var handler http.Handler + oauthConfig := &OAuthConfig{} + authServerURL, jwksURL, resourceURL, oAthExist := a.loadEnv() + + if oAthExist { + oauthConfig.AuthServerURL = authServerURL + oauthConfig.JwksURL = jwksURL + oauthConfig.ResourceURL = resourceURL + if err := oauthConfig.InitJWKS(); err != nil { + a.logger.Error("failed to initialize JWKS", slog.Any("err", err)) + } + a.logger.Info("MCP Server started with Oauth") + } else { + a.logger.Info("MCP Server started without any Oauth") + } address := a.options.Host + ":" + strconv.Itoa(a.options.Port) a.logger.Info("starting MCP server over HTTP transport", @@ -235,12 +255,6 @@ func (a *App) runHTTPServer(server *mcp.Server) { slog.String("host", a.options.Host), slog.Int("port", a.options.Port)) - // Health check endpoint - http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("OK")) - }) - handler = mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { return server }, &mcp.StreamableHTTPOptions{Stateless: a.options.Stateless}) @@ -267,24 +281,52 @@ func (a *App) runHTTPServer(server *mcp.Server) { handler = loggingHandler } - wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Skip MCP handler for health endpoint - if r.URL.Path == "/health" { - http.DefaultServeMux.ServeHTTP(w, r) - return - } + // Setup routing + mux := http.NewServeMux() - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Mcp-Protocol-Version, Mcp-Session-Id") + // Health check endpoint + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) - if r.Method == http.MethodOptions { + if oAthExist { + mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) - return - } + metadata := oauthex.ProtectedResourceMetadata{ + Resource: oauthConfig.ResourceURL, + ScopesSupported: []string{"mcp:tools"}, + AuthorizationServers: []string{oauthConfig.AuthServerURL}, + } - handler.ServeHTTP(w, r) - }) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + a.logger.Error("metadata encoding failed", slog.Any("error", err)) + } + }) + + // MCP endpoint (OAuth authorization required, with logging) + mux.Handle("/", LoggingMiddleware(oauthConfig.OAuthMiddleware(handler))) + } else { + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // Skip MCP handler for health endpoint + if r.URL.Path == "/health" { + http.DefaultServeMux.ServeHTTP(w, r) + return + } + + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Mcp-Protocol-Version, Mcp-Session-Id") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + handler.ServeHTTP(w, r) + }) + } //goland:noinspection HttpUrlsUsage a.logger.Info("MCP server endpoint available", slog.String("url", "http://"+address)) @@ -292,7 +334,7 @@ func (a *App) runHTTPServer(server *mcp.Server) { httpServer := &http.Server{ Addr: address, - Handler: wrappedHandler, + Handler: mux, ReadHeaderTimeout: 60 * time.Second, IdleTimeout: 60 * time.Second, } @@ -301,6 +343,7 @@ func (a *App) runHTTPServer(server *mcp.Server) { a.logger.Error("http server failed to start", slog.String("error", err.Error())) os.Exit(1) } + a.logger.Info("mcp server shutdown gracefully") } @@ -742,3 +785,54 @@ func parseSize(size string) (int64, error) { return 0, fmt.Errorf("invalid size format '%s'. Use '100MB', '2GB', '1TB', or raw bytes", size) } + +func (a *App) loadEnv() (string, string, string, bool) { + file, err := os.Open(".ontap-mcp.env") + if err != nil { + // .ontap-mcp.env file doesn't exist, that's okay + slog.Warn(".ontap-mcp.env file not exist", slog.Any("error", err)) + } + defer func() { + if err := file.Close(); err != nil { + slog.Warn("failed to close file", slog.Any("error", err)) + } + }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + if os.Getenv(key) == "" { + if err = os.Setenv(key, value); err != nil { + // Log the error and proceed further + slog.Error("Error setting environment variable", slog.String("key", key), slog.String("value", value), slog.Any("err", err)) + } + } + } + } + + oauthServerURL := os.Getenv("OAUTH_SERVER_URL") + if a.options.OauthServerURL != "" { + oauthServerURL = a.options.OauthServerURL + } + slog.Debug("", slog.String("OAUTH SERVER URL", oauthServerURL)) + + jwksURL := os.Getenv("JWKS_URL") + if a.options.JwksURL != "" { + jwksURL = a.options.JwksURL + } + slog.Debug("", slog.String("JWKS URL", jwksURL)) + resourceURL := os.Getenv("RESOURCE_URL") + if a.options.ResourceURL != "" { + resourceURL = a.options.ResourceURL + } + slog.Debug("", slog.String("RESOURCE URL", resourceURL)) + + return oauthServerURL, jwksURL, resourceURL, oauthServerURL != "" && resourceURL != "" && jwksURL != "" +} From cc0fd8fa314cc062b4769f1432fe6dd0f26fda6c Mon Sep 17 00:00:00 2001 From: hardikl Date: Fri, 15 May 2026 20:49:07 +0530 Subject: [PATCH 2/2] feat: minor change --- go.mod | 5 ++--- go.sum | 6 ------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index f6328e1..45d25fb 100644 --- a/go.mod +++ b/go.mod @@ -7,19 +7,18 @@ require ( github.com/alecthomas/kong v1.15.0 github.com/carlmjohnson/requests v0.25.1 github.com/goccy/go-yaml v1.19.2 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/modelcontextprotocol/go-sdk v1.6.0 ) require ( github.com/MicahParks/jwkset v0.11.0 // indirect - github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/google/jsonschema-go v0.4.3 // indirect github.com/segmentio/asm v1.2.1 // indirect github.com/segmentio/encoding v0.5.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/net v0.53.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect - golang.org/x/sys v0.42.0 // indirect - golang.org/x/time v0.9.0 // indirect golang.org/x/sys v0.44.0 // indirect + golang.org/x/time v0.9.0 // indirect ) diff --git a/go.sum b/go.sum index d616720..fee349c 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,6 @@ github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+ github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= -github.com/modelcontextprotocol/go-sdk v1.6.0-pre.1 h1:wAz+jUrWmkDOnD1fSba+inmp7dzMHg7yG3jv7XU3+mc= -github.com/modelcontextprotocol/go-sdk v1.6.0-pre.1/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= @@ -30,14 +28,10 @@ github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfv github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=