Skip to content

Commit 75f259d

Browse files
Copilotintel352
andauthored
feat(auth.m2m): per-client custom claims in client_credentials tokens (#236)
* Initial plan * feat: add per-client claims support to auth.m2m module Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: intel352 <77607+intel352@users.noreply.github.com>
1 parent 8bbb6dc commit 75f259d

4 files changed

Lines changed: 198 additions & 6 deletions

File tree

module/auth_m2m.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ const (
4040

4141
// M2MClient represents a registered machine-to-machine OAuth2 client.
4242
type M2MClient struct {
43-
ClientID string `json:"clientId"`
44-
ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field
45-
Description string `json:"description,omitempty"`
46-
Scopes []string `json:"scopes,omitempty"`
43+
ClientID string `json:"clientId"`
44+
ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field
45+
Description string `json:"description,omitempty"`
46+
Scopes []string `json:"scopes,omitempty"`
47+
Claims map[string]any `json:"claims,omitempty"`
4748
}
4849

4950
// M2MAuthModule provides machine-to-machine (server-to-server) OAuth2 authentication.
@@ -251,7 +252,7 @@ func (m *M2MAuthModule) handleClientCredentials(w http.ResponseWriter, r *http.R
251252
return
252253
}
253254

254-
token, err := m.issueToken(clientID, grantedScopes, nil)
255+
token, err := m.issueToken(clientID, grantedScopes, client.Claims)
255256
if err != nil {
256257
w.WriteHeader(http.StatusInternalServerError)
257258
_ = json.NewEncoder(w).Encode(oauthError("server_error", "failed to issue token"))

module/auth_m2m_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,3 +1171,143 @@ func TestM2M_ClientCredentials_SubMatchesClientID(t *testing.T) {
11711171
t.Errorf("expected sub=test-client, got %v", claims["sub"])
11721172
}
11731173
}
1174+
1175+
// --- per-client custom claims ---
1176+
1177+
// TestM2M_ClientCredentials_CustomClaimsInToken verifies that a client's Claims
1178+
// map is included in the issued access token.
1179+
func TestM2M_ClientCredentials_CustomClaimsInToken(t *testing.T) {
1180+
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer")
1181+
m.RegisterClient(M2MClient{
1182+
ClientID: "org-alpha",
1183+
ClientSecret: "secret-org-alpha", //nolint:gosec // test credential
1184+
Scopes: []string{"read"},
1185+
Claims: map[string]any{"tenant_id": "alpha"},
1186+
})
1187+
1188+
params := url.Values{
1189+
"grant_type": {"client_credentials"},
1190+
"client_id": {"org-alpha"},
1191+
"client_secret": {"secret-org-alpha"},
1192+
}
1193+
w := postToken(t, m, params)
1194+
1195+
if w.Code != http.StatusOK {
1196+
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
1197+
}
1198+
var resp map[string]any
1199+
json.NewDecoder(w.Body).Decode(&resp)
1200+
tokenStr, _ := resp["access_token"].(string)
1201+
1202+
_, claims, err := m.Authenticate(tokenStr)
1203+
if err != nil {
1204+
t.Fatalf("authenticate: %v", err)
1205+
}
1206+
if claims["tenant_id"] != "alpha" {
1207+
t.Errorf("expected tenant_id=alpha, got %v", claims["tenant_id"])
1208+
}
1209+
}
1210+
1211+
// TestM2M_ClientCredentials_MultipleCustomClaims verifies that multiple custom
1212+
// claims are all present in the issued token.
1213+
func TestM2M_ClientCredentials_MultipleCustomClaims(t *testing.T) {
1214+
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer")
1215+
m.RegisterClient(M2MClient{
1216+
ClientID: "org-beta",
1217+
ClientSecret: "secret-org-beta", //nolint:gosec // test credential
1218+
Scopes: []string{"read", "write"},
1219+
Claims: map[string]any{
1220+
"tenant_id": "beta",
1221+
"affiliate_id": "partner-42",
1222+
},
1223+
})
1224+
1225+
params := url.Values{
1226+
"grant_type": {"client_credentials"},
1227+
"client_id": {"org-beta"},
1228+
"client_secret": {"secret-org-beta"},
1229+
}
1230+
w := postToken(t, m, params)
1231+
1232+
if w.Code != http.StatusOK {
1233+
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
1234+
}
1235+
var resp map[string]any
1236+
json.NewDecoder(w.Body).Decode(&resp)
1237+
tokenStr, _ := resp["access_token"].(string)
1238+
1239+
_, claims, err := m.Authenticate(tokenStr)
1240+
if err != nil {
1241+
t.Fatalf("authenticate: %v", err)
1242+
}
1243+
if claims["tenant_id"] != "beta" {
1244+
t.Errorf("expected tenant_id=beta, got %v", claims["tenant_id"])
1245+
}
1246+
if claims["affiliate_id"] != "partner-42" {
1247+
t.Errorf("expected affiliate_id=partner-42, got %v", claims["affiliate_id"])
1248+
}
1249+
}
1250+
1251+
// TestM2M_ClientCredentials_CustomClaimsDoNotOverrideStandard verifies that
1252+
// custom claims on a client cannot override standard JWT claims.
1253+
func TestM2M_ClientCredentials_CustomClaimsDoNotOverrideStandard(t *testing.T) {
1254+
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "trusted-issuer")
1255+
m.RegisterClient(M2MClient{
1256+
ClientID: "attacker",
1257+
ClientSecret: "attacker-secret-here", //nolint:gosec // test credential
1258+
Scopes: []string{"read"},
1259+
Claims: map[string]any{
1260+
"iss": "evil-issuer",
1261+
"sub": "admin",
1262+
},
1263+
})
1264+
1265+
params := url.Values{
1266+
"grant_type": {"client_credentials"},
1267+
"client_id": {"attacker"},
1268+
"client_secret": {"attacker-secret-here"},
1269+
}
1270+
w := postToken(t, m, params)
1271+
1272+
if w.Code != http.StatusOK {
1273+
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
1274+
}
1275+
var resp map[string]any
1276+
json.NewDecoder(w.Body).Decode(&resp)
1277+
tokenStr, _ := resp["access_token"].(string)
1278+
1279+
_, claims, err := m.Authenticate(tokenStr)
1280+
if err != nil {
1281+
t.Fatalf("authenticate: %v", err)
1282+
}
1283+
// Standard claims must not be overridden by client.Claims.
1284+
if claims["iss"] != "trusted-issuer" {
1285+
t.Errorf("iss must not be overridable via client claims, got %v", claims["iss"])
1286+
}
1287+
if claims["sub"] != "attacker" {
1288+
t.Errorf("sub must not be overridable via client claims, got %v", claims["sub"])
1289+
}
1290+
}
1291+
1292+
// TestM2M_ClientCredentials_NilClaimsOK verifies that a client with nil Claims
1293+
// still issues tokens without error.
1294+
func TestM2M_ClientCredentials_NilClaimsOK(t *testing.T) {
1295+
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer")
1296+
m.RegisterClient(M2MClient{
1297+
ClientID: "plain-client",
1298+
ClientSecret: "plain-client-secret!", //nolint:gosec // test credential
1299+
Scopes: []string{"read"},
1300+
Claims: nil,
1301+
})
1302+
1303+
params := url.Values{
1304+
"grant_type": {"client_credentials"},
1305+
"client_id": {"plain-client"},
1306+
"client_secret": {"plain-client-secret!"},
1307+
}
1308+
w := postToken(t, m, params)
1309+
1310+
if w.Code != http.StatusOK {
1311+
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
1312+
}
1313+
}

plugins/auth/plugin.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory {
196196
}
197197
}
198198
}
199+
if claimsRaw, ok := cm["claims"].(map[string]any); ok {
200+
client.Claims = claimsRaw
201+
}
199202
if client.ClientID != "" && client.ClientSecret != "" {
200203
m.RegisterClient(client)
201204
}
@@ -366,7 +369,7 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema {
366369
{Key: "privateKey", Label: "EC Private Key (PEM)", Type: schema.FieldTypeString, Description: "PEM-encoded EC private key for ES256 signing; if omitted a key is auto-generated", Sensitive: true},
367370
{Key: "tokenExpiry", Label: "Token Expiry", Type: schema.FieldTypeDuration, DefaultValue: "1h", Description: "Access token expiration duration (e.g. 15m, 1h)", Placeholder: "1h"},
368371
{Key: "issuer", Label: "Issuer", Type: schema.FieldTypeString, DefaultValue: "workflow", Description: "Token issuer (iss) claim", Placeholder: "workflow"},
369-
{Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description}]"},
372+
{Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description, claims}]"},
370373
},
371374
DefaultConfig: map[string]any{"algorithm": "ES256", "tokenExpiry": "1h", "issuer": "workflow", "clients": []any{}},
372375
},

plugins/auth/plugin_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package auth
22

33
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"net/url"
7+
"strings"
48
"testing"
59

10+
"github.com/GoCodeAlone/workflow/module"
611
"github.com/GoCodeAlone/workflow/plugin"
712
)
813

@@ -120,3 +125,46 @@ func TestModuleSchemas(t *testing.T) {
120125
}
121126
}
122127
}
128+
129+
func TestModuleFactoryM2MWithClaims(t *testing.T) {
130+
p := New()
131+
factories := p.ModuleFactories()
132+
133+
mod := factories["auth.m2m"]("m2m-test", map[string]any{
134+
"algorithm": "HS256",
135+
"secret": "this-is-a-valid-secret-32-bytes!",
136+
"clients": []any{
137+
map[string]any{
138+
"clientId": "org-alpha",
139+
"clientSecret": "secret-alpha",
140+
"scopes": []any{"read"},
141+
"claims": map[string]any{
142+
"tenant_id": "alpha",
143+
},
144+
},
145+
},
146+
})
147+
if mod == nil {
148+
t.Fatal("auth.m2m factory returned nil")
149+
}
150+
151+
m2mMod, ok := mod.(*module.M2MAuthModule)
152+
if !ok {
153+
t.Fatal("expected *module.M2MAuthModule")
154+
}
155+
156+
// Issue a token via the Handle method.
157+
params := url.Values{
158+
"grant_type": {"client_credentials"},
159+
"client_id": {"org-alpha"},
160+
"client_secret": {"secret-alpha"},
161+
}
162+
req := httptest.NewRequest(http.MethodPost, "/oauth/token", strings.NewReader(params.Encode()))
163+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
164+
w := httptest.NewRecorder()
165+
m2mMod.Handle(w, req)
166+
167+
if w.Code != http.StatusOK {
168+
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
169+
}
170+
}

0 commit comments

Comments
 (0)