Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions module/auth_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ const (

// M2MClient represents a registered machine-to-machine OAuth2 client.
type M2MClient struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field
Description string `json:"description,omitempty"`
Scopes []string `json:"scopes,omitempty"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field
Description string `json:"description,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Claims map[string]any `json:"claims,omitempty"`
}

// M2MAuthModule provides machine-to-machine (server-to-server) OAuth2 authentication.
Expand Down Expand Up @@ -251,7 +252,7 @@ func (m *M2MAuthModule) handleClientCredentials(w http.ResponseWriter, r *http.R
return
}

token, err := m.issueToken(clientID, grantedScopes, nil)
token, err := m.issueToken(clientID, grantedScopes, client.Claims)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(oauthError("server_error", "failed to issue token"))
Expand Down
140 changes: 140 additions & 0 deletions module/auth_m2m_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,143 @@ func TestM2M_ClientCredentials_SubMatchesClientID(t *testing.T) {
t.Errorf("expected sub=test-client, got %v", claims["sub"])
}
}

// --- per-client custom claims ---

// TestM2M_ClientCredentials_CustomClaimsInToken verifies that a client's Claims
// map is included in the issued access token.
func TestM2M_ClientCredentials_CustomClaimsInToken(t *testing.T) {
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer")
m.RegisterClient(M2MClient{
ClientID: "org-alpha",
ClientSecret: "secret-org-alpha", //nolint:gosec // test credential
Scopes: []string{"read"},
Claims: map[string]any{"tenant_id": "alpha"},
})

params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {"org-alpha"},
"client_secret": {"secret-org-alpha"},
}
w := postToken(t, m, params)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
}
var resp map[string]any
json.NewDecoder(w.Body).Decode(&resp)
tokenStr, _ := resp["access_token"].(string)

_, claims, err := m.Authenticate(tokenStr)
if err != nil {
t.Fatalf("authenticate: %v", err)
}
if claims["tenant_id"] != "alpha" {
t.Errorf("expected tenant_id=alpha, got %v", claims["tenant_id"])
}
}

// TestM2M_ClientCredentials_MultipleCustomClaims verifies that multiple custom
// claims are all present in the issued token.
func TestM2M_ClientCredentials_MultipleCustomClaims(t *testing.T) {
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer")
m.RegisterClient(M2MClient{
ClientID: "org-beta",
ClientSecret: "secret-org-beta", //nolint:gosec // test credential
Scopes: []string{"read", "write"},
Claims: map[string]any{
"tenant_id": "beta",
"affiliate_id": "partner-42",
},
})

params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {"org-beta"},
"client_secret": {"secret-org-beta"},
}
w := postToken(t, m, params)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
}
var resp map[string]any
json.NewDecoder(w.Body).Decode(&resp)
tokenStr, _ := resp["access_token"].(string)

_, claims, err := m.Authenticate(tokenStr)
if err != nil {
t.Fatalf("authenticate: %v", err)
}
if claims["tenant_id"] != "beta" {
t.Errorf("expected tenant_id=beta, got %v", claims["tenant_id"])
}
if claims["affiliate_id"] != "partner-42" {
t.Errorf("expected affiliate_id=partner-42, got %v", claims["affiliate_id"])
}
}

// TestM2M_ClientCredentials_CustomClaimsDoNotOverrideStandard verifies that
// custom claims on a client cannot override standard JWT claims.
func TestM2M_ClientCredentials_CustomClaimsDoNotOverrideStandard(t *testing.T) {
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "trusted-issuer")
m.RegisterClient(M2MClient{
ClientID: "attacker",
ClientSecret: "attacker-secret-here", //nolint:gosec // test credential
Scopes: []string{"read"},
Claims: map[string]any{
"iss": "evil-issuer",
"sub": "admin",
},
})

params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {"attacker"},
"client_secret": {"attacker-secret-here"},
}
w := postToken(t, m, params)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
}
var resp map[string]any
json.NewDecoder(w.Body).Decode(&resp)
tokenStr, _ := resp["access_token"].(string)

_, claims, err := m.Authenticate(tokenStr)
if err != nil {
t.Fatalf("authenticate: %v", err)
}
// Standard claims must not be overridden by client.Claims.
if claims["iss"] != "trusted-issuer" {
t.Errorf("iss must not be overridable via client claims, got %v", claims["iss"])
}
if claims["sub"] != "attacker" {
t.Errorf("sub must not be overridable via client claims, got %v", claims["sub"])
}
}

// TestM2M_ClientCredentials_NilClaimsOK verifies that a client with nil Claims
// still issues tokens without error.
func TestM2M_ClientCredentials_NilClaimsOK(t *testing.T) {
m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer")
m.RegisterClient(M2MClient{
ClientID: "plain-client",
ClientSecret: "plain-client-secret!", //nolint:gosec // test credential
Scopes: []string{"read"},
Claims: nil,
})

params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {"plain-client"},
"client_secret": {"plain-client-secret!"},
}
w := postToken(t, m, params)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
}
}
5 changes: 4 additions & 1 deletion plugins/auth/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory {
}
}
}
if claimsRaw, ok := cm["claims"].(map[string]any); ok {
client.Claims = claimsRaw
}
if client.ClientID != "" && client.ClientSecret != "" {
m.RegisterClient(client)
}
Expand Down Expand Up @@ -366,7 +369,7 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema {
{Key: "privateKey", Label: "EC Private Key (PEM)", Type: schema.FieldTypeString, Description: "PEM-encoded EC private key for ES256 signing; if omitted a key is auto-generated", Sensitive: true},
{Key: "tokenExpiry", Label: "Token Expiry", Type: schema.FieldTypeDuration, DefaultValue: "1h", Description: "Access token expiration duration (e.g. 15m, 1h)", Placeholder: "1h"},
{Key: "issuer", Label: "Issuer", Type: schema.FieldTypeString, DefaultValue: "workflow", Description: "Token issuer (iss) claim", Placeholder: "workflow"},
{Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description}]"},
{Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description, claims}]"},
},
DefaultConfig: map[string]any{"algorithm": "ES256", "tokenExpiry": "1h", "issuer": "workflow", "clients": []any{}},
},
Expand Down
48 changes: 48 additions & 0 deletions plugins/auth/plugin_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package auth

import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/GoCodeAlone/workflow/module"
"github.com/GoCodeAlone/workflow/plugin"
)

Expand Down Expand Up @@ -120,3 +125,46 @@ func TestModuleSchemas(t *testing.T) {
}
}
}

func TestModuleFactoryM2MWithClaims(t *testing.T) {
p := New()
factories := p.ModuleFactories()

mod := factories["auth.m2m"]("m2m-test", map[string]any{
"algorithm": "HS256",
"secret": "this-is-a-valid-secret-32-bytes!",
"clients": []any{
map[string]any{
"clientId": "org-alpha",
"clientSecret": "secret-alpha",
"scopes": []any{"read"},
"claims": map[string]any{
"tenant_id": "alpha",
},
},
},
})
if mod == nil {
t.Fatal("auth.m2m factory returned nil")
}

m2mMod, ok := mod.(*module.M2MAuthModule)
if !ok {
t.Fatal("expected *module.M2MAuthModule")
}

// Issue a token via the Handle method.
params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {"org-alpha"},
"client_secret": {"secret-alpha"},
}
req := httptest.NewRequest(http.MethodPost, "/oauth/token", strings.NewReader(params.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
m2mMod.Handle(w, req)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String())
}
}
Loading