@@ -63,6 +63,32 @@ type M2MClient struct {
6363 Claims map [string ]any `json:"claims,omitempty"`
6464}
6565
66+ // TrustedKeyConfig holds the configuration for a trusted external JWT issuer.
67+ // It is used to register trusted keys for the JWT-bearer grant via YAML configuration.
68+ type TrustedKeyConfig struct {
69+ // Issuer is the expected `iss` claim value (e.g. "https://legacy-platform.example.com").
70+ Issuer string `json:"issuer" yaml:"issuer"`
71+ // Algorithm is the expected signing algorithm (e.g. "ES256"). Currently only ES256 is supported.
72+ Algorithm string `json:"algorithm,omitempty" yaml:"algorithm,omitempty"`
73+ // PublicKeyPEM is the PEM-encoded EC public key for the trusted issuer.
74+ // Literal `\n` sequences (common in Docker/Kubernetes env vars) are normalised to newlines.
75+ PublicKeyPEM string `json:"publicKeyPEM,omitempty" yaml:"publicKeyPEM,omitempty"` //nolint:gosec // G117: config DTO field
76+ // Audiences is an optional list of accepted audience values.
77+ // When non-empty, the assertion's `aud` claim must contain at least one of these values.
78+ Audiences []string `json:"audiences,omitempty" yaml:"audiences,omitempty"`
79+ // ClaimMapping renames claims from the external assertion before they are included in the
80+ // issued token. The map key is the external claim name; the value is the local claim name.
81+ // For example {"user_id": "sub"} promotes the external `user_id` claim to `sub`.
82+ ClaimMapping map [string ]string `json:"claimMapping,omitempty" yaml:"claimMapping,omitempty"`
83+ }
84+
85+ // trustedKeyEntry is the internal representation of a trusted external JWT issuer.
86+ type trustedKeyEntry struct {
87+ pubKey * ecdsa.PublicKey
88+ audiences []string
89+ claimMapping map [string ]string
90+ }
91+
6692// M2MAuthModule provides machine-to-machine (server-to-server) OAuth2 authentication.
6793// It supports the client_credentials grant and the JWT-bearer grant, and can issue
6894// tokens signed with either HS256 (shared secret) or ES256 (ECDSA P-256).
@@ -84,7 +110,7 @@ type M2MAuthModule struct {
84110 publicKey * ecdsa.PublicKey
85111
86112 // Trusted public keys for JWT-bearer grant (keyed by key ID or issuer)
87- trustedKeys map [string ]* ecdsa. PublicKey
113+ trustedKeys map [string ]* trustedKeyEntry
88114
89115 // Registered clients
90116 mu sync.RWMutex
@@ -116,7 +142,7 @@ func NewM2MAuthModule(name string, hmacSecret string, tokenExpiry time.Duration,
116142 issuer : issuer ,
117143 tokenExpiry : tokenExpiry ,
118144 hmacSecret : []byte (hmacSecret ),
119- trustedKeys : make (map [string ]* ecdsa. PublicKey ),
145+ trustedKeys : make (map [string ]* trustedKeyEntry ),
120146 clients : make (map [string ]* M2MClient ),
121147 jtiBlacklist : make (map [string ]time.Time ),
122148 }
@@ -166,7 +192,46 @@ func (m *M2MAuthModule) SetInitErr(err error) {
166192func (m * M2MAuthModule ) AddTrustedKey (keyID string , pubKey * ecdsa.PublicKey ) {
167193 m .mu .Lock ()
168194 defer m .mu .Unlock ()
169- m .trustedKeys [keyID ] = pubKey
195+ m .trustedKeys [keyID ] = & trustedKeyEntry {pubKey : pubKey }
196+ }
197+
198+ // AddTrustedKeyFromPEM parses a PEM-encoded EC public key and registers it as a trusted
199+ // key for JWT-bearer assertion validation. Literal `\n` sequences in the PEM string are
200+ // normalised to real newlines so that env-var-injected keys (Docker/Kubernetes) work without
201+ // additional preprocessing by the caller.
202+ //
203+ // audiences is an optional list; when non-empty the assertion's `aud` claim must match at
204+ // least one entry. claimMapping renames external claims before they are forwarded into the
205+ // issued token (map key = external name, map value = local name).
206+ func (m * M2MAuthModule ) AddTrustedKeyFromPEM (issuer , publicKeyPEM string , audiences []string , claimMapping map [string ]string ) error {
207+ // Normalise escaped newlines that are common in Docker/Kubernetes env vars.
208+ normalised := strings .ReplaceAll (publicKeyPEM , `\n` , "\n " )
209+
210+ block , _ := pem .Decode ([]byte (normalised ))
211+ if block == nil {
212+ return fmt .Errorf ("auth.m2m: failed to decode PEM block for issuer %q" , issuer )
213+ }
214+
215+ pubAny , err := x509 .ParsePKIXPublicKey (block .Bytes )
216+ if err != nil {
217+ return fmt .Errorf ("auth.m2m: parse public key for issuer %q: %w" , issuer , err )
218+ }
219+ ecKey , ok := pubAny .(* ecdsa.PublicKey )
220+ if ! ok {
221+ return fmt .Errorf ("auth.m2m: public key for issuer %q is not an ECDSA key" , issuer )
222+ }
223+ if ecKey .Curve != elliptic .P256 () {
224+ return fmt .Errorf ("auth.m2m: public key for issuer %q must use P-256 (ES256) curve" , issuer )
225+ }
226+
227+ m .mu .Lock ()
228+ defer m .mu .Unlock ()
229+ m .trustedKeys [issuer ] = & trustedKeyEntry {
230+ pubKey : ecKey ,
231+ audiences : audiences ,
232+ claimMapping : claimMapping ,
233+ }
234+ return nil
170235}
171236
172237// RegisterClient registers a new OAuth2 client.
@@ -676,19 +741,19 @@ func (m *M2MAuthModule) validateJWTAssertion(assertion string) (jwt.MapClaims, e
676741
677742 m .mu .RLock ()
678743 // Try kid first, then iss.
679- var selectedKey * ecdsa. PublicKey
744+ var selectedEntry * trustedKeyEntry
680745 if kid != "" {
681- selectedKey = m .trustedKeys [kid ]
746+ selectedEntry = m .trustedKeys [kid ]
682747 }
683- if selectedKey == nil && iss != "" {
684- selectedKey = m .trustedKeys [iss ]
748+ if selectedEntry == nil && iss != "" {
749+ selectedEntry = m .trustedKeys [iss ]
685750 }
686751 hmacSecret := m .hmacSecret
687752 m .mu .RUnlock ()
688753
689754 // Try EC key if found.
690- if selectedKey != nil {
691- k := selectedKey
755+ if selectedEntry != nil && selectedEntry . pubKey != nil {
756+ k := selectedEntry . pubKey
692757 token , err := jwt .Parse (assertion , func (token * jwt.Token ) (any , error ) {
693758 if _ , ok := token .Method .(* jwt.SigningMethodECDSA ); ! ok {
694759 return nil , fmt .Errorf ("unexpected signing method: %v" , token .Header ["alg" ])
@@ -702,6 +767,19 @@ func (m *M2MAuthModule) validateJWTAssertion(assertion string) (jwt.MapClaims, e
702767 if ! ok || ! token .Valid {
703768 return nil , fmt .Errorf ("invalid assertion claims" )
704769 }
770+
771+ // Validate audience if configured.
772+ if len (selectedEntry .audiences ) > 0 {
773+ if err := validateAssertionAudience (claims , selectedEntry .audiences ); err != nil {
774+ return nil , err
775+ }
776+ }
777+
778+ // Apply claim mapping if configured.
779+ if len (selectedEntry .claimMapping ) > 0 {
780+ claims = applyAssertionClaimMapping (claims , selectedEntry .claimMapping )
781+ }
782+
705783 return claims , nil
706784 }
707785
@@ -1032,3 +1110,50 @@ func oauthError(code, description string) map[string]string {
10321110 "error_description" : description ,
10331111 }
10341112}
1113+
1114+ // validateAssertionAudience checks that the JWT claims contain at least one of the
1115+ // required audience values. The `aud` claim can be a single string or a JSON array.
1116+ func validateAssertionAudience (claims jwt.MapClaims , requiredAudiences []string ) error {
1117+ aud := claims ["aud" ]
1118+ if aud == nil {
1119+ return fmt .Errorf ("assertion missing aud claim, expected one of %v" , requiredAudiences )
1120+ }
1121+ var tokenAuds []string
1122+ switch v := aud .(type ) {
1123+ case string :
1124+ tokenAuds = []string {v }
1125+ case []any :
1126+ for _ , a := range v {
1127+ if s , ok := a .(string ); ok {
1128+ tokenAuds = append (tokenAuds , s )
1129+ }
1130+ }
1131+ }
1132+ for _ , required := range requiredAudiences {
1133+ for _ , tokenAud := range tokenAuds {
1134+ if tokenAud == required {
1135+ return nil
1136+ }
1137+ }
1138+ }
1139+ return fmt .Errorf ("assertion audience %v does not include required audience %v" , tokenAuds , requiredAudiences )
1140+ }
1141+
1142+ // applyAssertionClaimMapping renames claims from an external assertion before they are
1143+ // forwarded into the issued token. The mapping key is the external claim name; the
1144+ // value is the local claim name. The original claim is removed when the names differ.
1145+ func applyAssertionClaimMapping (claims jwt.MapClaims , mapping map [string ]string ) jwt.MapClaims {
1146+ result := make (jwt.MapClaims , len (claims ))
1147+ for k , v := range claims {
1148+ result [k ] = v
1149+ }
1150+ for externalKey , localKey := range mapping {
1151+ if val , exists := claims [externalKey ]; exists {
1152+ result [localKey ] = val
1153+ if externalKey != localKey {
1154+ delete (result , externalKey )
1155+ }
1156+ }
1157+ }
1158+ return result
1159+ }
0 commit comments