diff --git a/go.mod b/go.mod index 73d9ae72ab..1bea53d518 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc github.com/coreos/go-oidc/v3 v3.6.0 + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 github.com/didip/tollbooth/v5 v5.1.1 github.com/gobuffalo/validate/v3 v3.3.3 // indirect github.com/gobwas/glob v0.2.3 @@ -39,7 +40,6 @@ require ( github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/consensys/gnark-crypto v0.18.1 // indirect github.com/crate-crypto/go-eth-kzg v1.4.0 // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/ethereum/c-kzg-4844/v2 v2.1.5 // indirect github.com/go-jose/go-jose/v3 v3.0.5 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect diff --git a/internal/api/provider/es256k_keyset.go b/internal/api/provider/es256k_keyset.go new file mode 100644 index 0000000000..30a208be62 --- /dev/null +++ b/internal/api/provider/es256k_keyset.go @@ -0,0 +1,283 @@ +package provider + +import ( + "context" + "crypto/ecdsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "golang.org/x/oauth2" +) + +const es256kAlgorithm = "ES256K" + +type oidcDiscoveryClaims struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` +} + +type jwtHeader struct { + Algorithm string `json:"alg"` + KeyID string `json:"kid"` + Critical []string `json:"crit"` +} + +type es256kJWKSet struct { + Keys []es256kJWK `json:"keys"` +} + +type es256kJWK struct { + KeyType string `json:"kty"` + Curve string `json:"crv"` + KeyID string `json:"kid"` + Algorithm string `json:"alg"` + Use string `json:"use"` + X string `json:"x"` + Y string `json:"y"` +} + +type es256kRemoteKeySet struct { + jwksURL string + + mu sync.RWMutex + cachedKeys []es256kJWK +} + +var es256kKeySets sync.Map + +const es256kKeySetCacheTTL = time.Hour + +type cachedES256KKeySet struct { + keySet *es256kRemoteKeySet + createdAt time.Time +} + +func getES256KRemoteKeySet(jwksURL string) *es256kRemoteKeySet { + now := time.Now() + if value, ok := es256kKeySets.Load(jwksURL); ok { + cached := value.(cachedES256KKeySet) + if now.Sub(cached.createdAt) < es256kKeySetCacheTTL { + return cached.keySet + } + } + + cached := cachedES256KKeySet{ + keySet: newES256KRemoteKeySet(jwksURL), + createdAt: now, + } + es256kKeySets.Store(jwksURL, cached) + return cached.keySet +} + +func newES256KRemoteKeySet(jwksURL string) *es256kRemoteKeySet { + return &es256kRemoteKeySet{jwksURL: jwksURL} +} + +func (r *es256kRemoteKeySet) VerifySignature(ctx context.Context, token string) ([]byte, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts)) + } + + header, err := parseJWTHeader(token) + if err != nil { + return nil, err + } + if header.Algorithm != es256kAlgorithm { + return nil, fmt.Errorf("oidc: unsupported jwt algorithm %q for ES256K key set", header.Algorithm) + } + if len(header.Critical) > 0 { + return nil, errors.New("oidc: unsupported critical jwt headers") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("oidc: malformed jwt payload: %w", err) + } + + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("oidc: malformed jwt signature: %w", err) + } + if len(signature) != 64 { + return nil, fmt.Errorf("oidc: malformed ES256K signature length %d", len(signature)) + } + + if r.verifyWithKeys(r.keysFromCache(), header, parts[0]+"."+parts[1], signature) { + return payload, nil + } + + keys, err := r.keysFromRemote(ctx) + if err != nil { + return nil, fmt.Errorf("fetching keys %v", err) + } + if r.verifyWithKeys(keys, header, parts[0]+"."+parts[1], signature) { + return payload, nil + } + + return nil, errors.New("failed to verify id token signature") +} + +func parseJWTHeader(token string) (*jwtHeader, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts)) + } + + headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, fmt.Errorf("oidc: malformed jwt header: %w", err) + } + + var header jwtHeader + if err := json.Unmarshal(headerBytes, &header); err != nil { + return nil, fmt.Errorf("oidc: malformed jwt header: %w", err) + } + + return &header, nil +} + +func verifySHA256AccessTokenHash(expected, accessToken string) error { + digest := sha256.Sum256([]byte(accessToken)) + actual := base64.RawURLEncoding.EncodeToString(digest[:len(digest)/2]) + if actual != expected { + return errors.New("oidc: access token hash does not match value in ID token") + } + return nil +} + +func supportsES256K(claims oidcDiscoveryClaims, config *oidc.Config) bool { + if len(config.SupportedSigningAlgs) > 0 { + return containsSigningAlg(config.SupportedSigningAlgs, es256kAlgorithm) + } + return containsSigningAlg(claims.IDTokenSigningAlgValuesSupported, es256kAlgorithm) +} + +func containsSigningAlg(algs []string, alg string) bool { + for _, candidate := range algs { + if candidate == alg { + return true + } + } + return false +} + +func (r *es256kRemoteKeySet) keysFromCache() []es256kJWK { + r.mu.RLock() + defer r.mu.RUnlock() + + keys := make([]es256kJWK, len(r.cachedKeys)) + copy(keys, r.cachedKeys) + return keys +} + +func (r *es256kRemoteKeySet) keysFromRemote(ctx context.Context) ([]es256kJWK, error) { + req, err := http.NewRequest(http.MethodGet, r.jwksURL, nil) + if err != nil { + return nil, fmt.Errorf("oidc: can't create request: %w", err) + } + + client := http.DefaultClient + if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok && c != nil { + client = c + } + + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return nil, fmt.Errorf("oidc: get keys failed %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("unable to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body) + } + + var keySet es256kJWKSet + if err := json.Unmarshal(body, &keySet); err != nil { + return nil, fmt.Errorf("oidc: failed to decode keys: %w %s", err, body) + } + + r.mu.Lock() + r.cachedKeys = keySet.Keys + r.mu.Unlock() + + return keySet.Keys, nil +} + +func (r *es256kRemoteKeySet) verifyWithKeys(keys []es256kJWK, header *jwtHeader, signingInput string, signature []byte) bool { + for _, key := range keys { + if !key.matches(header.KeyID) { + continue + } + + publicKey, err := key.publicKey() + if err != nil { + continue + } + + digest := sha256.Sum256([]byte(signingInput)) + r := new(big.Int).SetBytes(signature[:32]) + s := new(big.Int).SetBytes(signature[32:]) + if ecdsa.Verify(publicKey, digest[:], r, s) { + return true + } + } + + return false +} + +func (k es256kJWK) matches(keyID string) bool { + if keyID != "" && k.KeyID != keyID { + return false + } + if k.KeyType != "EC" || k.Curve != "secp256k1" { + return false + } + if k.Algorithm != "" && k.Algorithm != es256kAlgorithm { + return false + } + if k.Use != "" && k.Use != "sig" { + return false + } + return true +} + +func (k es256kJWK) publicKey() (*ecdsa.PublicKey, error) { + xBytes, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + return nil, fmt.Errorf("oidc: malformed ES256K key x coordinate: %w", err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + return nil, fmt.Errorf("oidc: malformed ES256K key y coordinate: %w", err) + } + if len(xBytes) != 32 || len(yBytes) != 32 { + return nil, fmt.Errorf("oidc: malformed ES256K key coordinate lengths %d/%d", len(xBytes), len(yBytes)) + } + + x := new(big.Int).SetBytes(xBytes) + y := new(big.Int).SetBytes(yBytes) + curve := secp256k1.S256() + if !curve.IsOnCurve(x, y) { + return nil, errors.New("oidc: ES256K key is not on secp256k1 curve") + } + + return &ecdsa.PublicKey{Curve: curve, X: x, Y: y}, nil +} diff --git a/internal/api/provider/es256k_keyset_test.go b/internal/api/provider/es256k_keyset_test.go new file mode 100644 index 0000000000..ba8ac2e12e --- /dev/null +++ b/internal/api/provider/es256k_keyset_test.go @@ -0,0 +1,251 @@ +package provider + +import ( + "context" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/stretchr/testify/require" +) + +func TestParseIDTokenES256K(t *testing.T) { + privateKey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err) + + accessToken := "telegram-access-token" + accessTokenHash := sha256.Sum256([]byte(accessToken)) + + issuer, jwksURI, oidcServer := newES256KOIDCTestServer(t, privateKey.PubKey().ToECDSA(), "telegram-key") + defer oidcServer.Close() + + idToken := mustSignES256KJWT(t, privateKey.ToECDSA(), map[string]any{ + "alg": es256kAlgorithm, + "kid": "telegram-key", + "typ": "JWT", + }, map[string]any{ + "iss": issuer, + "sub": "telegram-user", + "aud": "telegram-client", + "email": "telegram@example.com", + "email_verified": true, + "iat": time.Now().Add(-time.Minute).Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "at_hash": base64.RawURLEncoding.EncodeToString(accessTokenHash[:len(accessTokenHash)/2]), + }) + + oidcProvider, err := oidc.NewProvider(context.Background(), issuer) + require.NoError(t, err) + + var claims oidcDiscoveryClaims + require.NoError(t, oidcProvider.Claims(&claims)) + require.Equal(t, jwksURI, claims.JWKSURI) + + token, user, err := ParseIDToken(context.Background(), oidcProvider, &oidc.Config{ + ClientID: "telegram-client", + }, idToken, ParseIDTokenOptions{ + AccessToken: accessToken, + }) + require.NoError(t, err) + require.Equal(t, issuer, token.Issuer) + require.Equal(t, "telegram-user", token.Subject) + require.Len(t, user.Emails, 1) + require.Equal(t, "telegram@example.com", user.Emails[0].Email) + require.True(t, user.Emails[0].Verified) +} + +func TestParseIDTokenES256KRequiresAdvertisedAlgorithm(t *testing.T) { + privateKey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err) + + issuer, _, oidcServer := newES256KOIDCTestServerWithAlgorithms(t, privateKey.PubKey().ToECDSA(), "telegram-key", nil) + defer oidcServer.Close() + + idToken := mustSignES256KJWT(t, privateKey.ToECDSA(), map[string]any{ + "alg": es256kAlgorithm, + "kid": "telegram-key", + }, map[string]any{ + "iss": issuer, + "sub": "telegram-user", + "aud": "telegram-client", + "iat": time.Now().Add(-time.Minute).Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + + oidcProvider, err := oidc.NewProvider(context.Background(), issuer) + require.NoError(t, err) + + _, _, err = ParseIDToken(context.Background(), oidcProvider, &oidc.Config{ + ClientID: "telegram-client", + }, idToken, ParseIDTokenOptions{SkipAccessTokenCheck: true}) + require.Error(t, err) +} + +func TestParseIDTokenES256KRejectsInvalidAccessTokenHash(t *testing.T) { + privateKey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err) + + issuer, _, oidcServer := newES256KOIDCTestServer(t, privateKey.PubKey().ToECDSA(), "telegram-key") + defer oidcServer.Close() + + idToken := mustSignES256KJWT(t, privateKey.ToECDSA(), map[string]any{ + "alg": es256kAlgorithm, + "kid": "telegram-key", + }, map[string]any{ + "iss": issuer, + "sub": "telegram-user", + "aud": "telegram-client", + "iat": time.Now().Add(-time.Minute).Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "at_hash": "invalid-hash", + }) + + oidcProvider, err := oidc.NewProvider(context.Background(), issuer) + require.NoError(t, err) + + _, _, err = ParseIDToken(context.Background(), oidcProvider, &oidc.Config{ + ClientID: "telegram-client", + }, idToken, ParseIDTokenOptions{AccessToken: "telegram-access-token"}) + require.Error(t, err) +} + +func TestES256KRemoteKeySetRejectsBadSignature(t *testing.T) { + privateKey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err) + + _, jwksURI, oidcServer := newES256KOIDCTestServer(t, privateKey.PubKey().ToECDSA(), "telegram-key") + defer oidcServer.Close() + + token := mustSignES256KJWT(t, privateKey.ToECDSA(), map[string]any{ + "alg": es256kAlgorithm, + "kid": "telegram-key", + }, map[string]any{"sub": "telegram-user"}) + + parts := strings.Split(token, ".") + require.Len(t, parts, 3) + parts[2] = base64.RawURLEncoding.EncodeToString(make([]byte, 64)) + + _, err = newES256KRemoteKeySet(jwksURI).VerifySignature(context.Background(), strings.Join(parts, ".")) + require.Error(t, err) +} + +func TestES256KRemoteKeySetRejectsUnknownKeyID(t *testing.T) { + privateKey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err) + + _, jwksURI, oidcServer := newES256KOIDCTestServer(t, privateKey.PubKey().ToECDSA(), "other-key") + defer oidcServer.Close() + + token := mustSignES256KJWT(t, privateKey.ToECDSA(), map[string]any{ + "alg": es256kAlgorithm, + "kid": "telegram-key", + }, map[string]any{"sub": "telegram-user"}) + + _, err = newES256KRemoteKeySet(jwksURI).VerifySignature(context.Background(), token) + require.Error(t, err) +} + +func TestES256KRemoteKeySetRejectsInvalidCurvePoint(t *testing.T) { + keySet := es256kJWKSet{Keys: []es256kJWK{{ + KeyType: "EC", + Curve: "secp256k1", + KeyID: "telegram-key", + Algorithm: es256kAlgorithm, + Use: "sig", + X: base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + Y: base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + }}} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(keySet)) + })) + defer server.Close() + + privateKey, err := secp256k1.GeneratePrivateKey() + require.NoError(t, err) + token := mustSignES256KJWT(t, privateKey.ToECDSA(), map[string]any{ + "alg": es256kAlgorithm, + "kid": "telegram-key", + }, map[string]any{"sub": "telegram-user"}) + + _, err = newES256KRemoteKeySet(server.URL).VerifySignature(context.Background(), token) + require.Error(t, err) +} + +func newES256KOIDCTestServer(t *testing.T, publicKey *ecdsa.PublicKey, keyID string) (string, string, *httptest.Server) { + return newES256KOIDCTestServerWithAlgorithms(t, publicKey, keyID, []string{es256kAlgorithm}) +} + +func newES256KOIDCTestServerWithAlgorithms(t *testing.T, publicKey *ecdsa.PublicKey, keyID string, algorithms []string) (string, string, *httptest.Server) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + issuer := "http://" + r.Host + jwksURI := issuer + "/jwks" + + switch r.URL.Path { + case "/.well-known/openid-configuration": + discovery := map[string]any{ + "issuer": issuer, + "authorization_endpoint": issuer + "/authorize", + "token_endpoint": issuer + "/token", + "jwks_uri": jwksURI, + } + if algorithms != nil { + discovery["id_token_signing_alg_values_supported"] = algorithms + } + require.NoError(t, json.NewEncoder(w).Encode(discovery)) + case "/jwks": + require.NoError(t, json.NewEncoder(w).Encode(es256kJWKSet{Keys: []es256kJWK{{ + KeyType: "EC", + Curve: "secp256k1", + KeyID: keyID, + Algorithm: es256kAlgorithm, + Use: "sig", + X: base64.RawURLEncoding.EncodeToString(padBigInt(publicKey.X, 32)), + Y: base64.RawURLEncoding.EncodeToString(padBigInt(publicKey.Y, 32)), + }}})) + default: + http.NotFound(w, r) + } + })) + + return server.URL, server.URL + "/jwks", server +} + +func mustSignES256KJWT(t *testing.T, privateKey *ecdsa.PrivateKey, header, claims map[string]any) string { + t.Helper() + + headerBytes, err := json.Marshal(header) + require.NoError(t, err) + claimsBytes, err := json.Marshal(claims) + require.NoError(t, err) + + signingInput := base64.RawURLEncoding.EncodeToString(headerBytes) + "." + base64.RawURLEncoding.EncodeToString(claimsBytes) + digest := sha256.Sum256([]byte(signingInput)) + r, s, err := ecdsa.Sign(rand.Reader, privateKey, digest[:]) + require.NoError(t, err) + + signature := append(padBigInt(r, 32), padBigInt(s, 32)...) + return signingInput + "." + base64.RawURLEncoding.EncodeToString(signature) +} + +func padBigInt(value *big.Int, size int) []byte { + bytes := value.Bytes() + if len(bytes) >= size { + return bytes + } + padded := make([]byte, size) + copy(padded[size-len(bytes):], bytes) + return padded +} diff --git a/internal/api/provider/oidc.go b/internal/api/provider/oidc.go index daadf5456e..fedb7d1c2f 100644 --- a/internal/api/provider/oidc.go +++ b/internal/api/provider/oidc.go @@ -2,6 +2,7 @@ package provider import ( "context" + "errors" "strconv" "strings" "time" @@ -39,6 +40,26 @@ func ParseIDToken(ctx context.Context, provider *oidc.Provider, config *oidc.Con } verifier := provider.Verifier(config) + header, headerErr := parseJWTHeader(idToken) + isES256K := headerErr == nil && header.Algorithm == es256kAlgorithm + if isES256K { + var claims oidcDiscoveryClaims + if err := provider.Claims(&claims); err != nil { + return nil, nil, err + } + if claims.Issuer == "" || claims.JWKSURI == "" { + return nil, nil, errors.New("oidc: missing issuer or jwks_uri in provider metadata") + } + if !supportsES256K(claims, config) { + return nil, nil, errors.New("oidc: ES256K is not listed as a supported signing algorithm") + } + + clonedConfig := *config + if len(clonedConfig.SupportedSigningAlgs) == 0 { + clonedConfig.SupportedSigningAlgs = []string{es256kAlgorithm} + } + verifier = oidc.NewVerifier(claims.Issuer, getES256KRemoteKeySet(claims.JWKSURI), &clonedConfig) + } overrideVerifier, ok := OverrideVerifiers[provider.Endpoint().AuthURL] if ok && overrideVerifier != nil { verifier = overrideVerifier(ctx, config) @@ -80,7 +101,13 @@ func ParseIDToken(ctx context.Context, provider *oidc.Provider, config *oidc.Con } if !options.SkipAccessTokenCheck && token.AccessTokenHash != "" { - if err := token.VerifyAccessToken(options.AccessToken); err != nil { + var err error + if isES256K { + err = verifySHA256AccessTokenHash(token.AccessTokenHash, options.AccessToken) + } else { + err = token.VerifyAccessToken(options.AccessToken) + } + if err != nil { return nil, nil, err } }