Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cmd/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (
var logger = setupLogger()

type Globals struct {
LogLevel string `enum:"debug,info,warn,error" default:"info" env:"LOG_LEVEL" help:"Log level, one of: ${enum}"`
ConfigPath string `name:"config" default:"ontap.yaml" env:"ONTAP_MCP_CONFIG" help:"ONTAP-MCP config path"`
LogLevel string `enum:"debug,info,warn,error" default:"info" env:"LOG_LEVEL" help:"Log level, one of: ${enum}"`
ConfigPath string `name:"config" default:"ontap.yaml" env:"ONTAP_MCP_CONFIG" help:"ONTAP-MCP config path"`
OAuthServerURL string `name:"oauth-server-url" env:"OAUTH_SERVER_URL" help:"Authorization Server URL"`
JwksURL string `name:"jwks-url" env:"JWKS_URL" help:"JWKS URL"`
ResourceURL string `name:"resource-url" env:"RESOURCE_URL" help:"Resource URL for this MCP server"`
}

type CLI struct {
Expand Down Expand Up @@ -51,6 +54,9 @@ func (a *StartCmd) Run(cli *CLI) error {
InspectTraffic: cli.Start.InspectTraffic,
ReadOnly: cli.Start.ReadOnly,
Stateless: cli.Start.Stateless,
OauthServerURL: cli.OAuthServerURL,
JwksURL: cli.JwksURL,
ResourceURL: cli.ResourceURL,
JSONResponse: cli.Start.JSONResponse,
}

Expand Down
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ module github.com/netapp/ontap-mcp
go 1.26.3

require (
github.com/MicahParks/keyfunc/v3 v3.8.0
github.com/alecthomas/kong v1.15.0
github.com/carlmjohnson/requests v0.25.1
github.com/goccy/go-yaml v1.19.2
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/modelcontextprotocol/go-sdk v1.6.0
)

require (
github.com/MicahParks/jwkset v0.11.0 // indirect
github.com/google/jsonschema-go v0.4.3 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/segmentio/encoding v0.5.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/net v0.53.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/sys v0.44.0 // indirect
golang.org/x/time v0.9.0 // indirect
)
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ=
github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0=
github.com/MicahParks/keyfunc/v3 v3.8.0 h1:Hx2dgIjAXGk9slakM6rV9BOeaWDPEXXZ4Us8guNBfds=
github.com/MicahParks/keyfunc/v3 v3.8.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/kong v1.15.0 h1:BVJstKbpO73zKpmIu+m/aLRrNmWwxXPIGTNin9VmLVI=
Expand All @@ -16,8 +20,6 @@ github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/modelcontextprotocol/go-sdk v1.6.0-pre.1 h1:wAz+jUrWmkDOnD1fSba+inmp7dzMHg7yG3jv7XU3+mc=
github.com/modelcontextprotocol/go-sdk v1.6.0-pre.1/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ=
github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY=
github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
Expand All @@ -26,15 +28,13 @@ github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfv
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
188 changes: 188 additions & 0 deletions server/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package server

import (
"bytes"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"slices"
"strings"
"time"

"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
)

type OAuthConfig struct {
AuthServerURL string
JwksURL string
ResourceURL string
jwks keyfunc.Keyfunc
}

func (c *OAuthConfig) InitJWKS() error {
jwks, err := keyfunc.NewDefault([]string{c.JwksURL})
if err != nil {
return fmt.Errorf("failed to create client of JWKS: %w", err)
}
c.jwks = jwks
slog.Info("Initialized JWKS", slog.String("JwksURL", c.JwksURL))
return nil
}

func (c *OAuthConfig) OAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
c.sendUnauthorized(w, r)
return
}

// Bearer token
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
c.sendUnauthorized(w, r)
return
}

// Validate JWT token
token, err := jwt.Parse(tokenString, c.jwks.Keyfunc, jwt.WithValidMethods([]string{"RS256"}))
if err != nil {
slog.Error("failed to parse token", slog.Any("err", err))
c.sendUnauthorized(w, r)
return
}

if !token.Valid {
slog.Error("token is invalid")
c.sendUnauthorized(w, r)
return
}

// Get claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
slog.Error("claim type invalid")
c.sendUnauthorized(w, r)
return
}

slog.Debug("=== JWT Access Token Debug ===")
slog.Debug("", slog.String("token", tokenString))
claimsJSON, _ := json.MarshalIndent(claims, "", " ")
slog.Debug("", slog.String("claim", string(claimsJSON)))
slog.Debug("===============================")
Comment on lines +73 to +77

// Validate audience
if !c.validateAudience(claims) {
slog.Error("audience invalid")
c.sendUnauthorized(w, r)
return
}

// Validate issuer
if !c.validateIssuer(claims) {
slog.Error("issuer invalid")
c.sendUnauthorized(w, r)
return
}

// Validate expiration
// Note: jwt.Parse already validates exp by default, but we explicitly check here for clarity
if !c.validateExpiration(claims) {
slog.Error("token has been expired")
c.sendUnauthorized(w, r)
return
}

// Validate scope
if !c.validateScope(claims) {
slog.Error("scope insufficient")
c.sendUnauthorized(w, r)
return
}

// Authorization successful
next.ServeHTTP(w, r)
})
}

func (c *OAuthConfig) validateAudience(claims jwt.MapClaims) bool {
aud, ok := claims["aud"]
if !ok {
return false
}

// aud can be a string or array of strings
switch v := aud.(type) {
case string:
return v == c.ResourceURL
case []any:
for _, a := range v {
if audStr, ok := a.(string); ok && audStr == c.ResourceURL {
return true
}
}
return false
default:
return false
}
}

func (c *OAuthConfig) validateIssuer(claims jwt.MapClaims) bool {
iss, ok := claims["iss"].(string)
if !ok {
return false
}
return iss == c.AuthServerURL
}

func (c *OAuthConfig) validateExpiration(claims jwt.MapClaims) bool {
exp, ok := claims["exp"].(float64)
if !ok {
return false
}
// Allow 60 seconds of clock skew
return time.Now().Unix() < int64(exp)+60
Comment on lines +144 to +149
}

func (c *OAuthConfig) validateScope(claims jwt.MapClaims) bool {
scope, ok := claims["scope"].(string)
if !ok {
return false
}
// Scope is a space-separated string (OAuth 2.0 standard)
// Check if "mcp:tools" is present
s := strings.Split(scope, " ")
return slices.Contains(s, "mcp:tools")
}

func (c *OAuthConfig) sendUnauthorized(w http.ResponseWriter, _ *http.Request) {
metadataURL := c.ResourceURL + "/.well-known/oauth-protected-resource"
w.Header().Set("WWW-Authenticate",
fmt.Sprintf(`Bearer resource_metadata="%q", scope="openid profile email"`, metadataURL))
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}

func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
slog.Debug("", slog.String("method", r.Method), slog.String("urlPath", r.URL.Path), slog.String("RemoteAddr", r.RemoteAddr))

if r.Method == http.MethodPost && r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
slog.Error("Error reading body", slog.Any("err", err))
} else {
slog.Debug("", slog.String("body", string(bodyBytes)))
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
}

next.ServeHTTP(w, r)
slog.Debug("request finished", slog.Any("duration", time.Since(start)))
})
Comment on lines +170 to +187
}
Loading
Loading