diff --git a/cmd/wfctl/type_registry.go b/cmd/wfctl/type_registry.go index 813bcf4c..d9bb8f12 100644 --- a/cmd/wfctl/type_registry.go +++ b/cmd/wfctl/type_registry.go @@ -161,6 +161,12 @@ func KnownModuleTypes() map[string]ModuleTypeInfo { Stateful: false, ConfigKeys: []string{"providers"}, }, + "auth.m2m": { + Type: "auth.m2m", + Plugin: "auth", + Stateful: false, + ConfigKeys: []string{"secret", "algorithm", "privateKey", "tokenExpiry", "issuer", "clients"}, + }, // messaging plugin "messaging.broker": { diff --git a/module/auth_m2m.go b/module/auth_m2m.go new file mode 100644 index 00000000..16fc8c74 --- /dev/null +++ b/module/auth_m2m.go @@ -0,0 +1,656 @@ +package module + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/golang-jwt/jwt/v5" +) + +// GrantType constants for OAuth2 M2M flows. +const ( + GrantTypeClientCredentials = "client_credentials" + //nolint:gosec // G101: This is an OAuth2 grant type name, not a credential value. + GrantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" +) + +// SigningAlgorithm defines the JWT signing algorithm for the M2M module. +type SigningAlgorithm string + +const ( + // SigningAlgHS256 uses HMAC-SHA256 (symmetric). + SigningAlgHS256 SigningAlgorithm = "HS256" + // SigningAlgES256 uses ECDSA P-256 (asymmetric). + SigningAlgES256 SigningAlgorithm = "ES256" +) + +// M2MClient represents a registered machine-to-machine OAuth2 client. +type M2MClient struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field + Description string `json:"description,omitempty"` + Scopes []string `json:"scopes,omitempty"` +} + +// M2MAuthModule provides machine-to-machine (server-to-server) OAuth2 authentication. +// It supports the client_credentials grant and the JWT-bearer grant, and can issue +// tokens signed with either HS256 (shared secret) or ES256 (ECDSA P-256). +type M2MAuthModule struct { + name string + algorithm SigningAlgorithm + issuer string + tokenExpiry time.Duration + + // initErr holds an error from factory-time key setup (e.g. SetECDSAKey/GenerateECDSAKey), + // which is surfaced in Init() since module factories cannot return errors. + initErr error + + // HS256 fields + hmacSecret []byte + + // ES256 fields + privateKey *ecdsa.PrivateKey + publicKey *ecdsa.PublicKey + + // Trusted public keys for JWT-bearer grant (keyed by key ID or issuer) + trustedKeys map[string]*ecdsa.PublicKey + + // Registered clients + mu sync.RWMutex + clients map[string]*M2MClient // keyed by ClientID +} + +// NewM2MAuthModule creates a new M2MAuthModule with HS256 signing. +// Use SetECDSAKey or GenerateECDSAKey to switch to ES256 signing. +func NewM2MAuthModule(name string, hmacSecret string, tokenExpiry time.Duration, issuer string) *M2MAuthModule { + if tokenExpiry <= 0 { + tokenExpiry = time.Hour + } + if issuer == "" { + issuer = "workflow" + } + m := &M2MAuthModule{ + name: name, + algorithm: SigningAlgHS256, + issuer: issuer, + tokenExpiry: tokenExpiry, + hmacSecret: []byte(hmacSecret), + trustedKeys: make(map[string]*ecdsa.PublicKey), + clients: make(map[string]*M2MClient), + } + return m +} + +// GenerateECDSAKey generates a new P-256 key pair and switches the module to ES256 signing. +func (m *M2MAuthModule) GenerateECDSAKey() error { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("generate ECDSA key: %w", err) + } + m.privateKey = key + m.publicKey = &key.PublicKey + m.algorithm = SigningAlgES256 + return nil +} + +// SetECDSAKey loads a PEM-encoded EC private key and switches the module to ES256 signing. +// Only P-256 keys are accepted; other curves are rejected. +func (m *M2MAuthModule) SetECDSAKey(pemKey string) error { + block, _ := pem.Decode([]byte(pemKey)) + if block == nil { + return fmt.Errorf("failed to decode PEM block") + } + key, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("parse EC private key: %w", err) + } + if key.Curve != elliptic.P256() { + return fmt.Errorf("unsupported ECDSA curve: got %s, want P-256", key.Curve.Params().Name) + } + m.privateKey = key + m.publicKey = &key.PublicKey + m.algorithm = SigningAlgES256 + return nil +} + +// SetInitErr stores a deferred initialization error to be returned by Init(). +// This is used by factory functions which cannot return errors directly. +func (m *M2MAuthModule) SetInitErr(err error) { + m.initErr = err +} + +// AddTrustedKey registers a trusted ECDSA public key for JWT-bearer assertion validation. +// The keyID is used to look up the key; it can be an issuer name or any unique identifier. +func (m *M2MAuthModule) AddTrustedKey(keyID string, pubKey *ecdsa.PublicKey) { + m.mu.Lock() + defer m.mu.Unlock() + m.trustedKeys[keyID] = pubKey +} + +// RegisterClient registers a new OAuth2 client. +func (m *M2MAuthModule) RegisterClient(client M2MClient) { + m.mu.Lock() + defer m.mu.Unlock() + m.clients[client.ClientID] = &client +} + +// Name returns the module name. +func (m *M2MAuthModule) Name() string { return m.name } + +// Init validates the module configuration. It also surfaces any key-setup error +// that occurred in the factory (stored in initErr). +func (m *M2MAuthModule) Init(_ modular.Application) error { + if m.initErr != nil { + return fmt.Errorf("M2M auth: key setup failed: %w", m.initErr) + } + if m.algorithm == SigningAlgHS256 && len(m.hmacSecret) < 32 { + return fmt.Errorf("M2M auth: HMAC secret must be at least 32 bytes for HS256") + } + if m.algorithm == SigningAlgES256 && m.privateKey == nil { + return fmt.Errorf("M2M auth: ECDSA private key required for ES256") + } + return nil +} + +// ProvidesServices returns the services provided by this module. +func (m *M2MAuthModule) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + { + Name: m.name, + Description: "Machine-to-machine OAuth2 auth module (client_credentials + jwt-bearer)", + Instance: m, + }, + } +} + +// RequiresServices returns an empty list (no external dependencies). +func (m *M2MAuthModule) RequiresServices() []modular.ServiceDependency { return nil } + +// Handle routes M2M OAuth2 requests. +// +// Routes: +// +// POST /oauth/token — token endpoint (client_credentials + jwt-bearer grants) +// GET /oauth/jwks — JSON Web Key Set (ES256 public key) +func (m *M2MAuthModule) Handle(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + path := r.URL.Path + switch { + case r.Method == http.MethodPost && strings.HasSuffix(path, "/oauth/token"): + m.handleToken(w, r) + case r.Method == http.MethodGet && strings.HasSuffix(path, "/oauth/jwks"): + m.handleJWKS(w, r) + default: + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"}) + } +} + +// handleToken implements RFC 6749 § 4.4 (client_credentials) and +// RFC 7523 § 2.1 (jwt-bearer assertion) token endpoints. +func (m *M2MAuthModule) handleToken(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(oauthError("invalid_request", "failed to parse form")) + return + } + + grantType := r.FormValue("grant_type") + switch grantType { + case GrantTypeClientCredentials: + m.handleClientCredentials(w, r) + case GrantTypeJWTBearer: + m.handleJWTBearer(w, r) + default: + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(oauthError("unsupported_grant_type", + fmt.Sprintf("grant_type %q not supported; use %q or %q", + grantType, GrantTypeClientCredentials, GrantTypeJWTBearer))) + } +} + +// handleClientCredentials processes the OAuth2 client_credentials grant. +// Clients send client_id + client_secret (either as form params or HTTP Basic auth) +// and receive a signed access token. +func (m *M2MAuthModule) handleClientCredentials(w http.ResponseWriter, r *http.Request) { + clientID, clientSecret, ok := m.extractClientCredentials(r) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(oauthError("invalid_client", "client credentials required")) + return + } + + client, err := m.authenticateClient(clientID, clientSecret) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(oauthError("invalid_client", err.Error())) + return + } + + // Validate requested scopes against client's allowed scopes. + requestedScope := r.FormValue("scope") + grantedScopes, err := m.validateScopes(client, requestedScope) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(oauthError("invalid_scope", err.Error())) + return + } + + token, err := m.issueToken(clientID, grantedScopes, nil) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(oauthError("server_error", "failed to issue token")) + return + } + + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": token, + "token_type": "Bearer", + "expires_in": int(m.tokenExpiry.Seconds()), + "scope": strings.Join(grantedScopes, " "), + }) +} + +// handleJWTBearer processes the JWT-bearer grant (RFC 7523). +// The client sends a signed JWT assertion; if the signature is valid and the +// assertion is trusted, an access token is returned. +func (m *M2MAuthModule) handleJWTBearer(w http.ResponseWriter, r *http.Request) { + assertion := r.FormValue("assertion") + if assertion == "" { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(oauthError("invalid_request", "assertion is required")) + return + } + + claims, err := m.validateJWTAssertion(assertion) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(oauthError("invalid_grant", err.Error())) + return + } + + // The subject (sub) becomes the client identity in the issued token. + subject, _ := claims["sub"].(string) + if subject == "" { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(oauthError("invalid_grant", "assertion missing sub claim")) + return + } + + requestedScope := r.FormValue("scope") + var grantedScopes []string + if requestedScope != "" { + grantedScopes = strings.Fields(requestedScope) + } + + token, err := m.issueToken(subject, grantedScopes, claims) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(oauthError("server_error", "failed to issue token")) + return + } + + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": token, + "token_type": "Bearer", + "expires_in": int(m.tokenExpiry.Seconds()), + "scope": strings.Join(grantedScopes, " "), + }) +} + +// handleJWKS returns the JSON Web Key Set containing the module's public key(s). +// Only available when the module is configured for ES256. +func (m *M2MAuthModule) handleJWKS(w http.ResponseWriter, _ *http.Request) { + if m.algorithm != SigningAlgES256 || m.publicKey == nil { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "JWKS not available: algorithm must be ES256 with a configured public key", + }) + return + } + + jwk, err := ecPublicKeyToJWK(m.publicKey, m.name+"-key") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(oauthError("server_error", "failed to generate JWK for ES256 public key")) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "keys": []any{jwk}, + }) +} + +// --- token issuance --- + +// issueToken creates and signs a JWT access token. +// extraClaims are merged in (e.g., from a jwt-bearer assertion). +func (m *M2MAuthModule) issueToken(subject string, scopes []string, extraClaims map[string]any) (string, error) { + now := time.Now() + claims := jwt.MapClaims{ + "iss": m.issuer, + "sub": subject, + "iat": now.Unix(), + "exp": now.Add(m.tokenExpiry).Unix(), + } + if len(scopes) > 0 { + claims["scope"] = strings.Join(scopes, " ") + } + // Merge extra claims, but never let them override standard fields. + for k, v := range extraClaims { + switch k { + case "iss", "sub", "iat", "exp", "scope": + // protected — skip + default: + claims[k] = v + } + } + + switch m.algorithm { + case SigningAlgES256: + tok := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + return tok.SignedString(m.privateKey) + default: // HS256 + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return tok.SignedString(m.hmacSecret) + } +} + +// --- authentication helpers --- + +// extractClientCredentials returns client_id and client_secret from either +// HTTP Basic Auth or the request form body (per RFC 6749 § 2.3). +func (m *M2MAuthModule) extractClientCredentials(r *http.Request) (string, string, bool) { + // Prefer HTTP Basic Auth. + if clientID, clientSecret, ok := r.BasicAuth(); ok && clientID != "" { + return clientID, clientSecret, true + } + // Fall back to form params. + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + if clientID != "" && clientSecret != "" { + return clientID, clientSecret, true + } + return "", "", false +} + +// authenticateClient looks up and verifies a client by ID and secret. +func (m *M2MAuthModule) authenticateClient(clientID, clientSecret string) (*M2MClient, error) { + m.mu.RLock() + client, ok := m.clients[clientID] + m.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("client not found") + } + + // Compare fixed-length SHA-256 hashes to keep the comparison constant-time + // regardless of whether the provided secret length differs from the stored one, + // since subtle.ConstantTimeCompare returns early when lengths differ. + storedHash := sha256.Sum256([]byte(client.ClientSecret)) + providedHash := sha256.Sum256([]byte(clientSecret)) + if subtle.ConstantTimeCompare(storedHash[:], providedHash[:]) != 1 { + return nil, fmt.Errorf("invalid client secret") + } + + return client, nil +} + +// validateScopes checks that all requested scopes are permitted for the client. +// If no scopes are requested, the client's full scope list is granted. +func (m *M2MAuthModule) validateScopes(client *M2MClient, requestedScope string) ([]string, error) { + if requestedScope == "" { + return client.Scopes, nil + } + + requested := strings.Fields(requestedScope) + allowed := make(map[string]bool, len(client.Scopes)) + for _, s := range client.Scopes { + allowed[s] = true + } + + for _, s := range requested { + if !allowed[s] { + return nil, fmt.Errorf("scope %q not permitted for this client", s) + } + } + return requested, nil +} + +// validateJWTAssertion parses and validates a JWT bearer assertion (RFC 7523). +// It first parses the assertion unverified to extract the `iss` claim and the +// `kid` header, then selects the matching trusted key, and verifies the signature +// with that specific key. This prevents a holder of any trusted key from +// impersonating an arbitrary subject. +func (m *M2MAuthModule) validateJWTAssertion(assertion string) (jwt.MapClaims, error) { + // Parse unverified to extract iss/kid for key selection. + unverified, _, err := new(jwt.Parser).ParseUnverified(assertion, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("malformed assertion: %w", err) + } + uClaims, ok := unverified.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("malformed assertion claims") + } + iss, _ := uClaims["iss"].(string) + kid, _ := unverified.Header["kid"].(string) + + m.mu.RLock() + // Try kid first, then iss. + var selectedKey *ecdsa.PublicKey + if kid != "" { + selectedKey = m.trustedKeys[kid] + } + if selectedKey == nil && iss != "" { + selectedKey = m.trustedKeys[iss] + } + hmacSecret := m.hmacSecret + m.mu.RUnlock() + + // Try EC key if found. + if selectedKey != nil { + k := selectedKey + token, err := jwt.Parse(assertion, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return k, nil + }, jwt.WithExpirationRequired()) + if err != nil { + return nil, fmt.Errorf("invalid assertion: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, fmt.Errorf("invalid assertion claims") + } + return claims, nil + } + + // Fall back to HS256 using the module's own secret (for internal/testing use). + // The assertion must be signed with the module's exact secret. + if len(hmacSecret) >= 32 { + token, err := jwt.Parse(assertion, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return hmacSecret, nil + }, jwt.WithExpirationRequired()) + if err != nil { + return nil, fmt.Errorf("invalid assertion: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, fmt.Errorf("invalid assertion claims") + } + return claims, nil + } + + return nil, fmt.Errorf("no trusted key found for assertion issuer %q", iss) +} + +// Authenticate implements the AuthProvider interface so M2MAuthModule can be +// used as a provider in AuthMiddleware. It validates the token's signature +// using the configured algorithm and returns the embedded claims. +func (m *M2MAuthModule) Authenticate(tokenStr string) (bool, map[string]any, error) { + var token *jwt.Token + var err error + + switch m.algorithm { + case SigningAlgES256: + if m.publicKey == nil { + return false, nil, fmt.Errorf("no ECDSA public key configured") + } + token, err = jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return m.publicKey, nil + }) + default: // HS256 + token, err = jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return m.hmacSecret, nil + }) + } + + if err != nil { + return false, nil, nil //nolint:nilerr // Invalid token is a failed auth, not an error + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return false, nil, nil + } + + result := make(map[string]any, len(claims)) + for k, v := range claims { + result[k] = v + } + return true, result, nil +} + +// --- JWKS helpers --- + +// ecPublicKeyToJWK converts an ECDSA P-256 public key to a JWK (RFC 7517) map. +// It uses the ecdh package to extract the uncompressed point bytes, avoiding +// the deprecated ecdsa.PublicKey.X / .Y big.Int fields. +// Returns an error if the key cannot be converted. +func ecPublicKeyToJWK(pub *ecdsa.PublicKey, kid string) (map[string]any, error) { + ecdhPub, err := pub.ECDH() + if err != nil { + return nil, fmt.Errorf("convert to ECDH key: %w", err) + } + // Uncompressed point format for P-256: 0x04 || x (32 bytes) || y (32 bytes) = 65 bytes. + b := ecdhPub.Bytes() + if len(b) != 65 || b[0] != 0x04 { + return nil, fmt.Errorf("unexpected uncompressed point length %d or prefix 0x%02x (want 65, 0x04)", len(b), b[0]) + } + x := b[1:33] + y := b[33:65] + return map[string]any{ + "kty": "EC", + "crv": "P-256", + "alg": "ES256", + "use": "sig", + "kid": kid, + "x": base64.RawURLEncoding.EncodeToString(x), + "y": base64.RawURLEncoding.EncodeToString(y), + }, nil +} + +// jwkThumbprint computes the JWK thumbprint (RFC 7638) for an EC P-256 key. +// This is useful for deriving deterministic key IDs. +func jwkThumbprint(pub *ecdsa.PublicKey) string { + ecdhPub, err := pub.ECDH() + if err != nil { + return "" + } + b := ecdhPub.Bytes() + if len(b) != 65 || b[0] != 0x04 { + return "" + } + x := base64.RawURLEncoding.EncodeToString(b[1:33]) + y := base64.RawURLEncoding.EncodeToString(b[33:65]) + // RFC 7638: lexicographic JSON of required members. + //nolint:gocritic // sprintfQuotedString: %s is required here; %q would add extra escaping + raw := fmt.Sprintf(`{"crv":"P-256","kty":"EC","x":"%s","y":"%s"}`, x, y) + h := sha256.Sum256([]byte(raw)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +// jwkToECPublicKey converts a JWK map (EC P-256) back to an *ecdsa.PublicKey. +// Returns an error if the JWK is not a valid EC P-256 key. +func jwkToECPublicKey(jwk map[string]any) (*ecdsa.PublicKey, error) { + kty, _ := jwk["kty"].(string) + if kty != "EC" { + return nil, fmt.Errorf("expected kty=EC, got %q", kty) + } + crv, _ := jwk["crv"].(string) + if crv != "P-256" { + return nil, fmt.Errorf("expected crv=P-256, got %q", crv) + } + xStr, _ := jwk["x"].(string) + yStr, _ := jwk["y"].(string) + if xStr == "" || yStr == "" { + return nil, fmt.Errorf("missing x or y coordinate in JWK") + } + xBytes, err := base64.RawURLEncoding.DecodeString(xStr) + if err != nil { + return nil, fmt.Errorf("decode x: %w", err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(yStr) + if err != nil { + return nil, fmt.Errorf("decode y: %w", err) + } + if len(xBytes) != 32 || len(yBytes) != 32 { + return nil, fmt.Errorf("invalid P-256 coordinate length: x=%d y=%d", len(xBytes), len(yBytes)) + } + // Construct uncompressed point: 0x04 || x || y + uncompressed := make([]byte, 65) + uncompressed[0] = 0x04 + copy(uncompressed[1:33], xBytes) + copy(uncompressed[33:65], yBytes) + + // Parse via ecdh, then convert to ecdsa via PKIX round-trip. + ecdhPub, err := ecdh.P256().NewPublicKey(uncompressed) + if err != nil { + return nil, fmt.Errorf("parse uncompressed point: %w", err) + } + pkixBytes, err := x509.MarshalPKIXPublicKey(ecdhPub) + if err != nil { + return nil, fmt.Errorf("marshal PKIX: %w", err) + } + pub, err := x509.ParsePKIXPublicKey(pkixBytes) + if err != nil { + return nil, fmt.Errorf("parse PKIX: %w", err) + } + ecdsaPub, ok := pub.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("unexpected key type %T", pub) + } + return ecdsaPub, nil +} + +// oauthError builds an RFC 6749-compliant error response body. +func oauthError(code, description string) map[string]string { + return map[string]string{ + "error": code, + "error_description": description, + } +} diff --git a/module/auth_m2m_test.go b/module/auth_m2m_test.go new file mode 100644 index 00000000..4978883d --- /dev/null +++ b/module/auth_m2m_test.go @@ -0,0 +1,1173 @@ +package module + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// --- helpers --- + +// newM2MHS256 creates an M2MAuthModule configured with HS256 and a test client. +func newM2MHS256(t *testing.T) *M2MAuthModule { + t.Helper() + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer") + m.RegisterClient(M2MClient{ + ClientID: "test-client", + ClientSecret: "test-secret", //nolint:gosec // test credential + Scopes: []string{"read", "write"}, + }) + return m +} + +// newM2MES256 creates an M2MAuthModule configured with ES256 and a test client. +func newM2MES256(t *testing.T) *M2MAuthModule { + t.Helper() + m := NewM2MAuthModule("m2m", "", time.Hour, "test-issuer") + if err := m.GenerateECDSAKey(); err != nil { + t.Fatalf("GenerateECDSAKey: %v", err) + } + m.RegisterClient(M2MClient{ + ClientID: "es256-client", + ClientSecret: "es256-secret", //nolint:gosec // test credential + Scopes: []string{"api"}, + }) + return m +} + +// postToken is a test helper that sends a form-encoded POST to /oauth/token. +func postToken(t *testing.T, m *M2MAuthModule, params url.Values) *httptest.ResponseRecorder { + t.Helper() + req := httptest.NewRequest(http.MethodPost, "/oauth/token", + strings.NewReader(params.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + m.Handle(w, req) + return w +} + +// --- Name / Init --- + +func TestM2M_Name(t *testing.T) { + m := NewM2MAuthModule("my-m2m", "secret-must-be-at-least-32bytes!!", time.Hour, "") + if m.Name() != "my-m2m" { + t.Errorf("expected 'my-m2m', got %q", m.Name()) + } +} + +func TestM2M_InitHS256_ValidSecret(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "issuer") + if err := m.Init(nil); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestM2M_InitHS256_ShortSecret(t *testing.T) { + m := NewM2MAuthModule("m2m", "short", time.Hour, "issuer") + if err := m.Init(nil); err == nil { + t.Error("expected error for short HMAC secret") + } +} + +func TestM2M_InitES256_NoKey(t *testing.T) { + m := &M2MAuthModule{ + name: "m2m", + algorithm: SigningAlgES256, + issuer: "issuer", + } + if err := m.Init(nil); err == nil { + t.Error("expected error when ES256 key not set") + } +} + +func TestM2M_InitES256_WithKey(t *testing.T) { + m := newM2MES256(t) + if err := m.Init(nil); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestM2M_ProvidesServices(t *testing.T) { + m := newM2MHS256(t) + svcs := m.ProvidesServices() + if len(svcs) != 1 { + t.Fatalf("expected 1 service, got %d", len(svcs)) + } + if svcs[0].Name != "m2m" { + t.Errorf("expected service name 'm2m', got %q", svcs[0].Name) + } +} + +// --- client_credentials grant --- + +func TestM2M_ClientCredentials_FormParams(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp["access_token"] == nil || resp["access_token"] == "" { + t.Error("expected non-empty access_token") + } + if resp["token_type"] != "Bearer" { + t.Errorf("expected token_type=Bearer, got %v", resp["token_type"]) + } + if resp["expires_in"] == nil { + t.Error("expected expires_in in response") + } +} + +func TestM2M_ClientCredentials_BasicAuth(t *testing.T) { + m := newM2MHS256(t) + + req := httptest.NewRequest(http.MethodPost, "/oauth/token", + strings.NewReader("grant_type=client_credentials")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("test-client", "test-secret") + w := httptest.NewRecorder() + m.Handle(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + if resp["access_token"] == nil { + t.Error("expected access_token in response") + } +} + +func TestM2M_ClientCredentials_WrongSecret(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"wrong-secret"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for wrong secret, got %d", w.Code) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "invalid_client" { + t.Errorf("expected error=invalid_client, got %q", resp["error"]) + } +} + +func TestM2M_ClientCredentials_UnknownClient(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"unknown"}, + "client_secret": {"secret"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for unknown client, got %d", w.Code) + } +} + +func TestM2M_ClientCredentials_MissingCredentials(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 when no credentials, got %d", w.Code) + } +} + +func TestM2M_ClientCredentials_ScopeGranted(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + "scope": {"read"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + if resp["scope"] != "read" { + t.Errorf("expected scope=read, got %v", resp["scope"]) + } +} + +func TestM2M_ClientCredentials_ScopeNotPermitted(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + "scope": {"admin"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for forbidden scope, got %d", w.Code) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "invalid_scope" { + t.Errorf("expected error=invalid_scope, got %q", resp["error"]) + } +} + +func TestM2M_ClientCredentials_NoScopeGrantsAll(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + // no scope param → grant all client scopes + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + scopeVal, _ := resp["scope"].(string) + if !strings.Contains(scopeVal, "read") || !strings.Contains(scopeVal, "write") { + t.Errorf("expected all client scopes, got %q", scopeVal) + } +} + +// --- ES256 token issuance --- + +func TestM2M_ES256_ClientCredentials_IssuesToken(t *testing.T) { + m := newM2MES256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"es256-client"}, + "client_secret": {"es256-secret"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + if tokenStr == "" { + t.Fatal("expected non-empty access_token") + } + + // Parse the token header to confirm ES256 algorithm. + tok, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + t.Fatalf("parse token: %v", err) + } + if tok.Method.Alg() != "ES256" { + t.Errorf("expected ES256 algorithm, got %q", tok.Method.Alg()) + } +} + +func TestM2M_ES256_TokenVerifiable(t *testing.T) { + m := newM2MES256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"es256-client"}, + "client_secret": {"es256-secret"}, + } + w := postToken(t, m, params) + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + // Verify the token using the module's Authenticate method. + valid, claims, err := m.Authenticate(tokenStr) + if err != nil { + t.Fatalf("authenticate error: %v", err) + } + if !valid { + t.Error("expected token to be valid") + } + if claims["sub"] != "es256-client" { + t.Errorf("expected sub=es256-client, got %v", claims["sub"]) + } +} + +func TestM2M_ES256_GenerateKey(t *testing.T) { + m := NewM2MAuthModule("m2m", "", time.Hour, "issuer") + if err := m.GenerateECDSAKey(); err != nil { + t.Fatalf("GenerateECDSAKey: %v", err) + } + if m.privateKey == nil { + t.Error("expected private key to be set") + } + if m.publicKey == nil { + t.Error("expected public key to be set") + } + if m.algorithm != SigningAlgES256 { + t.Errorf("expected algorithm ES256, got %v", m.algorithm) + } +} + +func TestM2M_SetECDSAKey_ValidPEM(t *testing.T) { + // Generate a key to export as PEM. + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("marshal key: %v", err) + } + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: der, + }) + + m := NewM2MAuthModule("m2m", "", time.Hour, "issuer") + if err := m.SetECDSAKey(string(pemBytes)); err != nil { + t.Fatalf("SetECDSAKey: %v", err) + } + if m.algorithm != SigningAlgES256 { + t.Error("expected algorithm to be ES256 after SetECDSAKey") + } +} + +func TestM2M_SetECDSAKey_InvalidPEM(t *testing.T) { + m := NewM2MAuthModule("m2m", "", time.Hour, "issuer") + if err := m.SetECDSAKey("not a pem"); err == nil { + t.Error("expected error for invalid PEM") + } +} + +func TestM2M_SetECDSAKey_NonP256Rejected(t *testing.T) { + // Generate a P-384 key (not P-256) and verify it is rejected. + key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatalf("generate P-384 key: %v", err) + } + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("marshal key: %v", err) + } + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}) + m := NewM2MAuthModule("m2m", "", time.Hour, "issuer") + if err := m.SetECDSAKey(string(pemBytes)); err == nil { + t.Error("expected error for non-P256 key") + } else if !strings.Contains(err.Error(), "P-256") { + t.Errorf("expected P-256 mention in error, got %q", err.Error()) + } +} + +func TestM2M_InitErr_SurfacedInInit(t *testing.T) { + m := NewM2MAuthModule("m2m", "", time.Hour, "issuer") + m.SetInitErr(fmt.Errorf("injected key error")) + if err := m.Init(nil); err == nil { + t.Error("expected init error to surface") + } +} + +// --- JWKS endpoint --- + +func TestM2M_JWKS_ES256(t *testing.T) { + m := newM2MES256(t) + + req := httptest.NewRequest(http.MethodGet, "/oauth/jwks", nil) + w := httptest.NewRecorder() + m.Handle(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 for JWKS, got %d; body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode JWKS: %v", err) + } + keys, ok := resp["keys"].([]any) + if !ok || len(keys) == 0 { + t.Fatal("expected non-empty keys array") + } + jwk, _ := keys[0].(map[string]any) + if jwk["kty"] != "EC" { + t.Errorf("expected kty=EC, got %v", jwk["kty"]) + } + if jwk["crv"] != "P-256" { + t.Errorf("expected crv=P-256, got %v", jwk["crv"]) + } + if jwk["alg"] != "ES256" { + t.Errorf("expected alg=ES256, got %v", jwk["alg"]) + } +} + +func TestM2M_JWKS_HS256_NotAvailable(t *testing.T) { + m := newM2MHS256(t) + + req := httptest.NewRequest(http.MethodGet, "/oauth/jwks", nil) + w := httptest.NewRecorder() + m.Handle(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for HS256 JWKS, got %d", w.Code) + } +} + +func TestM2M_JWKS_RoundTrip(t *testing.T) { + m := newM2MES256(t) + + // Get JWKS. + req := httptest.NewRequest(http.MethodGet, "/oauth/jwks", nil) + w := httptest.NewRecorder() + m.Handle(w, req) + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + keys := resp["keys"].([]any) + jwk := keys[0].(map[string]any) + + // Reconstruct the public key from the JWK. + pub, err := jwkToECPublicKey(jwk) + if err != nil { + t.Fatalf("jwkToECPublicKey: %v", err) + } + + // Compare via ECDH byte representation to avoid using deprecated X/Y fields. + origBytes, err := m.publicKey.ECDH() + if err != nil { + t.Fatalf("original key ECDH: %v", err) + } + reconBytes, err := pub.ECDH() + if err != nil { + t.Fatalf("reconstructed key ECDH: %v", err) + } + if string(origBytes.Bytes()) != string(reconBytes.Bytes()) { + t.Error("reconstructed key does not match original") + } +} + +// --- JWT-bearer grant --- + +func TestM2M_JWTBearer_ES256_Valid(t *testing.T) { + // Server M2M module + server := newM2MES256(t) + + // Client generates its own key pair. + clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate client key: %v", err) + } + + // Register client's public key as trusted by the server. + server.AddTrustedKey("client-service", &clientKey.PublicKey) + + // Client creates a JWT assertion. + claims := jwt.MapClaims{ + "iss": "client-service", + "sub": "client-service", + "aud": "test-issuer", + "iat": time.Now().Unix(), + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + assertion, err := tok.SignedString(clientKey) + if err != nil { + t.Fatalf("sign assertion: %v", err) + } + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {assertion}, + "scope": {"api"}, + } + w := postToken(t, server, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + if resp["access_token"] == nil || resp["access_token"] == "" { + t.Error("expected non-empty access_token") + } +} + +func TestM2M_JWTBearer_HS256_Valid(t *testing.T) { + m := newM2MHS256(t) + + // Create a JWT assertion signed with the module's own HMAC secret. + claims := jwt.MapClaims{ + "iss": "internal-service", + "sub": "internal-service", + "aud": "test-issuer", + "iat": time.Now().Unix(), + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + assertion, err := tok.SignedString([]byte("this-is-a-valid-secret-32-bytes!")) + if err != nil { + t.Fatalf("sign assertion: %v", err) + } + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {assertion}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestM2M_JWTBearer_MissingAssertion(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for missing assertion, got %d", w.Code) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "invalid_request" { + t.Errorf("expected error=invalid_request, got %q", resp["error"]) + } +} + +func TestM2M_JWTBearer_InvalidSignature(t *testing.T) { + m := newM2MHS256(t) + + // Sign with a different secret. + claims := jwt.MapClaims{ + "sub": "service", + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + badAssertion, _ := tok.SignedString([]byte("wrong-secret-that-is-long-enough-x")) + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {badAssertion}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for bad assertion, got %d", w.Code) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "invalid_grant" { + t.Errorf("expected error=invalid_grant, got %q", resp["error"]) + } +} + +func TestM2M_JWTBearer_ExpiredAssertion(t *testing.T) { + m := newM2MHS256(t) + + claims := jwt.MapClaims{ + "sub": "service", + "exp": time.Now().Add(-time.Minute).Unix(), // expired + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + expired, _ := tok.SignedString([]byte("this-is-a-valid-secret-32-bytes!")) + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {expired}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for expired assertion, got %d", w.Code) + } +} + +func TestM2M_JWTBearer_MissingSub(t *testing.T) { + m := newM2MHS256(t) + + claims := jwt.MapClaims{ + // no "sub" + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + assertion, _ := tok.SignedString([]byte("this-is-a-valid-secret-32-bytes!")) + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {assertion}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for missing sub, got %d", w.Code) + } +} + +func TestM2M_JWTBearer_UntrustedKey(t *testing.T) { + // Module with ES256 but no trusted keys (and no hmac secret). + m := newM2MES256(t) + + clientKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + // Note: key is NOT added to trusted keys. + + claims := jwt.MapClaims{ + "sub": "unknown-service", + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + assertion, _ := tok.SignedString(clientKey) + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {assertion}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 when no trusted key matches, got %d", w.Code) + } +} + +// TestM2M_JWTBearer_KeySelectedByIss verifies that validation selects the key +// that matches the assertion's iss claim, not an arbitrary trusted key. +func TestM2M_JWTBearer_KeySelectedByIss(t *testing.T) { + server := newM2MES256(t) + + // Register two different keys for two different issuers. + keyA, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + keyB, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + server.AddTrustedKey("service-a", &keyA.PublicKey) + server.AddTrustedKey("service-b", &keyB.PublicKey) + + // Build an assertion claiming iss=service-a but signed with keyB (mismatch). + badClaims := jwt.MapClaims{ + "iss": "service-a", + "sub": "service-a", + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + badTok := jwt.NewWithClaims(jwt.SigningMethodES256, badClaims) + badAssertion, _ := badTok.SignedString(keyB) // signed by keyB but iss=service-a + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {badAssertion}, + } + w := postToken(t, server, params) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 when assertion iss/key mismatch, got %d; body: %s", w.Code, w.Body.String()) + } +} + +// TestM2M_JWTBearer_KeySelectedByKid verifies that the kid header is used for key lookup. +func TestM2M_JWTBearer_KeySelectedByKid(t *testing.T) { + server := newM2MES256(t) + + clientKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + server.AddTrustedKey("my-kid", &clientKey.PublicKey) + + claims := jwt.MapClaims{ + "iss": "some-service", + "sub": "some-service", + "exp": time.Now().Add(5 * time.Minute).Unix(), + } + // Set kid in header; server should find key by kid. + tok := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + tok.Header["kid"] = "my-kid" + assertion, err := tok.SignedString(clientKey) + if err != nil { + t.Fatalf("sign assertion: %v", err) + } + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {assertion}, + } + w := postToken(t, server, params) + + if w.Code != http.StatusOK { + t.Errorf("expected 200 for kid-based key lookup, got %d; body: %s", w.Code, w.Body.String()) + } +} + +// --- unsupported grant type --- + +func TestM2M_UnsupportedGrantType(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"authorization_code"}, + "code": {"abc"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for unsupported grant, got %d", w.Code) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "unsupported_grant_type" { + t.Errorf("expected error=unsupported_grant_type, got %q", resp["error"]) + } +} + +// --- not found / unknown route --- + +func TestM2M_UnknownRoute(t *testing.T) { + m := newM2MHS256(t) + + req := httptest.NewRequest(http.MethodGet, "/oauth/unknown", nil) + w := httptest.NewRecorder() + m.Handle(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } +} + +// --- Authenticate (AuthProvider interface) --- + +func TestM2M_Authenticate_HS256_Valid(t *testing.T) { + m := newM2MHS256(t) + + // Get a token. + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + } + w := postToken(t, m, params) + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + valid, claims, err := m.Authenticate(tokenStr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !valid { + t.Error("expected token to be valid") + } + if claims["sub"] != "test-client" { + t.Errorf("expected sub=test-client, got %v", claims["sub"]) + } +} + +func TestM2M_Authenticate_HS256_Invalid(t *testing.T) { + m := newM2MHS256(t) + + valid, _, err := m.Authenticate("not.a.jwt") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if valid { + t.Error("expected invalid token to not authenticate") + } +} + +func TestM2M_Authenticate_ES256_Valid(t *testing.T) { + m := newM2MES256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"es256-client"}, + "client_secret": {"es256-secret"}, + } + w := postToken(t, m, params) + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + valid, claims, err := m.Authenticate(tokenStr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !valid { + t.Error("expected token to be valid") + } + if claims["iss"] != "test-issuer" { + t.Errorf("expected iss=test-issuer, got %v", claims["iss"]) + } +} + +func TestM2M_Authenticate_ES256_NoPublicKey(t *testing.T) { + m := &M2MAuthModule{ + name: "m2m", + algorithm: SigningAlgES256, + // no publicKey set + } + _, _, err := m.Authenticate("some.jwt.token") + if err == nil { + t.Error("expected error when no public key configured") + } +} + +// --- JWK helpers --- + +func TestM2M_ecPublicKeyToJWK(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + jwk, err := ecPublicKeyToJWK(&key.PublicKey, "test-key") + if err != nil { + t.Fatalf("ecPublicKeyToJWK: %v", err) + } + if jwk["kty"] != "EC" { + t.Errorf("expected kty=EC, got %v", jwk["kty"]) + } + if jwk["crv"] != "P-256" { + t.Errorf("expected crv=P-256, got %v", jwk["crv"]) + } + if jwk["kid"] != "test-key" { + t.Errorf("expected kid=test-key, got %v", jwk["kid"]) + } +} + +func TestM2M_jwkToECPublicKey_InvalidKty(t *testing.T) { + _, err := jwkToECPublicKey(map[string]any{"kty": "RSA"}) + if err == nil { + t.Error("expected error for kty=RSA") + } +} + +func TestM2M_jwkToECPublicKey_InvalidCrv(t *testing.T) { + _, err := jwkToECPublicKey(map[string]any{"kty": "EC", "crv": "P-384"}) + if err == nil { + t.Error("expected error for crv=P-384") + } +} + +func TestM2M_jwkToECPublicKey_MissingCoords(t *testing.T) { + _, err := jwkToECPublicKey(map[string]any{"kty": "EC", "crv": "P-256"}) + if err == nil { + t.Error("expected error for missing x/y") + } +} + +func TestM2M_jwkThumbprint_Deterministic(t *testing.T) { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + a := jwkThumbprint(&key.PublicKey) + b := jwkThumbprint(&key.PublicKey) + if a != b { + t.Error("thumbprint must be deterministic") + } +} + +func TestM2M_AddTrustedKey(t *testing.T) { + m := newM2MES256(t) + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + m.AddTrustedKey("svc", &key.PublicKey) + + m.mu.RLock() + stored := m.trustedKeys["svc"] + m.mu.RUnlock() + + if stored == nil { + t.Error("expected key to be stored") + } +} + +// --- DefaultExpiry / issuer defaults --- + +func TestM2M_DefaultExpiry(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", 0, "") + if m.tokenExpiry != time.Hour { + t.Errorf("expected default 1h expiry, got %v", m.tokenExpiry) + } + if m.issuer != "workflow" { + t.Errorf("expected default issuer 'workflow', got %q", m.issuer) + } +} + +// --- token claims --- + +func TestM2M_TokenClaims_Issuer(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + } + w := postToken(t, m, params) + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + _, claims, _ := m.Authenticate(tokenStr) + if claims["iss"] != "test-issuer" { + t.Errorf("expected iss=test-issuer, got %v", claims["iss"]) + } +} + +func TestM2M_RegisterClient(t *testing.T) { + m := newM2MHS256(t) + m.RegisterClient(M2MClient{ + ClientID: "new-client", + ClientSecret: "new-secret-long-enough", //nolint:gosec // test credential + Scopes: []string{"read"}, + }) + + m.mu.RLock() + c, ok := m.clients["new-client"] + m.mu.RUnlock() + + if !ok { + t.Fatal("expected new client to be registered") + } + if c.ClientID != "new-client" { + t.Errorf("expected clientID 'new-client', got %q", c.ClientID) + } +} + +// --- JWK thumbprint used for key ID --- + +func TestM2M_JWKSKeyID(t *testing.T) { + m := newM2MES256(t) + + req := httptest.NewRequest(http.MethodGet, "/oauth/jwks", nil) + w := httptest.NewRecorder() + m.Handle(w, req) + + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + keys := resp["keys"].([]any) + jwk := keys[0].(map[string]any) + + kid, _ := jwk["kid"].(string) + if kid == "" { + t.Error("expected non-empty kid in JWK") + } +} + +// --- base64url encoding sanity check --- + +func TestM2M_ecPublicKeyToJWK_CoordinatesDecodable(t *testing.T) { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + jwk, err := ecPublicKeyToJWK(&key.PublicKey, "kid") + if err != nil { + t.Fatalf("ecPublicKeyToJWK: %v", err) + } + + xStr, _ := jwk["x"].(string) + yStr, _ := jwk["y"].(string) + + xBytes, err := base64.RawURLEncoding.DecodeString(xStr) + if err != nil { + t.Fatalf("decode x: %v", err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(yStr) + if err != nil { + t.Fatalf("decode y: %v", err) + } + if len(xBytes) != 32 { + t.Errorf("expected x to be 32 bytes, got %d", len(xBytes)) + } + if len(yBytes) != 32 { + t.Errorf("expected y to be 32 bytes, got %d", len(yBytes)) + } + + // Reconstructed key should match original via ECDH bytes. + pub, err := jwkToECPublicKey(jwk) + if err != nil { + t.Fatalf("jwkToECPublicKey: %v", err) + } + origECDH, err := key.PublicKey.ECDH() + if err != nil { + t.Fatalf("original key ECDH: %v", err) + } + reconECDH, err := pub.ECDH() + if err != nil { + t.Fatalf("reconstructed key ECDH: %v", err) + } + if string(origECDH.Bytes()) != string(reconECDH.Bytes()) { + t.Error("round-trip key mismatch") + } +} + +// Test that oauthError returns the expected structure. +func TestM2M_oauthError(t *testing.T) { + e := oauthError("invalid_client", "bad creds") + if e["error"] != "invalid_client" { + t.Errorf("expected error=invalid_client, got %q", e["error"]) + } + if e["error_description"] != "bad creds" { + t.Errorf("expected error_description='bad creds', got %q", e["error_description"]) + } +} + +// Test that issueToken doesn't override iss/sub with extraClaims. +func TestM2M_IssueToken_ProtectedClaims(t *testing.T) { + m := newM2MHS256(t) + extra := map[string]any{ + "iss": "evil-issuer", + "sub": "evil-sub", + "custom": "value", + } + tokenStr, err := m.issueToken("legit-subject", nil, extra) + if err != nil { + t.Fatalf("issueToken: %v", err) + } + _, claims, _ := m.Authenticate(tokenStr) + if claims["iss"] != "test-issuer" { + t.Errorf("iss should not be overridable, got %v", claims["iss"]) + } + if claims["sub"] != "legit-subject" { + t.Errorf("sub should not be overridable, got %v", claims["sub"]) + } + if claims["custom"] != "value" { + t.Errorf("expected custom claim to be passed through, got %v", claims["custom"]) + } +} + +// Verify the JWT-bearer grant passes through extra claims from the assertion. +func TestM2M_JWTBearer_ExtraClaimsPassedThrough(t *testing.T) { + m := newM2MHS256(t) + + claims := jwt.MapClaims{ + "sub": "svc", + "exp": time.Now().Add(5 * time.Minute).Unix(), + "team": "platform", + "tenantId": "acme", + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + assertion, _ := tok.SignedString([]byte("this-is-a-valid-secret-32-bytes!")) + + params := url.Values{ + "grant_type": {GrantTypeJWTBearer}, + "assertion": {assertion}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + _, issuedClaims, _ := m.Authenticate(tokenStr) + if issuedClaims["team"] != "platform" { + t.Errorf("expected team=platform, got %v", issuedClaims["team"]) + } + if issuedClaims["tenantId"] != "acme" { + t.Errorf("expected tenantId=acme, got %v", issuedClaims["tenantId"]) + } +} + +// Test RequiresServices returns nil. +func TestM2M_RequiresServices(t *testing.T) { + m := newM2MHS256(t) + if deps := m.RequiresServices(); deps != nil { + t.Errorf("expected nil deps, got %v", deps) + } +} + +// Test that the JWKS response has content-type application/json. +func TestM2M_Handle_ContentTypeJSON(t *testing.T) { + m := newM2MES256(t) + req := httptest.NewRequest(http.MethodGet, "/oauth/jwks", nil) + w := httptest.NewRecorder() + m.Handle(w, req) + + ct := w.Header().Get("Content-Type") + if !strings.Contains(ct, "application/json") { + t.Errorf("expected application/json content-type, got %q", ct) + } +} + +// Test a client with no scopes gets empty scope in response. +func TestM2M_ClientCredentials_NoClientScopes(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "issuer") + m.RegisterClient(M2MClient{ + ClientID: "no-scope-client", + ClientSecret: "no-scope-secret", //nolint:gosec // test credential + Scopes: nil, // no scopes configured + }) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"no-scope-client"}, + "client_secret": {"no-scope-secret"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + // No scopes → empty scope value. + scopeVal, _ := resp["scope"].(string) + if scopeVal != "" { + t.Errorf("expected empty scope for no-scope client, got %q", scopeVal) + } +} + +// Verify that the issued token sub matches the client_id. +func TestM2M_ClientCredentials_SubMatchesClientID(t *testing.T) { + m := newM2MHS256(t) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"test-client"}, + "client_secret": {"test-secret"}, + } + w := postToken(t, m, params) + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + _, claims, _ := m.Authenticate(tokenStr) + if claims["sub"] != "test-client" { + t.Errorf("expected sub=test-client, got %v", claims["sub"]) + } +} diff --git a/plugins/auth/plugin.go b/plugins/auth/plugin.go index e47797b5..0dec02b1 100644 --- a/plugins/auth/plugin.go +++ b/plugins/auth/plugin.go @@ -11,9 +11,9 @@ import ( "github.com/GoCodeAlone/workflow/schema" ) -// Plugin provides authentication capabilities: auth.jwt, auth.user-store, and -// auth.oauth2 modules plus the wiring hook that connects AuthProviders to -// AuthMiddleware. +// Plugin provides authentication capabilities: auth.jwt, auth.user-store, +// auth.oauth2, and auth.m2m modules plus the wiring hook that connects +// AuthProviders to AuthMiddleware. type Plugin struct { plugin.BaseEnginePlugin } @@ -37,6 +37,7 @@ func New() *Plugin { "auth.jwt", "auth.user-store", "auth.oauth2", + "auth.m2m", }, Capabilities: []plugin.CapabilityDecl{ {Name: "authentication", Role: "provider", Priority: 10}, @@ -126,6 +127,55 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { // jwtAuth will be wired during the wiring hook. return module.NewOAuth2Module(name, providerCfgs, nil) }, + "auth.m2m": func(name string, cfg map[string]any) modular.Module { + secret := stringFromMap(cfg, "secret") + tokenExpiry := time.Hour + if te, ok := cfg["tokenExpiry"].(string); ok && te != "" { + if d, err := time.ParseDuration(te); err == nil { + tokenExpiry = d + } + } + issuer := "workflow" + if iss, ok := cfg["issuer"].(string); ok && iss != "" { + issuer = iss + } + m := module.NewM2MAuthModule(name, secret, tokenExpiry, issuer) + + if algo, ok := cfg["algorithm"].(string); ok && module.SigningAlgorithm(algo) == module.SigningAlgES256 { + var keyErr error + if pemKey, ok := cfg["privateKey"].(string); ok && pemKey != "" { + keyErr = m.SetECDSAKey(pemKey) + } else { + keyErr = m.GenerateECDSAKey() + } + if keyErr != nil { + m.SetInitErr(keyErr) + } + } + + if clients, ok := cfg["clients"].([]any); ok { + for _, c := range clients { + if cm, ok := c.(map[string]any); ok { + client := module.M2MClient{ + ClientID: stringFromMap(cm, "clientId"), + ClientSecret: stringFromMap(cm, "clientSecret"), + Description: stringFromMap(cm, "description"), + } + if scopes, ok := cm["scopes"].([]any); ok { + for _, s := range scopes { + if sv, ok := s.(string); ok { + client.Scopes = append(client.Scopes, sv) + } + } + } + if client.ClientID != "" && client.ClientSecret != "" { + m.RegisterClient(client) + } + } + } + } + return m + }, } } @@ -229,5 +279,22 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema { }, DefaultConfig: map[string]any{"providers": []any{}}, }, + { + Type: "auth.m2m", + Label: "M2M Auth", + Category: "middleware", + Description: "Machine-to-machine OAuth2 auth: client_credentials grant, JWT-bearer assertion grant, ES256/HS256 token issuance, and JWKS endpoint", + Inputs: []schema.ServiceIODef{{Name: "client-credentials", Type: "ClientCredentials", Description: "OAuth2 client_id + client_secret, or a signed JWT assertion"}}, + Outputs: []schema.ServiceIODef{{Name: "access-token", Type: "BearerToken", Description: "Signed access token (HS256 or ES256)"}}, + ConfigFields: []schema.ConfigFieldDef{ + {Key: "secret", Label: "HMAC Secret", Type: schema.FieldTypeString, Description: "Secret for HS256 token signing (min 32 bytes; leave blank for ES256)", Placeholder: "$M2M_SECRET", Sensitive: true}, + {Key: "algorithm", Label: "Signing Algorithm", Type: schema.FieldTypeSelect, Options: []string{"HS256", "ES256"}, DefaultValue: "ES256", Description: "JWT signing algorithm: ES256 (ECDSA P-256) or HS256 (symmetric)"}, + {Key: "privateKey", Label: "EC Private Key (PEM)", Type: schema.FieldTypeString, Description: "PEM-encoded EC private key for ES256 signing; if omitted a key is auto-generated", Sensitive: true}, + {Key: "tokenExpiry", Label: "Token Expiry", Type: schema.FieldTypeDuration, DefaultValue: "1h", Description: "Access token expiration duration (e.g. 15m, 1h)", Placeholder: "1h"}, + {Key: "issuer", Label: "Issuer", Type: schema.FieldTypeString, DefaultValue: "workflow", Description: "Token issuer (iss) claim", Placeholder: "workflow"}, + {Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description}]"}, + }, + DefaultConfig: map[string]any{"algorithm": "ES256", "tokenExpiry": "1h", "issuer": "workflow", "clients": []any{}}, + }, } } diff --git a/plugins/auth/plugin_test.go b/plugins/auth/plugin_test.go index 2c4286f8..0bd49919 100644 --- a/plugins/auth/plugin_test.go +++ b/plugins/auth/plugin_test.go @@ -21,8 +21,8 @@ func TestPluginManifest(t *testing.T) { if m.Name != "auth" { t.Errorf("expected name %q, got %q", "auth", m.Name) } - if len(m.ModuleTypes) != 3 { - t.Errorf("expected 3 module types, got %d", len(m.ModuleTypes)) + if len(m.ModuleTypes) != 4 { + t.Errorf("expected 4 module types, got %d", len(m.ModuleTypes)) } if len(m.WiringHooks) != 2 { t.Errorf("expected 2 wiring hooks, got %d", len(m.WiringHooks)) @@ -50,7 +50,7 @@ func TestModuleFactories(t *testing.T) { p := New() factories := p.ModuleFactories() - expectedTypes := []string{"auth.jwt", "auth.user-store", "auth.oauth2"} + expectedTypes := []string{"auth.jwt", "auth.user-store", "auth.oauth2", "auth.m2m"} for _, typ := range expectedTypes { factory, ok := factories[typ] if !ok { @@ -103,15 +103,15 @@ func TestWiringHooks(t *testing.T) { func TestModuleSchemas(t *testing.T) { p := New() schemas := p.ModuleSchemas() - if len(schemas) != 3 { - t.Fatalf("expected 3 module schemas, got %d", len(schemas)) + if len(schemas) != 4 { + t.Fatalf("expected 4 module schemas, got %d", len(schemas)) } types := map[string]bool{} for _, s := range schemas { types[s.Type] = true } - for _, expected := range []string{"auth.jwt", "auth.user-store", "auth.oauth2"} { + for _, expected := range []string{"auth.jwt", "auth.user-store", "auth.oauth2", "auth.m2m"} { if !types[expected] { t.Errorf("missing schema for %q", expected) }