From 611e8c983f8f245729e9a77e8cfc6344dbec9a6c Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Mon, 1 Jun 2026 19:07:27 +0530 Subject: [PATCH 01/12] Changes for passkey implementation --- .../auth_schemes/__init__.py | 3 +- .../auth_schemes/dpop_auth.py | 51 + .../auth_server/my_account_client.py | 464 ++++++++- .../auth_server/server_client.py | 967 +++++++++++------- .../auth_types/__init__.py | 274 ++++- .../tests/test_dpop_auth.py | 145 +++ .../tests/test_passkey_my_account.py | 473 +++++++++ .../tests/test_passkey_server_client.py | 523 ++++++++++ 8 files changed, 2483 insertions(+), 417 deletions(-) create mode 100644 src/auth0_server_python/auth_schemes/dpop_auth.py create mode 100644 src/auth0_server_python/tests/test_dpop_auth.py create mode 100644 src/auth0_server_python/tests/test_passkey_my_account.py create mode 100644 src/auth0_server_python/tests/test_passkey_server_client.py diff --git a/src/auth0_server_python/auth_schemes/__init__.py b/src/auth0_server_python/auth_schemes/__init__.py index 1c2c869..ef37613 100644 --- a/src/auth0_server_python/auth_schemes/__init__.py +++ b/src/auth0_server_python/auth_schemes/__init__.py @@ -1,3 +1,4 @@ from .bearer_auth import BearerAuth +from .dpop_auth import DPoPAuth -__all__ = ["BearerAuth"] +__all__ = ["BearerAuth", "DPoPAuth"] diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py new file mode 100644 index 0000000..1517a78 --- /dev/null +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -0,0 +1,51 @@ +import base64 +import hashlib +import time +import uuid + +import httpx +from jwcrypto import jwk +from jwcrypto import jwt as jwcrypto_jwt + + +def _base64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +class DPoPAuth(httpx.Auth): + def __init__(self, token: str, key: "jwk.JWK") -> None: + public_jwk = key.export_public(as_dict=True) + if public_jwk.get("kty") != "EC" or public_jwk.get("crv") != "P-256": + raise ValueError("DPoP key must be an EC P-256 key") + self._token = token + self._key = key + self._public_jwk = public_jwk + + def __repr__(self) -> str: + return "DPoPAuth(token=[REDACTED], key=[REDACTED])" + + def __str__(self) -> str: + return "DPoPAuth(token=[REDACTED], key=[REDACTED])" + + def auth_flow(self, request: httpx.Request): + proof = self._make_proof(request.method, str(request.url)) + request.headers["Authorization"] = f"DPoP {self._token}" + request.headers["DPoP"] = proof + yield request + + def _make_proof(self, method: str, url: str) -> str: + htu = url.split("?")[0].split("#")[0] + ath = _base64url(hashlib.sha256(self._token.encode("ascii")).digest()) + + header = {"typ": "dpop+jwt", "alg": "ES256", "jwk": self._public_jwk} + payload = { + "jti": str(uuid.uuid4()), + "htm": method.upper(), + "htu": htu, + "iat": int(time.time()), + "ath": ath, + } + + token = jwcrypto_jwt.JWT(header=header, claims=payload) + token.make_signed_token(self._key) + return token.serialize() diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 499b981..a6aed8f 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -1,16 +1,26 @@ - -from typing import Optional +import json +from typing import TYPE_CHECKING, Optional +from urllib.parse import quote import httpx +from pydantic import ValidationError from auth0_server_python.auth_schemes.bearer_auth import BearerAuth +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_types import ( + AuthenticationMethod, CompleteConnectAccountRequest, CompleteConnectAccountResponse, ConnectAccountRequest, ConnectAccountResponse, + EnrollAuthenticationMethodRequest, + EnrollmentChallengeResponse, + GetFactorsResponse, + ListAuthenticationMethodsResponse, ListConnectedAccountConnectionsResponse, ListConnectedAccountsResponse, + UpdateAuthenticationMethodRequest, + VerifyAuthenticationMethodRequest, ) from auth0_server_python.error import ( ApiError, @@ -19,6 +29,18 @@ MyAccountApiError, ) +if TYPE_CHECKING: + from jwcrypto import jwk + + +def _make_auth( + access_token: str, + dpop_key: Optional["jwk.JWK"] = None, +) -> httpx.Auth: + if dpop_key is not None: + return DPoPAuth(access_token, dpop_key) + return BearerAuth(access_token) + class MyAccountClient: """ @@ -52,9 +74,7 @@ def audience(self): return f"https://{self._domain}/me/" async def connect_account( - self, - access_token: str, - request: ConnectAccountRequest + self, access_token: str, request: ConnectAccountRequest ) -> ConnectAccountResponse: """ Initiate the connected account flow. @@ -75,7 +95,7 @@ async def connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/connect", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 201: @@ -85,7 +105,7 @@ async def connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -98,13 +118,11 @@ async def connect_account( raise ApiError( "connect_account_error", f"Connected Accounts connect request failed: {str(e) or 'Unknown error'}", - e + e, ) async def complete_connect_account( - self, - access_token: str, - request: CompleteConnectAccountRequest + self, access_token: str, request: CompleteConnectAccountRequest ) -> CompleteConnectAccountResponse: """ Complete the connected account flow after user authorization. @@ -125,7 +143,7 @@ async def complete_connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/complete", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 201: @@ -135,7 +153,7 @@ async def complete_connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -148,7 +166,7 @@ async def complete_connect_account( raise ApiError( "connect_account_error", f"Connected Accounts complete request failed: {str(e) or 'Unknown error'}", - e + e, ) async def list_connected_accounts( @@ -156,7 +174,7 @@ async def list_connected_accounts( access_token: str, connection: Optional[str] = None, from_param: Optional[str] = None, - take: Optional[int] = None + take: Optional[int] = None, ) -> ListConnectedAccountsResponse: """ List connected accounts for the authenticated user. @@ -195,7 +213,7 @@ async def list_connected_accounts( response = await client.get( url=f"{self.audience}v1/connected-accounts/accounts", params=params, - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 200: @@ -205,7 +223,7 @@ async def list_connected_accounts( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -218,15 +236,10 @@ async def list_connected_accounts( raise ApiError( "connect_account_error", f"Connected Accounts list request failed: {str(e) or 'Unknown error'}", - e + e, ) - - async def delete_connected_account( - self, - access_token: str, - connected_account_id: str - ) -> None: + async def delete_connected_account(self, access_token: str, connected_account_id: str) -> None: """ Delete a connected account for the authenticated user. @@ -253,7 +266,7 @@ async def delete_connected_account( async with self._get_http_client() as client: response = await client.delete( url=f"{self.audience}v1/connected-accounts/accounts/{connected_account_id}", - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 204: @@ -263,7 +276,7 @@ async def delete_connected_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) except Exception as e: @@ -272,14 +285,11 @@ async def delete_connected_account( raise ApiError( "connect_account_error", f"Connected Accounts delete request failed: {str(e) or 'Unknown error'}", - e + e, ) async def list_connected_account_connections( - self, - access_token: str, - from_param: Optional[str] = None, - take: Optional[int] = None + self, access_token: str, from_param: Optional[str] = None, take: Optional[int] = None ) -> ListConnectedAccountConnectionsResponse: """ List available connections that support connected accounts. @@ -315,7 +325,7 @@ async def list_connected_account_connections( response = await client.get( url=f"{self.audience}v1/connected-accounts/connections", params=params, - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 200: @@ -325,7 +335,7 @@ async def list_connected_account_connections( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -338,5 +348,391 @@ async def list_connected_account_connections( raise ApiError( "connect_account_error", f"Connected Accounts list connections request failed: {str(e) or 'Unknown error'}", - e + e, + ) + + async def get_factors( + self, + access_token: str, + dpop_key: Optional["jwk.JWK"] = None, + ) -> GetFactorsResponse: + """ + Retrieve the list of factors available for enrollment. + + Args: + access_token: User's access token (scope: read:me:factors). + dpop_key: Optional EC P-256 key for DPoP-bound token presentation. + + Returns: + GetFactorsResponse containing the available factors. + + Raises: + MissingRequiredArgumentError: If access_token is not provided. + MyAccountApiError: If the API returns an error response. + ApiError: If the request fails due to network or other issues. + """ + if not access_token: + raise MissingRequiredArgumentError("access_token") + + try: + async with self._get_http_client() as client: + response = await client.get( + url=f"{self.audience}v1/factors", + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "get_factors_error", + f"Get factors failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return GetFactorsResponse.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "get_factors_error", + "Get factors request failed", + e, + ) + + async def list_authentication_methods( + self, + access_token: str, + type_filter: Optional[str] = None, + dpop_key: Optional["jwk.JWK"] = None, + ) -> ListAuthenticationMethodsResponse: + if not access_token: + raise MissingRequiredArgumentError("access_token") + + try: + async with self._get_http_client() as client: + params = {} + if type_filter: + params["type"] = type_filter + + response = await client.get( + url=f"{self.audience}v1/authentication-methods", + params=params, + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "list_authentication_methods_error", + f"List authentication methods failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return ListAuthenticationMethodsResponse.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "list_authentication_methods_error", + "List authentication methods request failed", + e, + ) + + async def get_authentication_method( + self, + access_token: str, + authentication_method_id: str, + dpop_key: Optional["jwk.JWK"] = None, + ) -> AuthenticationMethod: + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + + try: + async with self._get_http_client() as client: + response = await client.get( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}", + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "get_authentication_method_error", + f"Get authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return AuthenticationMethod.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "get_authentication_method_error", + "Get authentication method request failed", + e, + ) + + async def delete_authentication_method( + self, + access_token: str, + authentication_method_id: str, + dpop_key: Optional["jwk.JWK"] = None, + ) -> None: + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + + try: + async with self._get_http_client() as client: + response = await client.delete( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}", + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 204: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "delete_authentication_method_error", + f"Delete authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "delete_authentication_method_error", + "Delete authentication method request failed", + e, + ) + + async def update_authentication_method( + self, + access_token: str, + authentication_method_id: str, + request: UpdateAuthenticationMethodRequest, + dpop_key: Optional["jwk.JWK"] = None, + ) -> AuthenticationMethod: + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + if request is None: + raise MissingRequiredArgumentError("request") + + try: + async with self._get_http_client() as client: + response = await client.patch( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}", + json=request.model_dump(exclude_none=True), + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "update_authentication_method_error", + f"Update authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return AuthenticationMethod.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "update_authentication_method_error", + "Update authentication method request failed", + e, + ) + + async def enroll_authentication_method( + self, + access_token: str, + request: EnrollAuthenticationMethodRequest, + dpop_key: Optional["jwk.JWK"] = None, + ) -> EnrollmentChallengeResponse: + """Step 1 of 2: Start enrollment (POST /me/v1/authentication-methods). + + For passkey enrollment, pass the returned authn_params_public_key to + navigator.credentials.create(), then call verify_authentication_method() + with the auth_session and credential result. + + Requires scope: create:me:authentication_methods + """ + if not access_token: + raise MissingRequiredArgumentError("access_token") + if request is None: + raise MissingRequiredArgumentError("request") + + try: + async with self._get_http_client() as client: + response = await client.post( + url=f"{self.audience}v1/authentication-methods", + json=request.model_dump(exclude_none=True), + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 201: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "enroll_authentication_method_error", + f"Enroll authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + location = response.headers.get("location") + if not location: + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but Location header is missing", + ) + + authentication_method_id = ( + location.split("?")[0].split("#")[0].rstrip("/").split("/")[-1] + ) + if not authentication_method_id: + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but could not extract ID from Location header", + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but response body is not valid JSON", + ) + + auth_session = data.get("auth_session") + if not auth_session: + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but auth_session is missing from response", + ) + + return EnrollmentChallengeResponse.model_validate( + { + "authentication_method_id": authentication_method_id, + "auth_session": auth_session, + "authn_params_public_key": data.get("authn_params_public_key"), + } + ) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "enroll_authentication_method_error", + "Enroll authentication method request failed", + e, + ) + + async def verify_authentication_method( + self, + access_token: str, + authentication_method_id: str, + request: VerifyAuthenticationMethodRequest, + dpop_key: Optional["jwk.JWK"] = None, + ) -> AuthenticationMethod: + """Step 2 of 2: Verify enrollment (POST /me/v1/authentication-methods/{id}/verify). + + Requires scope: create:me:authentication_methods + """ + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + if request is None: + raise MissingRequiredArgumentError("request") + + try: + async with self._get_http_client() as client: + response = await client.post( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}/verify", + json=request.model_dump(by_alias=True, exclude_none=True), + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "verify_authentication_method_error", + f"Verify authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return AuthenticationMethod.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "verify_authentication_method_error", + "Verify authentication method request failed", + e, ) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 91de45d..334eb00 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -31,6 +31,10 @@ LogoutOptions, LogoutTokenClaims, MfaRequirements, + PasskeyAuthResponse, + PasskeyLoginChallengeResponse, + PasskeySignupChallengeResponse, + PasskeyTokenResponse, StartInteractiveLoginOptions, StateData, TokenExchangeResponse, @@ -65,11 +69,18 @@ ) # Generic type for store options -TStoreOptions = TypeVar('TStoreOptions') +TStoreOptions = TypeVar("TStoreOptions") # redirect_uri is intentionally excluded — in MCD mode it is built # dynamically from the resolved domain at login time. -INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", - "code_challenge", "code_challenge_method", "state", "nonce", "scope"] +INTERNAL_AUTHORIZE_PARAMS = [ + "client_id", + "response_type", + "code_challenge", + "code_challenge_method", + "state", + "nonce", + "scope", +] class ServerClient(Generic[TStoreOptions]): @@ -77,6 +88,7 @@ class ServerClient(Generic[TStoreOptions]): Main client for Auth0 server SDK. Handles authentication flows, session management, and token operations using Authlib for OIDC functionality. """ + DEFAULT_AUDIENCE_STATE_KEY = "default" # ============================================================================ @@ -117,9 +129,7 @@ def __init__( raise MissingRequiredArgumentError("secret") if domain is None: - raise ConfigurationError( - "Domain is required" - ) + raise ConfigurationError("Domain is required") # Validate domain type if not isinstance(domain, str) and not callable(domain): @@ -164,14 +174,12 @@ def __init__( headers=self._telemetry_headers, ) - self._my_account_client = MyAccountClient( - domain=domain, headers=self._telemetry_headers - ) + self._my_account_client = MyAccountClient(domain=domain, headers=self._telemetry_headers) # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) self._discovery_cache: OrderedDict[str, dict] = OrderedDict() - self._cache_ttl = 600 # 10 mins. TTL - self._cache_max_entries = 100 # Max 100 domains + self._cache_ttl = 600 # 10 mins. TTL + self._cache_max_entries = 100 # Max 100 domains # Initialize MFA client self._mfa_client = MfaClient( @@ -198,14 +206,14 @@ def _normalize_url(self, value: str) -> str: return value value = value.lower() - if value.startswith('https://'): + if value.startswith("https://"): pass - elif value.startswith('http://'): - value = value.replace('http://', 'https://') + elif value.startswith("http://"): + value = value.replace("http://", "https://") else: - value = f'https://{value}' + value = f"https://{value}" - return value.rstrip('/') + return value.rstrip("/") async def _resolve_current_domain(self, store_options=None) -> str: """Resolve domain from resolver function or return static domain.""" @@ -218,8 +226,7 @@ async def _resolve_current_domain(self, store_options=None) -> str: raise except Exception as e: raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e + f"Domain resolver function raised an exception: {str(e)}", original_error=e ) return self._domain @@ -233,18 +240,18 @@ def _get_session_domain(self, state_data_dict: dict) -> Optional[str]: 2. self._domain — static domain (if configured) 3. Extract hostname from user.iss — derive from user's issuer claim """ - domain = state_data_dict.get('domain') + domain = state_data_dict.get("domain") if domain: return domain if self._domain: return self._domain - user = state_data_dict.get('user') + user = state_data_dict.get("user") if isinstance(user, dict): - iss = user.get('iss') + iss = user.get("iss") else: - iss = getattr(user, 'iss', None) if user else None + iss = getattr(user, "iss", None) if user else None if iss: parsed = urlparse(iss) @@ -347,7 +354,7 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": None, - "expires_at": now + self._cache_ttl + "expires_at": now + self._cache_ttl, } return metadata @@ -409,11 +416,11 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: if not metadata: metadata = await self._get_oidc_metadata_cached(domain) - jwks_uri = metadata.get('jwks_uri') + jwks_uri = metadata.get("jwks_uri") if not jwks_uri: raise ApiError( "missing_jwks_uri", - f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant.", ) # Fetch JWKS @@ -430,7 +437,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": jwks, - "expires_at": now + self._cache_ttl + "expires_at": now + self._cache_ttl, } return jwks @@ -442,9 +449,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: # ============================================================================ async def start_interactive_login( - self, - options: Optional[StartInteractiveLoginOptions] = None, - store_options: dict = None + self, options: Optional[StartInteractiveLoginOptions] = None, store_options: dict = None ) -> str: """ Starts the interactive login process and returns a URL to redirect to. @@ -465,15 +470,17 @@ async def start_interactive_login( try: metadata = await self._get_oidc_metadata_cached(origin_domain) except Exception as e: - raise ApiError("metadata_error", - "Failed to fetch OIDC metadata", e) + raise ApiError("metadata_error", "Failed to fetch OIDC metadata", e) # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: auth_params.update( - {k: v for k, v in options.authorization_params.items( - ) if k not in INTERNAL_AUTHORIZE_PARAMS} + { + k: v + for k, v in options.authorization_params.items() + if k not in INTERNAL_AUTHORIZE_PARAMS + } ) # Ensure we have a redirect_uri @@ -497,7 +504,11 @@ async def start_interactive_login( auth_params["state"] = state # Merge any requested scope with defaults - requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None + requested_scope = ( + options.authorization_params.get("scope", None) + if options.authorization_params + else None + ) audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope @@ -513,65 +524,61 @@ async def start_interactive_login( # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) # Set metadata for OAuth client self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: - par_endpoint = self._oauth.metadata.get( - "pushed_authorization_request_endpoint") + par_endpoint = self._oauth.metadata.get("pushed_authorization_request_endpoint") if not par_endpoint: raise ApiError( - "configuration_error", "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata") + "configuration_error", + "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata", + ) auth_params["client_id"] = self._client_id # Post the auth_params to the PAR endpoint async with self._get_http_client() as client: par_response = await client.post( - par_endpoint, - data=auth_params, - auth=(self._client_id, self._client_secret) + par_endpoint, data=auth_params, auth=(self._client_id, self._client_secret) ) if par_response.status_code not in (200, 201): error_data = par_response.json() raise ApiError( error_data.get("error", "par_error"), error_data.get( - "error_description", "Failed to obtain request_uri from PAR endpoint") + "error_description", "Failed to obtain request_uri from PAR endpoint" + ), ) par_data = par_response.json() request_uri = par_data.get("request_uri") if not request_uri: - raise ApiError( - "par_error", "No request_uri returned from PAR endpoint") + raise ApiError("par_error", "No request_uri returned from PAR endpoint") auth_endpoint = self._oauth.metadata.get("authorization_endpoint") final_url = f"{auth_endpoint}?request_uri={request_uri}&response_type={auth_params['response_type']}&client_id={self._client_id}" return final_url else: if "authorization_endpoint" not in self._oauth.metadata: - raise ApiError("configuration_error", - "Authorization endpoint missing in OIDC metadata") + raise ApiError( + "configuration_error", "Authorization endpoint missing in OIDC metadata" + ) authorization_endpoint = self._oauth.metadata["authorization_endpoint"] try: auth_url, state = self._oauth.create_authorization_url( - authorization_endpoint, **auth_params) + authorization_endpoint, **auth_params + ) except Exception as e: - raise ApiError("authorization_url_error", - "Failed to create authorization URL", e) + raise ApiError("authorization_url_error", "Failed to create authorization URL", e) return auth_url async def complete_interactive_login( - self, - url: str, - store_options: dict = None + self, url: str, store_options: dict = None ) -> dict[str, Any]: """ Completes the login process after user is redirected back. @@ -594,7 +601,9 @@ async def complete_interactive_login( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) + transaction_data = await self._transaction_store.get( + transaction_identifier, options=store_options + ) if not transaction_data: raise MissingTransactionError() @@ -615,7 +624,7 @@ async def complete_interactive_login( # Fetch metadata and derive issuer from the origin domain metadata = await self._get_oidc_metadata_cached(origin_domain) - origin_issuer = metadata.get('issuer') + origin_issuer = metadata.get("issuer") self._oauth.metadata = metadata # Exchange the code for tokens @@ -631,8 +640,7 @@ async def complete_interactive_login( ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) - raise ApiError( - "token_error", f"Token exchange failed: {str(e)}", e) + raise ApiError("token_error", f"Token exchange failed: {str(e)}", e) # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") @@ -647,14 +655,14 @@ async def complete_interactive_login( # Decode and verify ID token with signature verification enabled try: - claims = await self._verify_and_decode_jwt( - id_token, jwks, audience=self._client_id - ) + claims = await self._verify_and_decode_jwt(id_token, jwks, audience=self._client_id) # Custom normalized issuer validation token_issuer = claims.get("iss", "") if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): - raise IssuerValidationError("ID token issuer mismatch. Ensure your Auth0 domain is configured correctly.") + raise IssuerValidationError( + "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." + ) user_claims = UserClaims.parse_obj(claims) except ValueError as e: @@ -663,40 +671,33 @@ async def complete_interactive_login( raise ApiError( "invalid_signature", f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", - e + e, ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", - e + e, ) except jwt.ExpiredSignatureError as e: - raise ApiError( - "token_expired", - f"ID token has expired: {str(e)}", - e - ) + raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) except jwt.InvalidTokenError as e: - raise ApiError( - "invalid_token", - f"ID token verification failed: {str(e)}", - e - ) - + raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) # Build a token set using the token response data token_set = TokenSet( audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.get("access_token", ""), scope=token_response.get("scope", ""), - expires_at=int(time.time()) + - token_response.get("expires_in", 3600) + expires_at=int(time.time()) + token_response.get("expires_in", 3600), ) # Generate a session id (sid) from token_response or transaction data, or create a new one - sid = user_info.get( - "sid") if user_info and "sid" in user_info else PKCE.generate_random_string(32) + sid = ( + user_info.get("sid") + if user_info and "sid" in user_info + else PKCE.generate_random_string(32) + ) # Construct state data to represent the session state_data = StateData( @@ -706,10 +707,7 @@ async def complete_interactive_login( refresh_token=token_response.get("refresh_token"), token_sets=[token_set], domain=origin_domain, - internal={ - "sid": sid, - "created_at": int(time.time()) - } + internal={"sid": sid, "created_at": int(time.time())}, ) # Store the state data in the state store using store_options (Response required) @@ -734,7 +732,9 @@ async def complete_interactive_login( # Methods for retrieving user information, session data, and logout operations. # ============================================================================ - async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: + async def get_user( + self, store_options: Optional[dict[str, Any]] = None + ) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -763,7 +763,9 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti return state_data.get("user") return None - async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: + async def get_session( + self, store_options: Optional[dict[str, Any]] = None + ) -> Optional[dict[str, Any]]: """ Retrieve the user session from the store, or None if no session found. @@ -789,15 +791,14 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O if self._normalize_url(session_domain) != self._normalize_url(current_domain): return None - session_data = {k: v for k, v in state_data.items() - if k != "internal"} + session_data = {k: v for k, v in state_data.items() if k != "internal"} return session_data return None async def logout( self, options: Optional[LogoutOptions] = None, - store_options: Optional[dict[str, Any]] = None + store_options: Optional[dict[str, Any]] = None, ) -> str: options = options or LogoutOptions() @@ -813,19 +814,18 @@ async def logout( if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_domain = self._get_session_domain(state_data) - if session_domain and self._normalize_url(session_domain) == self._normalize_url(domain): + if session_domain and self._normalize_url(session_domain) == self._normalize_url( + domain + ): await self._state_store.delete(self._state_identifier, store_options) # Return logout URL for the current resolved domain - logout_url = URL.create_logout_url( - domain, self._client_id, options.return_to) + logout_url = URL.create_logout_url(domain, self._client_id, options.return_to) return logout_url async def handle_backchannel_logout( - self, - logout_token: str, - store_options: Optional[dict[str, Any]] = None + self, logout_token: str, store_options: Optional[dict[str, Any]] = None ) -> None: """ Handles backchannel logout requests. @@ -846,8 +846,7 @@ async def handle_backchannel_logout( # Read iss from unverified token for comparison try: unverified = jwt.decode( - logout_token, algorithms=["RS256"], - options={"verify_signature": False} + logout_token, algorithms=["RS256"], options={"verify_signature": False} ) token_issuer = unverified.get("iss", "") except Exception as e: @@ -876,13 +875,16 @@ async def handle_backchannel_logout( jwks = await self._get_jwks_cached(domain) try: - claims = await self._verify_and_decode_jwt(logout_token, jwks, audience=self._client_id) + claims = await self._verify_and_decode_jwt( + logout_token, jwks, audience=self._client_id + ) # Normalized issuer validation token_issuer = claims.get("iss", "") expected_issuer = self._normalize_url(domain) if self._normalize_url(token_issuer) != self._normalize_url(expected_issuer): - raise IssuerValidationError("Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." + raise IssuerValidationError( + "Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." ) except ValueError as e: raise BackchannelLogoutError(str(e)) @@ -891,30 +893,22 @@ async def handle_backchannel_logout( f"Logout token signature verification failed: {str(e)}" ) except jwt.InvalidTokenError as e: - raise BackchannelLogoutError( - f"Logout token verification failed: {str(e)}" - ) + raise BackchannelLogoutError(f"Logout token verification failed: {str(e)}") # Validate the token is a logout token events = claims.get("events", {}) if "http://schemas.openid.net/event/backchannel-logout" not in events: - raise BackchannelLogoutError( - "Invalid logout token: not a backchannel logout event") + raise BackchannelLogoutError("Invalid logout token: not a backchannel logout event") # Delete sessions associated with this token logout_claims = LogoutTokenClaims( - sub=claims.get("sub"), - sid=claims.get("sid"), - iss=claims.get("iss") + sub=claims.get("sub"), sid=claims.get("sid"), iss=claims.get("iss") ) - await self._state_store.delete_by_logout_token( - logout_claims.dict(), store_options - ) + await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) except (jwt.PyJWTError, ValidationError) as e: - raise BackchannelLogoutError( - f"Error processing logout token: {str(e)}") + raise BackchannelLogoutError(f"Error processing logout token: {str(e)}") # ============================================================================ # ACCESS TOKEN MANAGEMENT @@ -955,13 +949,13 @@ async def get_access_token( if not session_domain: raise AccessTokenError( AccessTokenErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenError( AccessTokenErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) auth_params = self._default_authorization_params or {} @@ -975,7 +969,9 @@ async def get_access_token( # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: - token_set = self._find_matching_token_set(state_data_dict["token_sets"], audience, merged_scope) + token_set = self._find_matching_token_set( + state_data_dict["token_sets"], audience, merged_scope + ) # If token is valid, return it if token_set and token_set.get("expires_at", 0) > time.time(): @@ -985,7 +981,7 @@ async def get_access_token( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenError( AccessTokenErrorCode.MISSING_REFRESH_TOKEN, - "The access token has expired and a refresh token was not provided. The user needs to re-authenticate." + "The access token has expired and a refresh token was not provided. The user needs to re-authenticate.", ) # Get new token with refresh token @@ -994,7 +990,7 @@ async def get_access_token( session_domain = state_data_dict.get("domain") or self._domain get_refresh_token_options = { "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain + "domain": session_domain, } if audience: get_refresh_token_options["audience"] = audience @@ -1002,15 +998,20 @@ async def get_access_token( if merged_scope: get_refresh_token_options["scope"] = merged_scope - token_endpoint_response = await self.get_token_by_refresh_token(get_refresh_token_options) + token_endpoint_response = await self.get_token_by_refresh_token( + get_refresh_token_options + ) # Update state data with new token existing_state_data = await self._state_store.get(self._state_identifier, store_options) updated_state_data = State.update_state_data( - audience, existing_state_data, token_endpoint_response) + audience, existing_state_data, token_endpoint_response + ) # Store updated state - await self._state_store.set(self._state_identifier, updated_state_data, options=store_options) + await self._state_store.set( + self._state_identifier, updated_state_data, options=store_options + ) return token_endpoint_response["access_token"] except Exception as e: @@ -1024,22 +1025,21 @@ async def get_access_token( raw_mfa_token=raw_mfa_token, audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, scope=merged_scope or "", - mfa_requirements=mfa_requirements + mfa_requirements=mfa_requirements, ) raise MfaRequiredError( "Multifactor authentication required", mfa_token=encrypted_token, - mfa_requirements=mfa_requirements + mfa_requirements=mfa_requirements, ) if isinstance(e, AccessTokenError): raise raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, - f"Failed to get token with refresh token: {str(e)}" + f"Failed to get token with refresh token: {str(e)}", ) - async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -1067,8 +1067,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", - "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { @@ -1083,8 +1082,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Merge scope if present in options with any in the original authorization params merged_scope = self._merge_scope_with_defaults( - request_scope=options.get("scope"), - audience=audience + request_scope=options.get("scope"), audience=audience ) if merged_scope: @@ -1093,9 +1091,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Exchange the refresh token for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=token_params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1105,8 +1101,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Preserve mfa_required details for upstream handling if error_code == "mfa_required": error = ApiError( - error_code, - error_data.get("error_description", "MFA required") + error_code, error_data.get("error_description", "MFA required") ) error.mfa_token = error_data.get("mfa_token") mfa_requirements_data = error_data.get("mfa_requirements") @@ -1117,16 +1112,14 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise ApiError( error_code, - error_data.get("error_description", - "Failed to exchange refresh token") + error_data.get("error_description", "Failed to exchange refresh token"), ) token_response = response.json() # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int( - time.time()) + token_response["expires_in"] + token_response["expires_at"] = int(time.time()) + token_response["expires_in"] return token_response @@ -1136,13 +1129,11 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, "The access token has expired and there was an error while trying to refresh it.", - e + e, ) def _merge_scope_with_defaults( - self, - request_scope: Optional[str], - audience: Optional[str] + self, request_scope: Optional[str], audience: Optional[str] ) -> Optional[str]: """Helper: Merges requested scopes with default authorization params.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1163,10 +1154,7 @@ def _merge_scope_with_defaults( return " ".join(merged_scopes) if merged_scopes else None def _find_matching_token_set( - self, - token_sets: list[dict[str, Any]], - audience: Optional[str], - scope: Optional[str] + self, token_sets: list[dict[str, Any]], audience: Optional[str], scope: Optional[str] ) -> Optional[dict[str, Any]]: """Helper: Finds a token set matching the requested audience and scopes.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1192,9 +1180,7 @@ def _find_matching_token_set( # ============================================================================ async def login_backchannel( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Logs in using Client-Initiated Backchannel Authentication. @@ -1213,38 +1199,34 @@ async def login_backchannel( Returns: A dictionary containing the authorizationDetails (when RAR was used). """ - token_endpoint_response = await self.backchannel_authentication({ - "binding_message": options.get("binding_message"), - "login_hint": options.get("login_hint"), - "authorization_params": options.get("authorization_params"), - }, store_options=store_options) + token_endpoint_response = await self.backchannel_authentication( + { + "binding_message": options.get("binding_message"), + "login_hint": options.get("login_hint"), + "authorization_params": options.get("authorization_params"), + }, + store_options=store_options, + ) existing_state_data = await self._state_store.get(self._state_identifier, store_options) audience = self._default_authorization_params.get( - "audience", self.DEFAULT_AUDIENCE_STATE_KEY) - - state_data = State.update_state_data( - audience, - existing_state_data, - token_endpoint_response + "audience", self.DEFAULT_AUDIENCE_STATE_KEY ) + state_data = State.update_state_data(audience, existing_state_data, token_endpoint_response) + # Store domain for MCD session domain = await self._resolve_current_domain(store_options) state_data["domain"] = domain await self._state_store.set(self._state_identifier, state_data, store_options) - result = { - "authorization_details": token_endpoint_response.get("authorization_details") - } + result = {"authorization_details": token_endpoint_response.get("authorization_details")} return result async def backchannel_authentication( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Performs backchannel authentication with Auth0. @@ -1269,12 +1251,12 @@ async def backchannel_authentication( Raises: ApiError: If the backchannel authentication fails """ - backchannel_data = await self.initiate_backchannel_authentication(options, store_options=store_options) + backchannel_data = await self.initiate_backchannel_authentication( + options, store_options=store_options + ) auth_req_id = backchannel_data.get("auth_req_id") - expires_in = backchannel_data.get( - "expires_in", 120) # Default to 2 minutes - interval = backchannel_data.get( - "interval", 5) # Default to 5 seconds + expires_in = backchannel_data.get("expires_in", 120) # Default to 2 minutes + interval = backchannel_data.get("interval", 5) # Default to 5 seconds # Calculate when to stop polling end_time = time.time() + expires_in @@ -1283,7 +1265,9 @@ async def backchannel_authentication( while time.time() < end_time: # Make token request try: - token_response = await self.backchannel_authentication_grant(auth_req_id, store_options=store_options) + token_response = await self.backchannel_authentication_grant( + auth_req_id, store_options=store_options + ) return token_response except Exception as e: @@ -1299,17 +1283,14 @@ async def backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e + e, ) # If we get here, we've timed out - raise ApiError( - "timeout", "Backchannel authentication timed out") + raise ApiError("timeout", "Backchannel authentication timed out") async def initiate_backchannel_authentication( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Start backchannel authentication with Auth0. @@ -1339,18 +1320,13 @@ async def initiate_backchannel_authentication( https://auth0.com/docs/get-started/authentication-and-authorization-flow/client-initiated-backchannel-authentication-flow """ - sub = options.get('login_hint', {}).get("sub") + sub = options.get("login_hint", {}).get("sub") if not sub: - raise MissingRequiredArgumentError( - "login_hint.sub" - ) + raise MissingRequiredArgumentError("login_hint.sub") - authorization_params = options.get('authorization_params') + authorization_params = options.get("authorization_params") if authorization_params is not None and not isinstance(authorization_params, dict): - raise ApiError( - "invalid_argument", - "authorization_params must be a dict" - ) + raise ApiError("invalid_argument", "authorization_params must be a dict") if authorization_params: requested_expiry = authorization_params.get("requested_expiry") @@ -1358,7 +1334,7 @@ async def initiate_backchannel_authentication( if not isinstance(requested_expiry, int) or requested_expiry <= 0: raise ApiError( "invalid_argument", - "authorization_params.requested_expiry must be a positive integer" + "authorization_params.requested_expiry must be a positive integer", ) try: @@ -1367,24 +1343,18 @@ async def initiate_backchannel_authentication( metadata = await self._get_oidc_metadata_cached(domain) # Get the issuer from metadata - issuer = metadata.get( - "issuer") or f"https://{domain}/" + issuer = metadata.get("issuer") or f"https://{domain}/" # Get backchannel authentication endpoint - backchannel_endpoint = metadata.get( - "backchannel_authentication_endpoint") + backchannel_endpoint = metadata.get("backchannel_authentication_endpoint") if not backchannel_endpoint: raise ApiError( "configuration_error", - "Backchannel authentication is not supported by the authorization server" + "Backchannel authentication is not supported by the authorization server", ) # Prepare login hint in the required format - login_hint = json.dumps({ - "format": "iss_sub", - "iss": issuer, - "sub": sub - }) + login_hint = json.dumps({"format": "iss_sub", "iss": issuer, "sub": sub}) # The Request Parameters params = { @@ -1394,8 +1364,8 @@ async def initiate_backchannel_authentication( } # Add binding message if provided - if options.get('binding_message'): - params["binding_message"] = options.get('binding_message') + if options.get("binding_message"): + params["binding_message"] = options.get("binding_message") # Add any additional authorization parameters if self._default_authorization_params: @@ -1407,9 +1377,7 @@ async def initiate_backchannel_authentication( # Make the backchannel authentication request async with self._get_http_client() as client: backchannel_response = await client.post( - backchannel_endpoint, - data=params, - auth=(self._client_id, self._client_secret) + backchannel_endpoint, data=params, auth=(self._client_id, self._client_secret) ) if backchannel_response.status_code != 200: @@ -1417,7 +1385,8 @@ async def initiate_backchannel_authentication( raise ApiError( error_data.get("error", "backchannel_error"), error_data.get( - "error_description", "Backchannel authentication request failed") + "error_description", "Backchannel authentication request failed" + ), ) backchannel_data = backchannel_response.json() @@ -1426,7 +1395,7 @@ async def initiate_backchannel_authentication( if not auth_req_id: raise ApiError( "invalid_response", - "Missing auth_req_id in backchannel authentication response" + "Missing auth_req_id in backchannel authentication response", ) return backchannel_data @@ -1437,13 +1406,11 @@ async def initiate_backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e + e, ) async def backchannel_authentication_grant( - self, - auth_req_id: str, - store_options: Optional[dict[str, Any]] = None + self, auth_req_id: str, store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Retrieves a token by exchanging an auth_req_id. @@ -1468,23 +1435,20 @@ async def backchannel_authentication_grant( token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", - "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { "grant_type": "urn:openid:params:grant-type:ciba", "auth_req_id": auth_req_id, "client_id": self._client_id, - "client_secret": self._client_secret + "client_secret": self._client_secret, } # Exchange the auth_req_id for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=token_params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1493,23 +1457,18 @@ async def backchannel_authentication_grant( interval = int(retry_after) if retry_after is not None else None raise PollingApiError( error_data.get("error", "auth_req_id_error"), - error_data.get("error_description", - "Failed to exchange auth_req_id"), - interval + error_data.get("error_description", "Failed to exchange auth_req_id"), + interval, ) try: token_response = response.json() except json.JSONDecodeError: - raise ApiError( - "invalid_response", - "Failed to parse token response as JSON" - ) + raise ApiError("invalid_response", "Failed to parse token response as JSON") # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int( - time.time()) + token_response["expires_in"] + token_response["expires_at"] = int(time.time()) + token_response["expires_in"] return token_response @@ -1519,7 +1478,7 @@ async def backchannel_authentication_grant( raise AccessTokenError( AccessTokenErrorCode.AUTH_REQ_ID_ERROR, "There was an error while trying to exchange the auth_req_id for an access token.", - e + e, ) # ============================================================================ @@ -1528,11 +1487,7 @@ async def backchannel_authentication_grant( # to a user's Auth0 profile. # ============================================================================ - async def start_link_user( - self, - options, - store_options: Optional[dict[str, Any]] = None - ): + async def start_link_user(self, options, store_options: Optional[dict[str, Any]] = None): """ Starts the user linking process, and returns a URL to redirect the user-agent to. @@ -1559,13 +1514,9 @@ async def start_link_user( state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1579,7 +1530,7 @@ async def start_link_user( code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain + domain=origin_domain, ) # Store transaction data @@ -1590,17 +1541,13 @@ async def start_link_user( ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) return link_user_url async def complete_link_user( - self, - url: str, - store_options: Optional[dict[str, Any]] = None + self, url: str, store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user linking process. @@ -1617,15 +1564,9 @@ async def complete_link_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return { - "app_state": result.get("app_state") - } + return {"app_state": result.get("app_state")} - async def start_unlink_user( - self, - options, - store_options: Optional[dict[str, Any]] = None - ): + async def start_unlink_user(self, options, store_options: Optional[dict[str, Any]] = None): """ Starts the user unlinking process, and returns a URL to redirect the user-agent to. @@ -1652,13 +1593,9 @@ async def start_unlink_user( state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1671,7 +1608,7 @@ async def start_unlink_user( code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain + domain=origin_domain, ) # Store transaction data @@ -1682,17 +1619,13 @@ async def start_unlink_user( ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) return link_user_url async def complete_unlink_user( - self, - url: str, - store_options: Optional[dict[str, Any]] = None + self, url: str, store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user unlinking process. @@ -1709,9 +1642,7 @@ async def complete_unlink_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return { - "app_state": result.get("app_state") - } + return {"app_state": result.get("app_state")} async def _build_link_user_url( self, @@ -1721,7 +1652,7 @@ async def _build_link_user_url( state: str, connection_scope: Optional[str] = None, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None + domain: Optional[str] = None, ) -> str: """Build a URL for linking user accounts""" # Generate code challenge from verifier @@ -1732,8 +1663,9 @@ async def _build_link_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get("authorization_endpoint", - f"https://{resolved_domain}/authorize") + auth_endpoint = metadata.get( + "authorization_endpoint", f"https://{resolved_domain}/authorize" + ) # Build params params = { @@ -1746,7 +1678,7 @@ async def _build_link_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid link_account", - "audience": "my-account" + "audience": "my-account", } # Add connection scope if provided @@ -1765,7 +1697,7 @@ async def _build_unlink_user_url( code_verifier: str, state: str, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None + domain: Optional[str] = None, ) -> str: """Build a URL for unlinking user accounts""" # Generate code challenge from verifier @@ -1776,8 +1708,9 @@ async def _build_unlink_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get("authorization_endpoint", - f"https://{resolved_domain}/authorize") + auth_endpoint = metadata.get( + "authorization_endpoint", f"https://{resolved_domain}/authorize" + ) # Build params params = { @@ -1789,7 +1722,7 @@ async def _build_unlink_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid unlink_account", - "audience": "my-account" + "audience": "my-account", } # Add any additional parameters if authorization_params: @@ -1804,9 +1737,7 @@ async def _build_unlink_user_url( # ============================================================================ async def get_access_token_for_connection( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> str: """ Retrieves an access token for a connection. @@ -1840,13 +1771,13 @@ async def get_access_token_for_connection( if not session_domain: raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) # Find existing connection token @@ -1865,21 +1796,24 @@ async def get_access_token_for_connection( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_REFRESH_TOKEN, - "A refresh token was not found but is required to be able to retrieve an access token for a connection." + "A refresh token was not found but is required to be able to retrieve an access token for a connection.", ) # Get new token for connection # Use session's domain for token exchange session_domain = state_data_dict.get("domain") or self._domain - token_endpoint_response = await self.get_token_for_connection({ - "connection": options.get("connection"), - "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain - }) + token_endpoint_response = await self.get_token_for_connection( + { + "connection": options.get("connection"), + "login_hint": options.get("login_hint"), + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain, + } + ) # Update state data with new token updated_state_data = State.update_state_data_for_connection_token_set( - options, state_data_dict, token_endpoint_response) + options, state_data_dict, token_endpoint_response + ) # Store updated state await self._state_store.set(self._state_identifier, updated_state_data, store_options) @@ -1903,8 +1837,12 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A """ # Constants SUBJECT_TYPE_REFRESH_TOKEN = "urn:ietf:params:oauth:token-type:refresh_token" - REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" - GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( + "http://auth0.com/oauth/token-type/federated-connection-access-token" + ) + GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( + "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + ) try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain @@ -1914,8 +1852,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", - "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") # Prepare parameters params = { @@ -1924,7 +1861,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A "subject_token": options["refresh_token"], "requested_token_type": REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, "grant_type": GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, - "client_id": self._client_id + "client_id": self._client_id, } # Add login_hint if provided @@ -1934,38 +1871,41 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # Make the request async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = response.json() if response.headers.get( - "content-type") == "application/json" else {} + error_data = ( + response.json() + if response.headers.get("content-type") == "application/json" + else {} + ) raise ApiError( error_data.get("error", "connection_token_error"), error_data.get( - "error_description", f"Failed to get token for connection: {response.status_code}") + "error_description", + f"Failed to get token for connection: {response.status_code}", + ), ) token_endpoint_response = response.json() return { "access_token": token_endpoint_response.get("access_token"), - "expires_at": int(time.time()) + int(token_endpoint_response.get("expires_in", 3600)), - "scope": token_endpoint_response.get("scope", "") + "expires_at": int(time.time()) + + int(token_endpoint_response.get("expires_in", 3600)), + "scope": token_endpoint_response.get("scope", ""), } except Exception as e: if isinstance(e, ApiError): raise AccessTokenForConnectionError( - AccessTokenForConnectionErrorCode.API_ERROR, - str(e) + AccessTokenForConnectionErrorCode.API_ERROR, str(e) ) raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.FETCH_ERROR, "There was an error while trying to retrieve an access token for a connection.", - e + e, ) # ============================================================================ @@ -1975,9 +1915,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # ============================================================================ async def start_connect_account( - self, - options: ConnectAccountOptions, - store_options: dict = None + self, options: ConnectAccountOptions, store_options: dict = None ) -> str: """ Initiates the connect account flow for linking a third-party account to the user's profile. @@ -2002,26 +1940,25 @@ async def start_connect_account( code_verifier = PKCE.generate_code_verifier() code_challenge = PKCE.generate_code_challenge(code_verifier) - state= PKCE.generate_random_string(32) + state = PKCE.generate_random_string(32) connect_request = ConnectAccountRequest( connection=options.connection, scopes=options.scopes, - redirect_uri = redirect_uri, + redirect_uri=redirect_uri, code_challenge=code_challenge, code_challenge_method="S256", state=state, - authorization_params=options.authorization_params + authorization_params=options.authorization_params, ) access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options + store_options=store_options, ) connect_response = await self._my_account_client.connect_account( - access_token=access_token, - request=connect_request + access_token=access_token, request=connect_request ) # Build the transaction data to store @@ -2029,24 +1966,29 @@ async def start_connect_account( code_verifier=code_verifier, app_state=options.app_state, auth_session=connect_response.auth_session, - redirect_uri=redirect_uri + redirect_uri=redirect_uri, ) # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) parsedUrl = urlparse(connect_response.connect_uri) query = urlencode({"ticket": connect_response.connect_params.ticket}) - return urlunparse((parsedUrl.scheme, parsedUrl.netloc, parsedUrl.path, parsedUrl.params, query, parsedUrl.fragment)) + return urlunparse( + ( + parsedUrl.scheme, + parsedUrl.netloc, + parsedUrl.path, + parsedUrl.params, + query, + parsedUrl.fragment, + ) + ) async def complete_connect_account( - self, - url: str, - store_options: dict = None + self, url: str, store_options: dict = None ) -> CompleteConnectAccountResponse: """ Handles the redirect callback to complete the connect account flow for linking a third-party @@ -2078,7 +2020,9 @@ async def complete_connect_account( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) + transaction_data = await self._transaction_store.get( + transaction_identifier, options=store_options + ) if not transaction_data: raise MissingTransactionError() @@ -2086,18 +2030,19 @@ async def complete_connect_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options + store_options=store_options, ) request = CompleteConnectAccountRequest( auth_session=transaction_data.auth_session, connect_code=connect_code, redirect_uri=transaction_data.redirect_uri, - code_verifier=transaction_data.code_verifier + code_verifier=transaction_data.code_verifier, ) try: response = await self._my_account_client.complete_connect_account( - access_token=access_token, request=request) + access_token=access_token, request=request + ) if transaction_data.app_state is not None: response.app_state = transaction_data.app_state finally: @@ -2111,7 +2056,7 @@ async def list_connected_accounts( connection: Optional[str] = None, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None + store_options: dict = None, ) -> ListConnectedAccountsResponse: """ Retrieves a list of connected accounts for the authenticated user. @@ -2135,15 +2080,14 @@ async def list_connected_accounts( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options + store_options=store_options, ) return await self._my_account_client.list_connected_accounts( - access_token=access_token, connection=connection, from_param=from_param, take=take) + access_token=access_token, connection=connection, from_param=from_param, take=take + ) async def delete_connected_account( - self, - connected_account_id: str, - store_options: dict = None + self, connected_account_id: str, store_options: dict = None ) -> None: """ Deletes a connected account. @@ -2162,16 +2106,17 @@ async def delete_connected_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="delete:me:connected_accounts", - store_options=store_options + store_options=store_options, ) await self._my_account_client.delete_connected_account( - access_token=access_token, connected_account_id=connected_account_id) + access_token=access_token, connected_account_id=connected_account_id + ) async def list_connected_account_connections( self, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None + store_options: dict = None, ) -> ListConnectedAccountConnectionsResponse: """ Retrieves a list of available connections that can be used connected accounts for the authenticated user. @@ -2194,10 +2139,11 @@ async def list_connected_account_connections( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options + store_options=store_options, ) return await self._my_account_client.list_connected_account_connections( - access_token=access_token, from_param=from_param, take=take) + access_token=access_token, from_param=from_param, take=take + ) # ============================================================================ # CUSTOM TOKEN EXCHANGE (RFC 8693) @@ -2205,9 +2151,7 @@ async def list_connected_account_connections( # ============================================================================ async def custom_token_exchange( - self, - options: CustomTokenExchangeOptions, - store_options: Optional[dict[str, Any]] = None + self, options: CustomTokenExchangeOptions, store_options: Optional[dict[str, Any]] = None ) -> TokenExchangeResponse: """ Exchanges a custom token for Auth0 tokens using RFC 8693. @@ -2280,7 +2224,12 @@ async def custom_token_exchange( # Merge additional authorization params if options.authorization_params: # Prevent override of critical parameters - forbidden_params = {"grant_type", "client_id", "subject_token", "subject_token_type"} + forbidden_params = { + "grant_type", + "client_id", + "subject_token", + "subject_token_type", + } for key, value in options.authorization_params.items(): if key not in forbidden_params: params[key] = value @@ -2288,17 +2237,20 @@ async def custom_token_exchange( # Make the token exchange request async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = response.json() if response.headers.get( - "content-type", "").startswith("application/json") else {} + error_data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) raise CustomTokenExchangeError( error_data.get("error", CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED), - error_data.get("error_description", f"Token exchange failed: {response.status_code}") + error_data.get( + "error_description", f"Token exchange failed: {response.status_code}" + ), ) try: @@ -2306,7 +2258,7 @@ async def custom_token_exchange( except json.JSONDecodeError: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_RESPONSE, - "Failed to parse token response as JSON" + "Failed to parse token response as JSON", ) # Validate and return response @@ -2315,7 +2267,7 @@ async def custom_token_exchange( except ValidationError as e: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_TOKEN_FORMAT, - f"Token validation failed: {str(e)}" + f"Token validation failed: {str(e)}", ) except Exception as e: if isinstance(e, (CustomTokenExchangeError, ApiError)): @@ -2323,13 +2275,13 @@ async def custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Token exchange failed: {str(e)}", - e + e, ) async def login_with_custom_token_exchange( self, options: LoginWithCustomTokenExchangeOptions, - store_options: Optional[dict[str, Any]] = None + store_options: Optional[dict[str, Any]] = None, ) -> LoginWithCustomTokenExchangeResult: """ Performs token exchange and establishes a user session. @@ -2374,10 +2326,12 @@ async def login_with_custom_token_exchange( actor_token=options.actor_token, actor_token_type=options.actor_token_type, organization=options.organization, - authorization_params=options.authorization_params + authorization_params=options.authorization_params, ) - token_response = await self.custom_token_exchange(exchange_options, store_options=store_options) + token_response = await self.custom_token_exchange( + exchange_options, store_options=store_options + ) # Resolve domain and fetch metadata for verification domain = await self._resolve_current_domain(store_options) @@ -2409,28 +2363,18 @@ async def login_with_custom_token_exchange( raise ApiError("jwks_key_not_found", str(e)) except jwt.InvalidSignatureError as e: raise ApiError( - "invalid_signature", - f"ID token signature verification failed: {str(e)}", - e + "invalid_signature", f"ID token signature verification failed: {str(e)}", e ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}: {str(e)}", - e + e, ) except jwt.ExpiredSignatureError as e: - raise ApiError( - "token_expired", - f"ID token has expired: {str(e)}", - e - ) + raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) except jwt.InvalidTokenError as e: - raise ApiError( - "invalid_token", - f"ID token verification failed: {str(e)}", - e - ) + raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) # Determine audience for token set audience = options.audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -2440,7 +2384,7 @@ async def login_with_custom_token_exchange( audience=audience, access_token=token_response.access_token, scope=token_response.scope or options.scope or "", - expires_at=int(time.time()) + token_response.expires_in + expires_at=int(time.time()) + token_response.expires_in, ) # Construct state data @@ -2450,19 +2394,14 @@ async def login_with_custom_token_exchange( refresh_token=token_response.refresh_token, token_sets=[token_set], domain=domain, - internal={ - "sid": sid, - "created_at": int(time.time()) - } + internal={"sid": sid, "created_at": int(time.time())}, ) # Store session await self._state_store.set(self._state_identifier, state_data, options=store_options) # Build result - result = LoginWithCustomTokenExchangeResult( - state_data=state_data.dict() - ) + result = LoginWithCustomTokenExchangeResult(state_data=state_data.dict()) return result @@ -2472,9 +2411,291 @@ async def login_with_custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Login with custom token exchange failed: {str(e)}", - e + e, ) + # ============================================================================ + # PASSKEY AUTHENTICATION (Category 1) + # ============================================================================ + + GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" + + async def passkey_signup_challenge( + self, + name: Optional[str] = None, + email: Optional[str] = None, + username: Optional[str] = None, + phone_number: Optional[str] = None, + given_name: Optional[str] = None, + family_name: Optional[str] = None, + nickname: Optional[str] = None, + picture: Optional[str] = None, + user_metadata: Optional[dict[str, Any]] = None, + connection: Optional[str] = None, + organization: Optional[str] = None, + store_options: Optional[dict[str, Any]] = None, + ) -> PasskeySignupChallengeResponse: + """ + Step 1 of 2: Initiate a passkey signup challenge (POST /passkey/register). + + Pass the returned authn_params_public_key to navigator.credentials.create(), + then call signin_with_passkey() with the auth_session and credential result. + + Args: + name: User's full name. + email: User's email address. + username: Username for the new account. + phone_number: User's phone number. + given_name: User's given (first) name. + family_name: User's family (last) name. + nickname: User's nickname. + picture: URL to the user's profile picture. + user_metadata: Arbitrary user metadata dict. + connection: Auth0 database connection name (realm). + organization: Auth0 organization ID or name. + store_options: Optional options for domain resolution. + + Returns: + PasskeySignupChallengeResponse with auth_session and authn_params_public_key. + + Raises: + ApiError: If the challenge request fails. + """ + try: + domain = await self._resolve_current_domain(store_options) + + user_profile: dict[str, Any] = {} + if email is not None: + user_profile["email"] = email + if name is not None: + user_profile["name"] = name + if username is not None: + user_profile["username"] = username + if phone_number is not None: + user_profile["phone_number"] = phone_number + if given_name is not None: + user_profile["given_name"] = given_name + if family_name is not None: + user_profile["family_name"] = family_name + if nickname is not None: + user_profile["nickname"] = nickname + if picture is not None: + user_profile["picture"] = picture + if user_metadata is not None: + user_profile["user_metadata"] = user_metadata + + body: dict[str, Any] = {"client_id": self._client_id} + if self._client_secret: + body["client_secret"] = self._client_secret + if user_profile: + body["user_profile"] = user_profile + if connection: + body["realm"] = connection + if organization: + body["organization"] = organization + + url = f"https://{domain}/passkey/register" + + async with self._get_http_client() as client: + response = await client.post(url, json=body) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "passkey_challenge_error", + f"Passkey signup challenge failed with status {response.status_code}", + ) + raise ApiError( + error_data.get("error", "passkey_challenge_error"), + error_data.get("error_description", "Passkey signup challenge failed"), + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "invalid_response", + "Failed to parse passkey signup challenge response as JSON", + ) + + return PasskeySignupChallengeResponse.model_validate(data) + + except Exception as e: + if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + raise + raise ApiError("passkey_challenge_error", "Passkey signup challenge failed", e) + + async def passkey_login_challenge( + self, + username: Optional[str] = None, + connection: Optional[str] = None, + organization: Optional[str] = None, + store_options: Optional[dict[str, Any]] = None, + ) -> PasskeyLoginChallengeResponse: + """ + Step 1 of 2: Initiate a passkey login challenge (POST /passkey/challenge). + + Pass the returned authn_params_public_key to navigator.credentials.get(), + then call signin_with_passkey() with the auth_session and credential result. + + Args: + username: Optional username hint for conditional UI. + connection: Auth0 database connection name (realm). + organization: Auth0 organization ID or name. + store_options: Optional options for domain resolution. + + Returns: + PasskeyLoginChallengeResponse with auth_session and authn_params_public_key. + + Raises: + ApiError: If the challenge request fails. + """ + try: + domain = await self._resolve_current_domain(store_options) + + body: dict[str, Any] = {"client_id": self._client_id} + if self._client_secret: + body["client_secret"] = self._client_secret + if username: + body["username"] = username + if connection: + body["realm"] = connection + if organization: + body["organization"] = organization + + url = f"https://{domain}/passkey/challenge" + + async with self._get_http_client() as client: + response = await client.post(url, json=body) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "passkey_challenge_error", + f"Passkey login challenge failed with status {response.status_code}", + ) + raise ApiError( + error_data.get("error", "passkey_challenge_error"), + error_data.get("error_description", "Passkey login challenge failed"), + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "invalid_response", + "Failed to parse passkey login challenge response as JSON", + ) + + return PasskeyLoginChallengeResponse.model_validate(data) + + except Exception as e: + if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + raise + raise ApiError("passkey_challenge_error", "Passkey login challenge failed", e) + + async def signin_with_passkey( + self, + auth_session: str, + authn_response: PasskeyAuthResponse, + store_options: Optional[dict[str, Any]] = None, + connection: Optional[str] = None, + organization: Optional[str] = None, + scope: Optional[str] = None, + audience: Optional[str] = None, + ) -> PasskeyTokenResponse: + """ + Completes passkey authentication by exchanging the WebAuthn assertion + for tokens (POST /oauth/token with webauthn grant). + + This is step 2 of 2: call passkey_signup_challenge or passkey_login_challenge + first to obtain auth_session and the WebAuthn challenge options. + + Uses Content-Type: application/json (required for nested authn_response). + + Args: + auth_session: Session credential from passkey_signup_challenge or passkey_login_challenge. + authn_response: Serialized WebAuthn credential from navigator.credentials.create/get. + store_options: Optional options for domain resolution and state store. + connection: Auth0 database connection name (realm). + organization: Auth0 organization ID or name. + scope: OAuth2 scope string. + audience: Target API audience. + + Returns: + PasskeyTokenResponse containing access_token, id_token, expires_in, etc. + + Raises: + MissingRequiredArgumentError: If auth_session or authn_response is missing. + ApiError: If token exchange fails. + """ + if not auth_session: + raise MissingRequiredArgumentError("auth_session") + if authn_response is None: + raise MissingRequiredArgumentError("authn_response") + + try: + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) + + token_endpoint = metadata.get("token_endpoint") + if not token_endpoint: + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + + body: dict[str, Any] = { + "grant_type": self.GRANT_TYPE_PASSKEY, + "client_id": self._client_id, + "auth_session": auth_session, + "authn_response": authn_response.model_dump(by_alias=True, exclude_none=True), + } + if self._client_secret: + body["client_secret"] = self._client_secret + if connection: + body["realm"] = connection + if organization: + body["organization"] = organization + if scope: + body["scope"] = scope + if audience: + body["audience"] = audience + + async with self._get_http_client() as client: + response = await client.post(token_endpoint, json=body) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "passkey_token_error", + f"Passkey token exchange failed with status {response.status_code}", + ) + raise ApiError( + error_data.get("error", "passkey_token_error"), + error_data.get("error_description", "Passkey token exchange failed"), + ) + + try: + token_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "invalid_response", "Failed to parse passkey token response as JSON" + ) + + if "expires_in" in token_data and "expires_at" not in token_data: + token_data["expires_at"] = int(time.time()) + token_data["expires_in"] + + return PasskeyTokenResponse.model_validate(token_data) + + except Exception as e: + if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + raise + raise ApiError("passkey_token_error", "Passkey sign-in failed", e) + # ============================================================================ # MFA (Multi-Factor Authentication) # ============================================================================ diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 055103a..d306efa 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -5,7 +5,7 @@ from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator class UserClaims(BaseModel): @@ -13,6 +13,7 @@ class UserClaims(BaseModel): User profile information as returned by Auth0. Contains standard OIDC claims about the authenticated user. """ + sub: str name: Optional[str] = None nickname: Optional[str] = None @@ -32,6 +33,7 @@ class TokenSet(BaseModel): Represents a set of tokens issued by Auth0. Contains the access token and related metadata. """ + audience: str access_token: str scope: Optional[str] = None @@ -43,6 +45,7 @@ class ConnectionTokenSet(TokenSet): Token set specific to a connection. Extends TokenSet with connection-specific information. """ + connection: str login_hint: str @@ -52,6 +55,7 @@ class InternalStateData(BaseModel): Internal data used for managing state. Not meant to be accessed directly by SDK users. """ + sid: str created_at: int @@ -61,6 +65,7 @@ class SessionData(BaseModel): Represents a user session with Auth0. Contains user information and tokens. """ + user: Optional[UserClaims] = None id_token: Optional[str] = None refresh_token: Optional[str] = None @@ -77,6 +82,7 @@ class StateData(SessionData): Complete state data stored in the state store. Extends SessionData with internal management information. """ + internal: InternalStateData @@ -85,6 +91,7 @@ class TransactionData(BaseModel): Represents data for an in-progress authentication transaction. Used during the authorization code flow to correlate requests. """ + audience: Optional[str] = None code_verifier: str app_state: Optional[Any] = None @@ -101,6 +108,7 @@ class LogoutTokenClaims(BaseModel): Claims expected in a logout token. Used for backchannel logout processing. """ + sub: str sid: str iss: Optional[str] = None @@ -111,6 +119,7 @@ class EncryptedStoreOptions(BaseModel): Options for encrypted stores. Contains the secret used for encryption. """ + secret: str @@ -119,6 +128,7 @@ class ServerClientOptionsBase(BaseModel): Base options for configuring the Auth0 server client. Contains core settings required for all clients. """ + domain: str client_id: str client_secret: str @@ -135,6 +145,7 @@ class ServerClientOptionsWithSecret(ServerClientOptionsBase): Client options using a secret for encryption. Extends base options with secret and duration settings. """ + secret: str state_absolute_duration: Optional[int] = 259200 # 3 days in seconds @@ -144,6 +155,7 @@ class StartInteractiveLoginOptions(BaseModel): Options for starting the interactive login process. Configures how the authorization request is constructed. """ + pushed_authorization_requests: Optional[bool] = False app_state: Optional[Any] = None authorization_params: Optional[dict[str, Any]] = None @@ -154,6 +166,7 @@ class LogoutOptions(BaseModel): Options for logout operations. Configures how the logout request is constructed. """ + return_to: Optional[str] = None @@ -162,6 +175,7 @@ class AuthorizationParameters(BaseModel): Parameters used in authorization requests. Based on standard OAuth2/OIDC parameters. """ + scope: Optional[str] = None audience: Optional[str] = None redirect_uri: Optional[str] = None @@ -169,11 +183,13 @@ class AuthorizationParameters(BaseModel): class Config: extra = "allow" # Allow additional OAuth parameters + class AuthorizationDetails(BaseModel): """ Authorization details returned from Auth0. Used for Resource Access Rights (RAR). """ + type: str actions: Optional[list[str]] = None locations: Optional[list[str]] = None @@ -188,6 +204,7 @@ class LoginBackchannelOptions(BaseModel): """ Options for Client-Initiated Backchannel Authentication. """ + binding_message: str login_hint: dict[str, str] # Should contain a 'sub' field authorization_params: Optional[dict[str, Any]] = None @@ -200,6 +217,7 @@ class LoginBackchannelResult(BaseModel): """ Result from Client-Initiated Backchannel Authentication. """ + authorization_details: Optional[list[AuthorizationDetails]] = None @@ -207,19 +225,23 @@ class AccessTokenForConnectionOptions(BaseModel): """ Options for retrieving an access token for a specific connection. """ + connection: str login_hint: Optional[str] = None + class StartLinkUserOptions(BaseModel): connection: str connection_scope: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None app_state: Optional[Any] = None + # ============================================================================= # Multiple Custom Domain # ============================================================================= + class DomainResolverContext(BaseModel): """ Context passed to domain resolver function for MCD support. @@ -236,13 +258,16 @@ async def domain_resolver(context: DomainResolverContext) -> str: host = context.request_headers.get('host', '').split(':')[0] return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) """ + request_url: Optional[str] = None request_headers: Optional[dict[str, str]] = None + # ============================================================================= # Custom Token Exchange Types # ============================================================================= + class CustomTokenExchangeOptions(BaseModel): """ Options for custom token exchange (RFC 8693). @@ -257,6 +282,7 @@ class CustomTokenExchangeOptions(BaseModel): organization: Organization identifier for the token exchange (optional) authorization_params: Additional OAuth parameters (optional) """ + subject_token: str = Field(..., min_length=1) subject_token_type: str = Field(..., min_length=1) audience: Optional[str] = None @@ -266,7 +292,7 @@ class CustomTokenExchangeOptions(BaseModel): organization: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None - @field_validator('subject_token', 'actor_token') + @field_validator("subject_token", "actor_token") @classmethod def validate_token_format(cls, v: Optional[str]) -> Optional[str]: """Validate token doesn't have Bearer prefix and isn't whitespace-only.""" @@ -277,8 +303,8 @@ def validate_token_format(cls, v: Optional[str]) -> Optional[str]: raise ValueError("Token should not include 'Bearer ' prefix") return v - @model_validator(mode='after') - def validate_actor_token_type(self) -> 'CustomTokenExchangeOptions': + @model_validator(mode="after") + def validate_actor_token_type(self) -> "CustomTokenExchangeOptions": """Ensure actor_token_type is provided if actor_token is present.""" if self.actor_token and not self.actor_token_type: raise ValueError("actor_token_type is required when actor_token is provided") @@ -298,6 +324,7 @@ class TokenExchangeResponse(BaseModel): id_token: OpenID Connect ID token (optional) refresh_token: Refresh token (optional) """ + access_token: str token_type: str = "Bearer" expires_in: int @@ -313,6 +340,7 @@ class LoginWithCustomTokenExchangeOptions(BaseModel): Combines token exchange parameters with session management. """ + subject_token: str = Field(..., min_length=1) subject_token_type: str = Field(..., min_length=1) audience: Optional[str] = None @@ -322,7 +350,7 @@ class LoginWithCustomTokenExchangeOptions(BaseModel): organization: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None - @field_validator('subject_token', 'actor_token') + @field_validator("subject_token", "actor_token") @classmethod def validate_token_format(cls, v: Optional[str]) -> Optional[str]: """Validate token doesn't have Bearer prefix and isn't whitespace-only.""" @@ -333,8 +361,8 @@ def validate_token_format(cls, v: Optional[str]) -> Optional[str]: raise ValueError("Token should not include 'Bearer ' prefix") return v - @model_validator(mode='after') - def validate_actor_token_type(self) -> 'LoginWithCustomTokenExchangeOptions': + @model_validator(mode="after") + def validate_actor_token_type(self) -> "LoginWithCustomTokenExchangeOptions": """Ensure actor_token_type is provided if actor_token is present.""" if self.actor_token and not self.actor_token_type: raise ValueError("actor_token_type is required when actor_token is provided") @@ -347,13 +375,16 @@ class LoginWithCustomTokenExchangeResult(BaseModel): Contains session data established after token exchange. """ + state_data: dict[str, Any] authorization_details: Optional[list[AuthorizationDetails]] = None + # ============================================================================= # Connected Accounts Types # ============================================================================= + # BASE & SHARED class ConnectedAccountBase(BaseModel): id: str @@ -363,6 +394,7 @@ class ConnectedAccountBase(BaseModel): created_at: str expires_at: Optional[str] = None + # ENTITIES (What exists) class ConnectedAccount(ConnectedAccountBase): id: str @@ -381,6 +413,7 @@ class ConnectedAccountConnection(BaseModel): # Connect Operations (How to connect) + class ConnectAccountOptions(BaseModel): connection: str redirect_uri: Optional[str] = None @@ -388,43 +421,244 @@ class ConnectAccountOptions(BaseModel): app_state: Optional[Any] = None authorization_params: Optional[dict[str, Any]] = None + class ConnectAccountRequest(BaseModel): connection: str scopes: Optional[list[str]] = None redirect_uri: Optional[str] = None state: Optional[str] = None code_challenge: Optional[str] = None - code_challenge_method: Optional[str] = 'S256' + code_challenge_method: Optional[str] = "S256" authorization_params: Optional[dict[str, Any]] = None + class ConnectParams(BaseModel): ticket: str + class ConnectAccountResponse(BaseModel): auth_session: str connect_uri: str connect_params: ConnectParams expires_in: int + class CompleteConnectAccountRequest(BaseModel): auth_session: str connect_code: str redirect_uri: str code_verifier: Optional[str] = None + class CompleteConnectAccountResponse(ConnectedAccountBase): app_state: Optional[Any] = None + # Manage operations class ListConnectedAccountsResponse(BaseModel): accounts: list[ConnectedAccount] next: Optional[str] = None + class ListConnectedAccountConnectionsResponse(BaseModel): connections: list[ConnectedAccountConnection] next: Optional[str] = None +# ============================================================================= +# Passkey & MyAccount Authentication Methods Types +# ============================================================================= + + +class PasskeyRpInfo(BaseModel): + id: str + name: str + + +class PasskeyUserInfo(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + name: str + display_name: Optional[str] = Field(None, alias="displayName") + + +class PasskeyPubKeyCredParam(BaseModel): + type: str + alg: int + + +class PasskeyAuthenticatorSelection(BaseModel): + model_config = ConfigDict(populate_by_name=True) + resident_key: Optional[str] = Field(None, alias="residentKey") + user_verification: Optional[str] = Field(None, alias="userVerification") + + +class PasskeyPublicKeyOptions(BaseModel): + model_config = ConfigDict(populate_by_name=True) + challenge: str + rp: Optional[PasskeyRpInfo] = None + rp_id: Optional[str] = Field(None, alias="rpId") + user: Optional[PasskeyUserInfo] = None + pub_key_cred_params: Optional[list[PasskeyPubKeyCredParam]] = Field( + None, alias="pubKeyCredParams" + ) + authenticator_selection: Optional[PasskeyAuthenticatorSelection] = Field( + None, alias="authenticatorSelection" + ) + timeout: Optional[int] = None + user_verification: Optional[str] = Field(None, alias="userVerification") + + +class EnrollAuthenticationMethodRequest(BaseModel): + type: str + email: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[str] = None + user_identity_id: Optional[str] = None + connection: Optional[str] = None + + +class EnrollmentChallengeResponse(BaseModel): + authentication_method_id: str + auth_session: str + authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None + + def __repr__(self) -> str: + return ( + f"EnrollmentChallengeResponse(" + f"authentication_method_id={self.authentication_method_id!r}, " + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyAuthResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + raw_id: str = Field(alias="rawId") + type: str + authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") + response: dict[str, str] + client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") + + +class VerifyAuthenticationMethodRequest(BaseModel): + auth_session: str + authn_response: Optional[PasskeyAuthResponse] = None + otp_code: Optional[str] = None + recovery_code: Optional[str] = None + password: Optional[str] = None + + @model_validator(mode="after") + def _check_at_least_one_method(self) -> "VerifyAuthenticationMethodRequest": + has_method = ( + self.authn_response is not None + or self.otp_code is not None + or self.recovery_code is not None + or self.password is not None + ) + if not has_method: + raise ValueError( + "At least one verification method must be provided: " + "authn_response, otp_code, recovery_code, or password" + ) + return self + + +class AuthenticationMethod(BaseModel): + model_config = ConfigDict(extra="allow", populate_by_name=True) + + id: str + type: str + created_at: str + confirmed: Optional[bool] = None + usage: Optional[list[str]] = None + identity_user_id: Optional[str] = None + credential_device_type: Optional[str] = None + credential_backed_up: Optional[bool] = None + key_id: Optional[str] = None + public_key: Optional[str] = None + transports: Optional[list[str]] = None + user_agent: Optional[str] = None + user_handle: Optional[str] = None + aaguid: Optional[str] = None + relying_party_id: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[str] = None + email: Optional[str] = None + name: Optional[str] = None + last_password_reset: Optional[str] = None + + +class UpdateAuthenticationMethodRequest(BaseModel): + name: Optional[str] = None + preferred_authentication_method: Optional[str] = None + + +class ListAuthenticationMethodsResponse(BaseModel): + authentication_methods: list[AuthenticationMethod] + + +class Factor(BaseModel): + model_config = ConfigDict(extra="allow") + name: str + enabled: Optional[bool] = None + trial_expired: Optional[bool] = None + + +class GetFactorsResponse(BaseModel): + factors: list[Factor] + + +class PasskeySignupChallengeResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeySignupChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyLoginChallengeResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeyLoginChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyTokenResponse(BaseModel): + model_config = ConfigDict(extra="allow", populate_by_name=True) + access_token: str + token_type: str = "Bearer" + expires_in: int + expires_at: Optional[int] = None + scope: Optional[str] = None + id_token: Optional[str] = None + refresh_token: Optional[str] = None + + def __repr__(self) -> str: + return ( + f"PasskeyTokenResponse(" + f"token_type={self.token_type!r}, " + f"expires_in={self.expires_in!r}, " + f"expires_at={self.expires_at!r}, " + f"scope={self.scope!r}, " + f"access_token=[REDACTED], " + f"id_token=[REDACTED], " + f"refresh_token=[REDACTED])" + ) + + # ============================================================================= # MFA Types # ============================================================================= @@ -437,6 +671,7 @@ class ListConnectedAccountConnectionsResponse(BaseModel): class AuthenticatorResponse(BaseModel): """Represents an MFA authenticator enrolled by a user.""" + id: str authenticator_type: AuthenticatorType active: bool @@ -450,14 +685,17 @@ class AuthenticatorResponse(BaseModel): # Enrollment Options + class EnrollOtpOptions(BaseModel): """Options for enrolling an OTP authenticator.""" + authenticator_types: list[str] mfa_token: str class EnrollOobOptions(BaseModel): """Options for enrolling an OOB authenticator (SMS, Voice, Push).""" + authenticator_types: list[str] oob_channels: list[OobChannel] phone_number: Optional[str] = None @@ -466,6 +704,7 @@ class EnrollOobOptions(BaseModel): class EnrollEmailOptions(BaseModel): """Options for enrolling an email authenticator.""" + authenticator_types: list[str] oob_channels: list[OobChannel] email: Optional[str] = None @@ -477,8 +716,10 @@ class EnrollEmailOptions(BaseModel): # Enrollment Responses + class OtpEnrollmentResponse(BaseModel): """Response when enrolling an OTP authenticator.""" + authenticator_type: Literal["otp"] secret: str barcode_uri: str @@ -488,6 +729,7 @@ class OtpEnrollmentResponse(BaseModel): class OobEnrollmentResponse(BaseModel): """Response when enrolling an OOB authenticator.""" + authenticator_type: Literal["oob"] oob_channel: OobChannel oob_code: Optional[str] = None @@ -502,8 +744,10 @@ class OobEnrollmentResponse(BaseModel): # Challenge Types + class ChallengeOptions(BaseModel): """Options for initiating an MFA challenge.""" + challenge_type: ChallengeType authenticator_id: Optional[str] = None mfa_token: str @@ -511,6 +755,7 @@ class ChallengeOptions(BaseModel): class ChallengeResponse(BaseModel): """Response from initiating an MFA challenge.""" + challenge_type: ChallengeType oob_code: Optional[str] = None binding_method: Optional[str] = None @@ -519,21 +764,26 @@ class ChallengeResponse(BaseModel): # List Options + class ListAuthenticatorsOptions(BaseModel): """Options for listing MFA authenticators.""" + mfa_token: str # Verify Types + class VerifyOtpOptions(BaseModel): """Verify with OTP code.""" + mfa_token: str otp: str class VerifyOobOptions(BaseModel): """Verify with OOB code + binding code.""" + mfa_token: str oob_code: str binding_code: str @@ -541,6 +791,7 @@ class VerifyOobOptions(BaseModel): class VerifyRecoveryCodeOptions(BaseModel): """Verify with recovery code.""" + mfa_token: str recovery_code: str @@ -550,6 +801,7 @@ class VerifyRecoveryCodeOptions(BaseModel): class MfaVerifyResponse(BaseModel): """Response from MFA verification.""" + access_token: str token_type: str = "Bearer" expires_in: int @@ -562,24 +814,28 @@ class MfaVerifyResponse(BaseModel): # MFA Requirements + class MfaRequirement(BaseModel): """A single MFA requirement entry.""" + type: str class MfaRequirements(BaseModel): """MFA requirements from an mfa_required error response.""" + enroll: Optional[list[MfaRequirement]] = None challenge: Optional[list[MfaRequirement]] = None # MFA Token Context (for encrypted storage) + class MfaTokenContext(BaseModel): """Internal context stored inside encrypted mfa_token.""" + mfa_token: str audience: str scope: str mfa_requirements: Optional[MfaRequirements] = None created_at: int - diff --git a/src/auth0_server_python/tests/test_dpop_auth.py b/src/auth0_server_python/tests/test_dpop_auth.py new file mode 100644 index 0000000..b6beb69 --- /dev/null +++ b/src/auth0_server_python/tests/test_dpop_auth.py @@ -0,0 +1,145 @@ +import base64 +import hashlib +import json + +import httpx +import pytest +from jwcrypto import jwk + +from auth0_server_python.auth_schemes.bearer_auth import BearerAuth +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, _base64url +from auth0_server_python.auth_server.my_account_client import _make_auth + + +@pytest.fixture +def ec_key(): + return jwk.JWK.generate(kty="EC", crv="P-256") + + +def _decode_jwt_parts(token: str) -> tuple[dict, dict]: + parts = token.split(".") + header = json.loads(base64.urlsafe_b64decode(parts[0] + "==")) + payload = json.loads(base64.urlsafe_b64decode(parts[1] + "==")) + return header, payload + + +def test_dpop_headers_set(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("POST", "https://example.com/me/v1/authentication-methods") + flow = auth.auth_flow(request) + modified = next(flow) + + assert modified.headers["Authorization"] == "DPoP test_token" + assert "DPoP" in modified.headers + assert "Bearer" not in modified.headers["Authorization"] + + +def test_dpop_proof_structure(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("POST", "https://example.com/me/v1/authentication-methods") + flow = auth.auth_flow(request) + modified = next(flow) + + proof = modified.headers["DPoP"] + header, payload = _decode_jwt_parts(proof) + + assert header["typ"] == "dpop+jwt" + assert header["alg"] == "ES256" + assert "jwk" in header + assert header["jwk"]["kty"] == "EC" + assert header["jwk"]["crv"] == "P-256" + + assert "jti" in payload + assert payload["htm"] == "POST" + assert payload["htu"] == "https://example.com/me/v1/authentication-methods" + assert "iat" in payload + assert "ath" in payload + + +def test_dpop_htm_binding(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + + get_request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(get_request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htm"] == "GET" + + post_request = httpx.Request("post", "https://example.com/me/v1/factors") + flow = auth.auth_flow(post_request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htm"] == "POST" + + +def test_dpop_htu_strips_query_and_fragment(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("GET", "https://example.com/me/v1/factors?foo=bar#section") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htu"] == "https://example.com/me/v1/factors" + + +def test_dpop_htu_preserves_port(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("GET", "https://example.com:8443/me/v1/factors") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htu"] == "https://example.com:8443/me/v1/factors" + + +def test_dpop_ath_binding(ec_key): + token = "my_access_token_value" + auth = DPoPAuth(token=token, key=ec_key) + request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + + expected_ath = _base64url(hashlib.sha256(token.encode("ascii")).digest()) + assert payload["ath"] == expected_ath + + +def test_dpop_proof_uniqueness(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + jtis = set() + for _ in range(10): + request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + jtis.add(payload["jti"]) + + assert len(jtis) == 10 + + +def test_dpop_repr_redacts_credentials(ec_key): + auth = DPoPAuth(token="secret_access_token_value", key=ec_key) + assert "secret_access_token_value" not in repr(auth) + assert "secret_access_token_value" not in str(auth) + assert "[REDACTED]" in repr(auth) + assert "[REDACTED]" in str(auth) + + +def test_dpop_rejects_non_ec_key(): + rsa_key = jwk.JWK.generate(kty="RSA", size=2048) + with pytest.raises(ValueError, match="EC P-256"): + DPoPAuth(token="token", key=rsa_key) + + +def test_dpop_rejects_wrong_curve(): + p384_key = jwk.JWK.generate(kty="EC", crv="P-384") + with pytest.raises(ValueError, match="EC P-256"): + DPoPAuth(token="token", key=p384_key) + + +def test_make_auth_bearer_fallback(): + auth = _make_auth("token123", dpop_key=None) + assert isinstance(auth, BearerAuth) + + +def test_make_auth_dpop_when_key_provided(ec_key): + auth = _make_auth("token123", dpop_key=ec_key) + assert isinstance(auth, DPoPAuth) diff --git a/src/auth0_server_python/tests/test_passkey_my_account.py b/src/auth0_server_python/tests/test_passkey_my_account.py new file mode 100644 index 0000000..d7f181d --- /dev/null +++ b/src/auth0_server_python/tests/test_passkey_my_account.py @@ -0,0 +1,473 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from jwcrypto import jwk as jwk_module + +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth +from auth0_server_python.auth_server.my_account_client import MyAccountClient +from auth0_server_python.auth_types import ( + AuthenticationMethod, + EnrollAuthenticationMethodRequest, + EnrollmentChallengeResponse, + GetFactorsResponse, + ListAuthenticationMethodsResponse, + PasskeyAuthResponse, + UpdateAuthenticationMethodRequest, + VerifyAuthenticationMethodRequest, +) +from auth0_server_python.error import ApiError, MissingRequiredArgumentError, MyAccountApiError + + +@pytest.mark.asyncio +async def test_get_factors_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + + assert isinstance(result, GetFactorsResponse) + assert len(result.factors) == 1 + assert result.factors[0].name == "sms" + assert result.factors[0].enabled is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("access_token", [None, ""]) +async def test_get_factors_missing_access_token(mocker, access_token): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_factors(access_token=access_token) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_factors_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock( + return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Insufficient scope", + "status": 403, + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_factors(access_token="token123") + + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_get_factors_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.get_factors(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_factors_empty_list(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors == [] + + +@pytest.mark.asyncio +async def test_get_factors_extra_fields(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors[0].name == "webauthn-roaming" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "authentication_methods": [ + { + "id": "am_1", + "type": "passkey", + "created_at": "2026-01-01T00:00:00Z", + "key_id": "kid1", + } + ] + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert isinstance(result, ListAuthenticationMethodsResponse) + assert len(result.authentication_methods) == 1 + assert result.authentication_methods[0].type == "passkey" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_type_filter(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.list_authentication_methods(access_token="token123", type_filter="passkey") + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert call_kwargs["params"] == {"type": "passkey"} + + +@pytest.mark.asyncio +async def test_list_authentication_methods_empty(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert result.authentication_methods == [] + + +@pytest.mark.asyncio +async def test_get_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert isinstance(result, AuthenticationMethod) + assert result.id == "am_1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_get_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_authentication_method_path_traversal(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "id/slash", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="id/slash" + ) + call_url = mock_get.call_args[1]["url"] + assert "id%2Fslash" in call_url + assert "id/slash" not in call_url.replace("https://auth0.local/me/", "") + + +@pytest.mark.asyncio +async def test_get_authentication_method_pipe_encoding(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "passkey|new", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="passkey|new" + ) + call_url = mock_get.call_args[1]["url"] + assert "passkey%7Cnew" in call_url + + +@pytest.mark.asyncio +async def test_delete_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + result = await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert result is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_delete_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_delete = mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_update_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "id": "am_1", + "type": "passkey", + "created_at": "2026-01-01T00:00:00Z", + "name": "My Key", + } + ) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + req = UpdateAuthenticationMethodRequest(name="My Key") + result = await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert result.name == "My Key" + call_kwargs = mock_patch.call_args[1] + assert call_kwargs["json"] == {"name": "My Key"} + + +@pytest.mark.asyncio +async def test_update_authentication_method_missing_request(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=None + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock( + return_value={ + "auth_session": "session_abc", + "authn_params_public_key": { + "challenge": "dGVzdA", + "rp": {"id": "auth0.local", "name": "My App"}, + "user": {"id": "dXNlcl8x", "name": "user@test.com", "displayName": "Test User"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": { + "residentKey": "required", + "userVerification": "preferred", + }, + "timeout": 60000, + }, + } + ) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert isinstance(result, EnrollmentChallengeResponse) + assert result.authentication_method_id == "passkey|new" + assert result.auth_session == "session_abc" + assert result.authn_params_public_key is not None + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + assert result.authn_params_public_key.user.display_name == "Test User" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_missing_location(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + + assert "Location header" in str(exc.value) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_with_query(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/abc123?tracking=1"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "abc123" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_absolute_url(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "https://tenant.auth0.com/me/v1/authentication-methods/am_xyz"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "am_xyz" + + +@pytest.mark.asyncio +async def test_verify_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "id": "am_1", + "type": "passkey", + "created_at": "2026-01-01T00:00:00Z", + "confirmed": True, + } + ) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + authn_response = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + authenticator_attachment="platform", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest( + auth_session="session_abc", authn_response=authn_response + ) + result = await client.verify_authentication_method( + access_token="token123", authentication_method_id="passkey|new", request=req + ) + + assert isinstance(result, AuthenticationMethod) + assert result.confirmed is True + + call_kwargs = mock_post.call_args[1] + body = call_kwargs["json"] + assert "rawId" in body["authn_response"] + assert "raw_id" not in body["authn_response"] + assert "authenticatorAttachment" in body["authn_response"] + assert body["auth_session"] == "session_abc" + assert "passkey%7Cnew" in call_kwargs["url"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_verify_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(MissingRequiredArgumentError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id=method_id, request=req + ) + + +@pytest.mark.asyncio +async def test_enrollment_challenge_response_repr(): + resp = EnrollmentChallengeResponse( + authentication_method_id="am_1", + auth_session="super_secret_session", + authn_params_public_key=None, + ) + repr_str = repr(resp) + assert "super_secret_session" not in repr_str + assert "[REDACTED]" in repr_str + assert "am_1" in repr_str + + +def test_verify_request_requires_at_least_one_method(): + with pytest.raises(Exception, match="At least one verification method"): + VerifyAuthenticationMethodRequest(auth_session="session_abc") + + +def test_verify_request_accepts_otp_code(): + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + assert req.otp_code == "123456" + + +def test_verify_request_accepts_authn_response(): + authn_resp = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", authn_response=authn_resp) + assert req.authn_response is not None + + +@pytest.mark.asyncio +async def test_get_factors_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_factors(access_token="token123", dpop_key=dpop_key) + + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert isinstance(call_kwargs["auth"], DPoPAuth) diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py new file mode 100644 index 0000000..8d39410 --- /dev/null +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -0,0 +1,523 @@ +import time +from unittest.mock import AsyncMock + +import httpx +import pytest + +from auth0_server_python.auth_server.server_client import ServerClient +from auth0_server_python.auth_types import ( + PasskeyAuthResponse, + PasskeyLoginChallengeResponse, + PasskeySignupChallengeResponse, + PasskeyTokenResponse, +) +from auth0_server_python.error import ApiError, MissingRequiredArgumentError + + +@pytest.fixture +def server_client(): + return ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + + +SIGNUP_CHALLENGE_RESPONSE = { + "auth_session": "session_abc123", + "authn_params_public_key": { + "challenge": "dGVzdC1jaGFsbGVuZ2U", + "rp": {"id": "auth0.local", "name": "Test App"}, + "user": {"id": "dXNlcl8x", "name": "user@example.com", "displayName": "Jane"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": { + "residentKey": "required", + "userVerification": "preferred", + }, + "timeout": 60000, + }, +} + +LOGIN_CHALLENGE_RESPONSE = { + "auth_session": "session_login_xyz", + "authn_params_public_key": { + "challenge": "bG9naW4tY2hhbGxlbmdl", + "rpId": "auth0.local", + "timeout": 60000, + "userVerification": "preferred", + }, +} + +TOKEN_RESPONSE = { + "access_token": "at_passkey_123", + "id_token": "eyJ.test.jwt", + "token_type": "Bearer", + "expires_in": 86400, + "scope": "openid profile", +} + + +def _mock_response(status_code=200, json_data=None, headers=None): + resp = httpx.Response( + status_code=status_code, + json=json_data, + headers=headers or {}, + request=httpx.Request("POST", "https://auth0.local/passkey/register"), + ) + return resp + + +# ============================================================================= +# passkey_signup_challenge +# ============================================================================= + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_success(server_client, mocker): + mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + result = await server_client.passkey_signup_challenge( + email="user@example.com", + name="Jane Doe", + connection="Username-Password-Authentication", + ) + + assert isinstance(result, PasskeySignupChallengeResponse) + assert result.auth_session == "session_abc123" + assert result.authn_params_public_key.challenge == "dGVzdC1jaGFsbGVuZ2U" + assert result.authn_params_public_key.rp.id == "auth0.local" + assert result.authn_params_public_key.user.display_name == "Jane" + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + + call_args = mock_client.post.call_args + assert "/passkey/register" in call_args.args[0] + body = call_args.kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["user_profile"]["email"] == "user@example.com" + assert body["user_profile"]["name"] == "Jane Doe" + assert body["realm"] == "Username-Password-Authentication" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_user_profile_fields(server_client, mocker): + mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + await server_client.passkey_signup_challenge( + email="u@e.com", + username="jdoe", + phone_number="+1234567890", + given_name="Jane", + family_name="Doe", + nickname="jd", + picture="https://example.com/pic.jpg", + user_metadata={"role": "admin"}, + organization="org_123", + ) + + body = mock_client.post.call_args.kwargs["json"] + assert body["user_profile"]["email"] == "u@e.com" + assert body["user_profile"]["username"] == "jdoe" + assert body["user_profile"]["phone_number"] == "+1234567890" + assert body["user_profile"]["given_name"] == "Jane" + assert body["user_profile"]["family_name"] == "Doe" + assert body["user_profile"]["nickname"] == "jd" + assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" + assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert body["organization"] == "org_123" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_minimal_body(server_client, mocker): + mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + await server_client.passkey_signup_challenge() + + body = mock_client.post.call_args.kwargs["json"] + assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} + assert "user_profile" not in body + assert "realm" not in body + assert "organization" not in body + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_api_error(server_client, mocker): + error_resp = _mock_response( + 403, + {"error": "access_denied", "error_description": "Passkey not enabled"}, + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=error_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError) as exc: + await server_client.passkey_signup_challenge(email="test@example.com") + assert "access_denied" in str(exc.value) or "Passkey not enabled" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_non_json_error(server_client, mocker): + resp = httpx.Response( + status_code=502, + content=b"Bad Gateway", + headers={"content-type": "text/html"}, + request=httpx.Request("POST", "https://auth0.local/passkey/register"), + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError) as exc: + await server_client.passkey_signup_challenge() + assert "502" in str(exc.value) or "passkey_challenge_error" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_network_error(server_client, mocker): + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=Exception("Connection refused")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError) as exc: + await server_client.passkey_signup_challenge() + assert "Passkey signup challenge failed" in str(exc.value) + + +# ============================================================================= +# passkey_login_challenge +# ============================================================================= + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_success(server_client, mocker): + mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + result = await server_client.passkey_login_challenge( + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyLoginChallengeResponse) + assert result.auth_session == "session_login_xyz" + assert result.authn_params_public_key.challenge == "bG9naW4tY2hhbGxlbmdl" + assert result.authn_params_public_key.rp_id == "auth0.local" + assert result.authn_params_public_key.user_verification == "preferred" + + body = mock_client.post.call_args.kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_with_username(server_client, mocker): + mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + await server_client.passkey_login_challenge(username="jane@example.com") + + body = mock_client.post.call_args.kwargs["json"] + assert body["username"] == "jane@example.com" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_api_error(server_client, mocker): + error_resp = _mock_response( + 400, + {"error": "invalid_request", "error_description": "Missing client_id"}, + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=error_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError): + await server_client.passkey_login_challenge() + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_network_error(server_client, mocker): + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=Exception("timeout")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError): + await server_client.passkey_login_challenge() + + +# ============================================================================= +# signin_with_passkey +# ============================================================================= + + +@pytest.fixture +def authn_response(): + return PasskeyAuthResponse( + id="cred_abc123", + raw_id="Y3JlZF9hYmMxMjM", + type="public-key", + authenticator_attachment="platform", + response={ + "clientDataJSON": "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0In0", + "authenticatorData": "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2M", + "signature": "MEUCIQC", + "userHandle": "dXNlcl8x", + }, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_success(server_client, authn_response, mocker): + mock_response = _mock_response(200, TOKEN_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + result = await server_client.signin_with_passkey( + auth_session="session_xyz", + authn_response=authn_response, + scope="openid profile", + audience="https://api.example.com", + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyTokenResponse) + assert result.access_token == "at_passkey_123" + assert result.token_type == "Bearer" + assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 + + body = mock_client.post.call_args.kwargs["json"] + assert body["grant_type"] == "urn:okta:params:oauth:grant-type:webauthn" + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["auth_session"] == "session_xyz" + assert body["scope"] == "openid profile" + assert body["audience"] == "https://api.example.com" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + assert body["authn_response"]["rawId"] == "Y3JlZF9hYmMxMjM" + assert body["authn_response"]["authenticatorAttachment"] == "platform" + assert "raw_id" not in body["authn_response"] + + +@pytest.mark.asyncio +async def test_signin_with_passkey_uses_json_content_type(server_client, authn_response, mocker): + mock_response = _mock_response(200, TOKEN_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + await server_client.signin_with_passkey( + auth_session="s", + authn_response=authn_response, + ) + + call_kwargs = mock_client.post.call_args.kwargs + assert "json" in call_kwargs + assert "data" not in call_kwargs + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_session", [None, ""]) +async def test_signin_with_passkey_missing_auth_session( + server_client, authn_response, auth_session +): + with pytest.raises(MissingRequiredArgumentError): + await server_client.signin_with_passkey( + auth_session=auth_session, + authn_response=authn_response, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_authn_response(server_client): + with pytest.raises(MissingRequiredArgumentError): + await server_client.signin_with_passkey( + auth_session="session_abc", + authn_response=None, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_api_error(server_client, authn_response, mocker): + error_resp = _mock_response( + 401, + {"error": "invalid_grant", "error_description": "Invalid auth_session"}, + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=error_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + with pytest.raises(ApiError) as exc: + await server_client.signin_with_passkey( + auth_session="expired_session", + authn_response=authn_response, + ) + assert "invalid_grant" in str(exc.value) or "Invalid auth_session" in str(exc.value) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_token_endpoint(server_client, authn_response, mocker): + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={}, + ) + + with pytest.raises(ApiError) as exc: + await server_client.signin_with_passkey( + auth_session="session", + authn_response=authn_response, + ) + assert "token endpoint" in str(exc.value).lower() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_network_error(server_client, authn_response, mocker): + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=Exception("Connection reset")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + with pytest.raises(ApiError): + await server_client.signin_with_passkey( + auth_session="session", + authn_response=authn_response, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_no_client_secret(mocker): + client = ServerClient( + domain="auth0.local", + client_id="public_client", + client_secret=None, + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret", + ) + + mock_response = _mock_response(200, TOKEN_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + authn_resp = PasskeyAuthResponse( + id="cred", + raw_id="cmF3", + type="public-key", + response={"clientDataJSON": "abc", "authenticatorData": "def", "signature": "ghi"}, + ) + + await client.signin_with_passkey( + auth_session="session", + authn_response=authn_resp, + ) + + body = mock_client.post.call_args.kwargs["json"] + assert "client_secret" not in body + assert body["client_id"] == "public_client" + + +@pytest.mark.asyncio +async def test_signup_challenge_repr_redacts_auth_session(): + resp = PasskeySignupChallengeResponse.model_validate(SIGNUP_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_abc123" not in repr_str + assert "[REDACTED]" in repr_str + + +@pytest.mark.asyncio +async def test_login_challenge_repr_redacts_auth_session(): + resp = PasskeyLoginChallengeResponse.model_validate(LOGIN_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_login_xyz" not in repr_str + assert "[REDACTED]" in repr_str + + +def test_passkey_token_response_repr_redacts_tokens(): + resp = PasskeyTokenResponse( + access_token="secret_at_value", + token_type="Bearer", + expires_in=86400, + id_token="secret_id_token", + refresh_token="secret_rt_value", + ) + repr_str = repr(resp) + assert "secret_at_value" not in repr_str + assert "secret_id_token" not in repr_str + assert "secret_rt_value" not in repr_str + assert "[REDACTED]" in repr_str + assert "86400" in repr_str From ec66c0f30db6d51f3319fb9d03fc55a109e80fa7 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 14:12:56 +0530 Subject: [PATCH 02/12] Added missing test cases and edge case fix --- .../auth_server/my_account_client.py | 12 +- .../tests/test_passkey_my_account.py | 357 ++++++++++++++++++ .../tests/test_passkey_server_client.py | 62 +++ 3 files changed, 427 insertions(+), 4 deletions(-) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index a6aed8f..9089186 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -643,10 +643,14 @@ async def enroll_authentication_method( "Enrollment succeeded (201) but Location header is missing", ) - authentication_method_id = ( - location.split("?")[0].split("#")[0].rstrip("/").split("/")[-1] - ) - if not authentication_method_id: + path = location.split("?")[0].split("#")[0].rstrip("/") + segments = path.split("/") + authentication_method_id = segments[-1] if len(segments) > 1 else "" + if not authentication_method_id or authentication_method_id in ( + "authentication-methods", + "v1", + "me", + ): raise ApiError( "enroll_authentication_method_error", "Enrollment succeeded (201) but could not extract ID from Location header", diff --git a/src/auth0_server_python/tests/test_passkey_my_account.py b/src/auth0_server_python/tests/test_passkey_my_account.py index d7f181d..4b7f29d 100644 --- a/src/auth0_server_python/tests/test_passkey_my_account.py +++ b/src/auth0_server_python/tests/test_passkey_my_account.py @@ -471,3 +471,360 @@ async def test_get_factors_with_dpop_key(mocker): mock_get.assert_awaited_once() call_kwargs = mock_get.call_args[1] assert isinstance(call_kwargs["auth"], DPoPAuth) + + +# ============================================================================= +# DPoP integration(mock) tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.list_authentication_methods(access_token="token123", dpop_key=dpop_key) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_get_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mock_delete = mocker.patch( + "httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_delete.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_update_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = UpdateAuthenticationMethodRequest(name="New Name") + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_patch.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = EnrollAuthenticationMethodRequest(type="passkey") + await client.enroll_authentication_method( + access_token="token123", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + await client.verify_authentication_method( + access_token="token123", + authentication_method_id="am_1", + request=req, + dpop_key=dpop_key, + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +# ============================================================================= +# API error and network error tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_authentication_methods_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock( + return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Insufficient scope", + "status": 403, + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.list_authentication_methods(access_token="token123") + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_list_authentication_methods_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.list_authentication_methods(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock( + return_value={ + "title": "Not Found", + "type": "not_found", + "detail": "Not found", + "status": 404, + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_get_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("timeout")) + + with pytest.raises(ApiError): + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock( + return_value={ + "title": "Not Found", + "type": "not_found", + "detail": "Not found", + "status": 404, + } + ) + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_delete_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.delete", + new_callable=AsyncMock, + side_effect=Exception("Connection reset"), + ) + + with pytest.raises(ApiError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_update_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 422 + response.json = MagicMock( + return_value={ + "title": "Unprocessable", + "type": "validation_error", + "detail": "Invalid", + "status": 422, + } + ) + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(MyAccountApiError) as exc: + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 422 + + +@pytest.mark.asyncio +async def test_update_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, side_effect=Exception("timeout") + ) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(ApiError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock( + return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Scope missing", + "status": 403, + } + ) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(MyAccountApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError): + await client.enroll_authentication_method(access_token="token123", request=req) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 400 + response.json = MagicMock( + return_value={ + "title": "Bad Request", + "type": "invalid_request", + "detail": "Invalid OTP", + "status": 400, + } + ) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="000000") + with pytest.raises(MyAccountApiError) as exc: + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 400 + + +@pytest.mark.asyncio +async def test_verify_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(ApiError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +# ============================================================================= +# Location header extraction edge case +# ============================================================================= + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_collection_url(mocker): + """Rejects Location header that ends at collection path without resource ID.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert "could not extract ID" in str(exc.value) diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py index 8d39410..7c2be37 100644 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -521,3 +521,65 @@ def test_passkey_token_response_repr_redacts_tokens(): assert "secret_rt_value" not in repr_str assert "[REDACTED]" in repr_str assert "86400" in repr_str + + +# ============================================================================= +# expires_at edge cases +# ============================================================================= + + +@pytest.mark.asyncio +async def test_signin_with_passkey_preserves_server_expires_at( + server_client, authn_response, mocker +): + token_data = { + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": 9999999999, + } + mock_response = _mock_response(200, token_data) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + result = await server_client.signin_with_passkey( + auth_session="session", authn_response=authn_response + ) + + assert result.expires_at == 9999999999 + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_expires_at_calculates( + server_client, authn_response, mocker +): + token_data = { + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 60, + } + mock_response = _mock_response(200, token_data) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + result = await server_client.signin_with_passkey( + auth_session="session", authn_response=authn_response + ) + + assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 From d2d1f216d4e159027d4bf85ef60a636fd13d2078 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 15:45:22 +0530 Subject: [PATCH 03/12] Resolved snake case to camel case for correct parsing --- src/auth0_server_python/auth_server/server_client.py | 10 +++++----- .../tests/test_passkey_server_client.py | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 334eb00..d5118c7 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -2472,23 +2472,23 @@ async def passkey_signup_challenge( if username is not None: user_profile["username"] = username if phone_number is not None: - user_profile["phone_number"] = phone_number + user_profile["phoneNumber"] = phone_number if given_name is not None: - user_profile["given_name"] = given_name + user_profile["givenName"] = given_name if family_name is not None: - user_profile["family_name"] = family_name + user_profile["familyName"] = family_name if nickname is not None: user_profile["nickname"] = nickname if picture is not None: user_profile["picture"] = picture - if user_metadata is not None: - user_profile["user_metadata"] = user_metadata body: dict[str, Any] = {"client_id": self._client_id} if self._client_secret: body["client_secret"] = self._client_secret if user_profile: body["user_profile"] = user_profile + if user_metadata is not None: + body["userMetadata"] = user_metadata if connection: body["realm"] = connection if organization: diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py index 7c2be37..2d644af 100644 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -132,12 +132,13 @@ async def test_passkey_signup_challenge_user_profile_fields(server_client, mocke body = mock_client.post.call_args.kwargs["json"] assert body["user_profile"]["email"] == "u@e.com" assert body["user_profile"]["username"] == "jdoe" - assert body["user_profile"]["phone_number"] == "+1234567890" - assert body["user_profile"]["given_name"] == "Jane" - assert body["user_profile"]["family_name"] == "Doe" + assert body["user_profile"]["phoneNumber"] == "+1234567890" + assert body["user_profile"]["givenName"] == "Jane" + assert body["user_profile"]["familyName"] == "Doe" assert body["user_profile"]["nickname"] == "jd" assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert "user_metadata" not in body["user_profile"] + assert body["userMetadata"] == {"role": "admin"} assert body["organization"] == "org_123" From 7691a0c12fead5a655fd58860249a349daeb7c09 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 16:05:59 +0530 Subject: [PATCH 04/12] Reverting to snake case - as per auth0 api docs. --- src/auth0_server_python/auth_server/server_client.py | 10 +++++----- .../tests/test_passkey_server_client.py | 9 ++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index d5118c7..334eb00 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -2472,23 +2472,23 @@ async def passkey_signup_challenge( if username is not None: user_profile["username"] = username if phone_number is not None: - user_profile["phoneNumber"] = phone_number + user_profile["phone_number"] = phone_number if given_name is not None: - user_profile["givenName"] = given_name + user_profile["given_name"] = given_name if family_name is not None: - user_profile["familyName"] = family_name + user_profile["family_name"] = family_name if nickname is not None: user_profile["nickname"] = nickname if picture is not None: user_profile["picture"] = picture + if user_metadata is not None: + user_profile["user_metadata"] = user_metadata body: dict[str, Any] = {"client_id": self._client_id} if self._client_secret: body["client_secret"] = self._client_secret if user_profile: body["user_profile"] = user_profile - if user_metadata is not None: - body["userMetadata"] = user_metadata if connection: body["realm"] = connection if organization: diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py index 2d644af..7c2be37 100644 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -132,13 +132,12 @@ async def test_passkey_signup_challenge_user_profile_fields(server_client, mocke body = mock_client.post.call_args.kwargs["json"] assert body["user_profile"]["email"] == "u@e.com" assert body["user_profile"]["username"] == "jdoe" - assert body["user_profile"]["phoneNumber"] == "+1234567890" - assert body["user_profile"]["givenName"] == "Jane" - assert body["user_profile"]["familyName"] == "Doe" + assert body["user_profile"]["phone_number"] == "+1234567890" + assert body["user_profile"]["given_name"] == "Jane" + assert body["user_profile"]["family_name"] == "Doe" assert body["user_profile"]["nickname"] == "jd" assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert "user_metadata" not in body["user_profile"] - assert body["userMetadata"] == {"role": "admin"} + assert body["user_profile"]["user_metadata"] == {"role": "admin"} assert body["organization"] == "org_123" From 3233fff43ee6985053b8b261a8838057df6578a2 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 17:15:27 +0530 Subject: [PATCH 05/12] Reverted lint fixes for easier review --- .../auth_server/my_account_client.py | 56 +- .../auth_server/server_client.py | 683 ++++++++++-------- 2 files changed, 410 insertions(+), 329 deletions(-) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 9089186..bd23a12 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -74,7 +74,9 @@ def audience(self): return f"https://{self._domain}/me/" async def connect_account( - self, access_token: str, request: ConnectAccountRequest + self, + access_token: str, + request: ConnectAccountRequest ) -> ConnectAccountResponse: """ Initiate the connected account flow. @@ -95,7 +97,7 @@ async def connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/connect", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 201: @@ -105,7 +107,7 @@ async def connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -118,11 +120,13 @@ async def connect_account( raise ApiError( "connect_account_error", f"Connected Accounts connect request failed: {str(e) or 'Unknown error'}", - e, + e ) async def complete_connect_account( - self, access_token: str, request: CompleteConnectAccountRequest + self, + access_token: str, + request: CompleteConnectAccountRequest ) -> CompleteConnectAccountResponse: """ Complete the connected account flow after user authorization. @@ -143,7 +147,7 @@ async def complete_connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/complete", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 201: @@ -153,7 +157,7 @@ async def complete_connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -166,7 +170,7 @@ async def complete_connect_account( raise ApiError( "connect_account_error", f"Connected Accounts complete request failed: {str(e) or 'Unknown error'}", - e, + e ) async def list_connected_accounts( @@ -174,7 +178,7 @@ async def list_connected_accounts( access_token: str, connection: Optional[str] = None, from_param: Optional[str] = None, - take: Optional[int] = None, + take: Optional[int] = None ) -> ListConnectedAccountsResponse: """ List connected accounts for the authenticated user. @@ -213,7 +217,7 @@ async def list_connected_accounts( response = await client.get( url=f"{self.audience}v1/connected-accounts/accounts", params=params, - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 200: @@ -223,7 +227,7 @@ async def list_connected_accounts( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -236,10 +240,15 @@ async def list_connected_accounts( raise ApiError( "connect_account_error", f"Connected Accounts list request failed: {str(e) or 'Unknown error'}", - e, + e ) - async def delete_connected_account(self, access_token: str, connected_account_id: str) -> None: + + async def delete_connected_account( + self, + access_token: str, + connected_account_id: str + ) -> None: """ Delete a connected account for the authenticated user. @@ -266,7 +275,7 @@ async def delete_connected_account(self, access_token: str, connected_account_id async with self._get_http_client() as client: response = await client.delete( url=f"{self.audience}v1/connected-accounts/accounts/{connected_account_id}", - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 204: @@ -276,7 +285,7 @@ async def delete_connected_account(self, access_token: str, connected_account_id type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) except Exception as e: @@ -285,11 +294,14 @@ async def delete_connected_account(self, access_token: str, connected_account_id raise ApiError( "connect_account_error", f"Connected Accounts delete request failed: {str(e) or 'Unknown error'}", - e, + e ) async def list_connected_account_connections( - self, access_token: str, from_param: Optional[str] = None, take: Optional[int] = None + self, + access_token: str, + from_param: Optional[str] = None, + take: Optional[int] = None ) -> ListConnectedAccountConnectionsResponse: """ List available connections that support connected accounts. @@ -325,7 +337,7 @@ async def list_connected_account_connections( response = await client.get( url=f"{self.audience}v1/connected-accounts/connections", params=params, - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 200: @@ -335,7 +347,7 @@ async def list_connected_account_connections( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -348,9 +360,13 @@ async def list_connected_account_connections( raise ApiError( "connect_account_error", f"Connected Accounts list connections request failed: {str(e) or 'Unknown error'}", - e, + e ) + # ============================================================================ + # AUTHENTICATION METHODS & FACTORS (Passkey / MyAccount API) + # ============================================================================ + async def get_factors( self, access_token: str, diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 334eb00..a233205 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -69,18 +69,11 @@ ) # Generic type for store options -TStoreOptions = TypeVar("TStoreOptions") +TStoreOptions = TypeVar('TStoreOptions') # redirect_uri is intentionally excluded — in MCD mode it is built # dynamically from the resolved domain at login time. -INTERNAL_AUTHORIZE_PARAMS = [ - "client_id", - "response_type", - "code_challenge", - "code_challenge_method", - "state", - "nonce", - "scope", -] +INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", + "code_challenge", "code_challenge_method", "state", "nonce", "scope"] class ServerClient(Generic[TStoreOptions]): @@ -88,7 +81,6 @@ class ServerClient(Generic[TStoreOptions]): Main client for Auth0 server SDK. Handles authentication flows, session management, and token operations using Authlib for OIDC functionality. """ - DEFAULT_AUDIENCE_STATE_KEY = "default" # ============================================================================ @@ -129,7 +121,9 @@ def __init__( raise MissingRequiredArgumentError("secret") if domain is None: - raise ConfigurationError("Domain is required") + raise ConfigurationError( + "Domain is required" + ) # Validate domain type if not isinstance(domain, str) and not callable(domain): @@ -174,12 +168,14 @@ def __init__( headers=self._telemetry_headers, ) - self._my_account_client = MyAccountClient(domain=domain, headers=self._telemetry_headers) + self._my_account_client = MyAccountClient( + domain=domain, headers=self._telemetry_headers + ) # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) self._discovery_cache: OrderedDict[str, dict] = OrderedDict() - self._cache_ttl = 600 # 10 mins. TTL - self._cache_max_entries = 100 # Max 100 domains + self._cache_ttl = 600 # 10 mins. TTL + self._cache_max_entries = 100 # Max 100 domains # Initialize MFA client self._mfa_client = MfaClient( @@ -206,14 +202,14 @@ def _normalize_url(self, value: str) -> str: return value value = value.lower() - if value.startswith("https://"): + if value.startswith('https://'): pass - elif value.startswith("http://"): - value = value.replace("http://", "https://") + elif value.startswith('http://'): + value = value.replace('http://', 'https://') else: - value = f"https://{value}" + value = f'https://{value}' - return value.rstrip("/") + return value.rstrip('/') async def _resolve_current_domain(self, store_options=None) -> str: """Resolve domain from resolver function or return static domain.""" @@ -226,7 +222,8 @@ async def _resolve_current_domain(self, store_options=None) -> str: raise except Exception as e: raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", original_error=e + f"Domain resolver function raised an exception: {str(e)}", + original_error=e ) return self._domain @@ -240,18 +237,18 @@ def _get_session_domain(self, state_data_dict: dict) -> Optional[str]: 2. self._domain — static domain (if configured) 3. Extract hostname from user.iss — derive from user's issuer claim """ - domain = state_data_dict.get("domain") + domain = state_data_dict.get('domain') if domain: return domain if self._domain: return self._domain - user = state_data_dict.get("user") + user = state_data_dict.get('user') if isinstance(user, dict): - iss = user.get("iss") + iss = user.get('iss') else: - iss = getattr(user, "iss", None) if user else None + iss = getattr(user, 'iss', None) if user else None if iss: parsed = urlparse(iss) @@ -354,7 +351,7 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": None, - "expires_at": now + self._cache_ttl, + "expires_at": now + self._cache_ttl } return metadata @@ -416,11 +413,11 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: if not metadata: metadata = await self._get_oidc_metadata_cached(domain) - jwks_uri = metadata.get("jwks_uri") + jwks_uri = metadata.get('jwks_uri') if not jwks_uri: raise ApiError( "missing_jwks_uri", - f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant.", + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." ) # Fetch JWKS @@ -437,7 +434,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": jwks, - "expires_at": now + self._cache_ttl, + "expires_at": now + self._cache_ttl } return jwks @@ -449,7 +446,9 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: # ============================================================================ async def start_interactive_login( - self, options: Optional[StartInteractiveLoginOptions] = None, store_options: dict = None + self, + options: Optional[StartInteractiveLoginOptions] = None, + store_options: dict = None ) -> str: """ Starts the interactive login process and returns a URL to redirect to. @@ -470,17 +469,15 @@ async def start_interactive_login( try: metadata = await self._get_oidc_metadata_cached(origin_domain) except Exception as e: - raise ApiError("metadata_error", "Failed to fetch OIDC metadata", e) + raise ApiError("metadata_error", + "Failed to fetch OIDC metadata", e) # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: auth_params.update( - { - k: v - for k, v in options.authorization_params.items() - if k not in INTERNAL_AUTHORIZE_PARAMS - } + {k: v for k, v in options.authorization_params.items( + ) if k not in INTERNAL_AUTHORIZE_PARAMS} ) # Ensure we have a redirect_uri @@ -504,11 +501,7 @@ async def start_interactive_login( auth_params["state"] = state # Merge any requested scope with defaults - requested_scope = ( - options.authorization_params.get("scope", None) - if options.authorization_params - else None - ) + requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope @@ -524,61 +517,65 @@ async def start_interactive_login( # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) # Set metadata for OAuth client self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: - par_endpoint = self._oauth.metadata.get("pushed_authorization_request_endpoint") + par_endpoint = self._oauth.metadata.get( + "pushed_authorization_request_endpoint") if not par_endpoint: raise ApiError( - "configuration_error", - "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata", - ) + "configuration_error", "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata") auth_params["client_id"] = self._client_id # Post the auth_params to the PAR endpoint async with self._get_http_client() as client: par_response = await client.post( - par_endpoint, data=auth_params, auth=(self._client_id, self._client_secret) + par_endpoint, + data=auth_params, + auth=(self._client_id, self._client_secret) ) if par_response.status_code not in (200, 201): error_data = par_response.json() raise ApiError( error_data.get("error", "par_error"), error_data.get( - "error_description", "Failed to obtain request_uri from PAR endpoint" - ), + "error_description", "Failed to obtain request_uri from PAR endpoint") ) par_data = par_response.json() request_uri = par_data.get("request_uri") if not request_uri: - raise ApiError("par_error", "No request_uri returned from PAR endpoint") + raise ApiError( + "par_error", "No request_uri returned from PAR endpoint") auth_endpoint = self._oauth.metadata.get("authorization_endpoint") final_url = f"{auth_endpoint}?request_uri={request_uri}&response_type={auth_params['response_type']}&client_id={self._client_id}" return final_url else: if "authorization_endpoint" not in self._oauth.metadata: - raise ApiError( - "configuration_error", "Authorization endpoint missing in OIDC metadata" - ) + raise ApiError("configuration_error", + "Authorization endpoint missing in OIDC metadata") authorization_endpoint = self._oauth.metadata["authorization_endpoint"] try: auth_url, state = self._oauth.create_authorization_url( - authorization_endpoint, **auth_params - ) + authorization_endpoint, **auth_params) except Exception as e: - raise ApiError("authorization_url_error", "Failed to create authorization URL", e) + raise ApiError("authorization_url_error", + "Failed to create authorization URL", e) return auth_url async def complete_interactive_login( - self, url: str, store_options: dict = None + self, + url: str, + store_options: dict = None ) -> dict[str, Any]: """ Completes the login process after user is redirected back. @@ -601,9 +598,7 @@ async def complete_interactive_login( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get( - transaction_identifier, options=store_options - ) + transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) if not transaction_data: raise MissingTransactionError() @@ -624,7 +619,7 @@ async def complete_interactive_login( # Fetch metadata and derive issuer from the origin domain metadata = await self._get_oidc_metadata_cached(origin_domain) - origin_issuer = metadata.get("issuer") + origin_issuer = metadata.get('issuer') self._oauth.metadata = metadata # Exchange the code for tokens @@ -640,7 +635,8 @@ async def complete_interactive_login( ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) - raise ApiError("token_error", f"Token exchange failed: {str(e)}", e) + raise ApiError( + "token_error", f"Token exchange failed: {str(e)}", e) # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") @@ -655,14 +651,14 @@ async def complete_interactive_login( # Decode and verify ID token with signature verification enabled try: - claims = await self._verify_and_decode_jwt(id_token, jwks, audience=self._client_id) + claims = await self._verify_and_decode_jwt( + id_token, jwks, audience=self._client_id + ) # Custom normalized issuer validation token_issuer = claims.get("iss", "") if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): - raise IssuerValidationError( - "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." - ) + raise IssuerValidationError("ID token issuer mismatch. Ensure your Auth0 domain is configured correctly.") user_claims = UserClaims.parse_obj(claims) except ValueError as e: @@ -671,33 +667,40 @@ async def complete_interactive_login( raise ApiError( "invalid_signature", f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", - e, + e ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", - e, + e ) except jwt.ExpiredSignatureError as e: - raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) except jwt.InvalidTokenError as e: - raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) + # Build a token set using the token response data token_set = TokenSet( audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.get("access_token", ""), scope=token_response.get("scope", ""), - expires_at=int(time.time()) + token_response.get("expires_in", 3600), + expires_at=int(time.time()) + + token_response.get("expires_in", 3600) ) # Generate a session id (sid) from token_response or transaction data, or create a new one - sid = ( - user_info.get("sid") - if user_info and "sid" in user_info - else PKCE.generate_random_string(32) - ) + sid = user_info.get( + "sid") if user_info and "sid" in user_info else PKCE.generate_random_string(32) # Construct state data to represent the session state_data = StateData( @@ -707,7 +710,10 @@ async def complete_interactive_login( refresh_token=token_response.get("refresh_token"), token_sets=[token_set], domain=origin_domain, - internal={"sid": sid, "created_at": int(time.time())}, + internal={ + "sid": sid, + "created_at": int(time.time()) + } ) # Store the state data in the state store using store_options (Response required) @@ -732,9 +738,7 @@ async def complete_interactive_login( # Methods for retrieving user information, session data, and logout operations. # ============================================================================ - async def get_user( - self, store_options: Optional[dict[str, Any]] = None - ) -> Optional[dict[str, Any]]: + async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -763,9 +767,7 @@ async def get_user( return state_data.get("user") return None - async def get_session( - self, store_options: Optional[dict[str, Any]] = None - ) -> Optional[dict[str, Any]]: + async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieve the user session from the store, or None if no session found. @@ -791,14 +793,15 @@ async def get_session( if self._normalize_url(session_domain) != self._normalize_url(current_domain): return None - session_data = {k: v for k, v in state_data.items() if k != "internal"} + session_data = {k: v for k, v in state_data.items() + if k != "internal"} return session_data return None async def logout( self, options: Optional[LogoutOptions] = None, - store_options: Optional[dict[str, Any]] = None, + store_options: Optional[dict[str, Any]] = None ) -> str: options = options or LogoutOptions() @@ -814,18 +817,19 @@ async def logout( if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_domain = self._get_session_domain(state_data) - if session_domain and self._normalize_url(session_domain) == self._normalize_url( - domain - ): + if session_domain and self._normalize_url(session_domain) == self._normalize_url(domain): await self._state_store.delete(self._state_identifier, store_options) # Return logout URL for the current resolved domain - logout_url = URL.create_logout_url(domain, self._client_id, options.return_to) + logout_url = URL.create_logout_url( + domain, self._client_id, options.return_to) return logout_url async def handle_backchannel_logout( - self, logout_token: str, store_options: Optional[dict[str, Any]] = None + self, + logout_token: str, + store_options: Optional[dict[str, Any]] = None ) -> None: """ Handles backchannel logout requests. @@ -846,7 +850,8 @@ async def handle_backchannel_logout( # Read iss from unverified token for comparison try: unverified = jwt.decode( - logout_token, algorithms=["RS256"], options={"verify_signature": False} + logout_token, algorithms=["RS256"], + options={"verify_signature": False} ) token_issuer = unverified.get("iss", "") except Exception as e: @@ -875,16 +880,13 @@ async def handle_backchannel_logout( jwks = await self._get_jwks_cached(domain) try: - claims = await self._verify_and_decode_jwt( - logout_token, jwks, audience=self._client_id - ) + claims = await self._verify_and_decode_jwt(logout_token, jwks, audience=self._client_id) # Normalized issuer validation token_issuer = claims.get("iss", "") expected_issuer = self._normalize_url(domain) if self._normalize_url(token_issuer) != self._normalize_url(expected_issuer): - raise IssuerValidationError( - "Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." + raise IssuerValidationError("Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." ) except ValueError as e: raise BackchannelLogoutError(str(e)) @@ -893,22 +895,30 @@ async def handle_backchannel_logout( f"Logout token signature verification failed: {str(e)}" ) except jwt.InvalidTokenError as e: - raise BackchannelLogoutError(f"Logout token verification failed: {str(e)}") + raise BackchannelLogoutError( + f"Logout token verification failed: {str(e)}" + ) # Validate the token is a logout token events = claims.get("events", {}) if "http://schemas.openid.net/event/backchannel-logout" not in events: - raise BackchannelLogoutError("Invalid logout token: not a backchannel logout event") + raise BackchannelLogoutError( + "Invalid logout token: not a backchannel logout event") # Delete sessions associated with this token logout_claims = LogoutTokenClaims( - sub=claims.get("sub"), sid=claims.get("sid"), iss=claims.get("iss") + sub=claims.get("sub"), + sid=claims.get("sid"), + iss=claims.get("iss") ) - await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) + await self._state_store.delete_by_logout_token( + logout_claims.dict(), store_options + ) except (jwt.PyJWTError, ValidationError) as e: - raise BackchannelLogoutError(f"Error processing logout token: {str(e)}") + raise BackchannelLogoutError( + f"Error processing logout token: {str(e)}") # ============================================================================ # ACCESS TOKEN MANAGEMENT @@ -949,13 +959,13 @@ async def get_access_token( if not session_domain: raise AccessTokenError( AccessTokenErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenError( AccessTokenErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) auth_params = self._default_authorization_params or {} @@ -969,9 +979,7 @@ async def get_access_token( # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: - token_set = self._find_matching_token_set( - state_data_dict["token_sets"], audience, merged_scope - ) + token_set = self._find_matching_token_set(state_data_dict["token_sets"], audience, merged_scope) # If token is valid, return it if token_set and token_set.get("expires_at", 0) > time.time(): @@ -981,7 +989,7 @@ async def get_access_token( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenError( AccessTokenErrorCode.MISSING_REFRESH_TOKEN, - "The access token has expired and a refresh token was not provided. The user needs to re-authenticate.", + "The access token has expired and a refresh token was not provided. The user needs to re-authenticate." ) # Get new token with refresh token @@ -990,7 +998,7 @@ async def get_access_token( session_domain = state_data_dict.get("domain") or self._domain get_refresh_token_options = { "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain, + "domain": session_domain } if audience: get_refresh_token_options["audience"] = audience @@ -998,20 +1006,15 @@ async def get_access_token( if merged_scope: get_refresh_token_options["scope"] = merged_scope - token_endpoint_response = await self.get_token_by_refresh_token( - get_refresh_token_options - ) + token_endpoint_response = await self.get_token_by_refresh_token(get_refresh_token_options) # Update state data with new token existing_state_data = await self._state_store.get(self._state_identifier, store_options) updated_state_data = State.update_state_data( - audience, existing_state_data, token_endpoint_response - ) + audience, existing_state_data, token_endpoint_response) # Store updated state - await self._state_store.set( - self._state_identifier, updated_state_data, options=store_options - ) + await self._state_store.set(self._state_identifier, updated_state_data, options=store_options) return token_endpoint_response["access_token"] except Exception as e: @@ -1025,21 +1028,22 @@ async def get_access_token( raw_mfa_token=raw_mfa_token, audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, scope=merged_scope or "", - mfa_requirements=mfa_requirements, + mfa_requirements=mfa_requirements ) raise MfaRequiredError( "Multifactor authentication required", mfa_token=encrypted_token, - mfa_requirements=mfa_requirements, + mfa_requirements=mfa_requirements ) if isinstance(e, AccessTokenError): raise raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, - f"Failed to get token with refresh token: {str(e)}", + f"Failed to get token with refresh token: {str(e)}" ) + async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -1067,7 +1071,8 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", + "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { @@ -1082,7 +1087,8 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Merge scope if present in options with any in the original authorization params merged_scope = self._merge_scope_with_defaults( - request_scope=options.get("scope"), audience=audience + request_scope=options.get("scope"), + audience=audience ) if merged_scope: @@ -1091,7 +1097,9 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Exchange the refresh token for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=token_params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1101,7 +1109,8 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Preserve mfa_required details for upstream handling if error_code == "mfa_required": error = ApiError( - error_code, error_data.get("error_description", "MFA required") + error_code, + error_data.get("error_description", "MFA required") ) error.mfa_token = error_data.get("mfa_token") mfa_requirements_data = error_data.get("mfa_requirements") @@ -1112,14 +1121,16 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise ApiError( error_code, - error_data.get("error_description", "Failed to exchange refresh token"), + error_data.get("error_description", + "Failed to exchange refresh token") ) token_response = response.json() # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int(time.time()) + token_response["expires_in"] + token_response["expires_at"] = int( + time.time()) + token_response["expires_in"] return token_response @@ -1129,11 +1140,13 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, "The access token has expired and there was an error while trying to refresh it.", - e, + e ) def _merge_scope_with_defaults( - self, request_scope: Optional[str], audience: Optional[str] + self, + request_scope: Optional[str], + audience: Optional[str] ) -> Optional[str]: """Helper: Merges requested scopes with default authorization params.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1154,7 +1167,10 @@ def _merge_scope_with_defaults( return " ".join(merged_scopes) if merged_scopes else None def _find_matching_token_set( - self, token_sets: list[dict[str, Any]], audience: Optional[str], scope: Optional[str] + self, + token_sets: list[dict[str, Any]], + audience: Optional[str], + scope: Optional[str] ) -> Optional[dict[str, Any]]: """Helper: Finds a token set matching the requested audience and scopes.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1180,7 +1196,9 @@ def _find_matching_token_set( # ============================================================================ async def login_backchannel( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Logs in using Client-Initiated Backchannel Authentication. @@ -1199,22 +1217,22 @@ async def login_backchannel( Returns: A dictionary containing the authorizationDetails (when RAR was used). """ - token_endpoint_response = await self.backchannel_authentication( - { - "binding_message": options.get("binding_message"), - "login_hint": options.get("login_hint"), - "authorization_params": options.get("authorization_params"), - }, - store_options=store_options, - ) + token_endpoint_response = await self.backchannel_authentication({ + "binding_message": options.get("binding_message"), + "login_hint": options.get("login_hint"), + "authorization_params": options.get("authorization_params"), + }, store_options=store_options) existing_state_data = await self._state_store.get(self._state_identifier, store_options) audience = self._default_authorization_params.get( - "audience", self.DEFAULT_AUDIENCE_STATE_KEY - ) + "audience", self.DEFAULT_AUDIENCE_STATE_KEY) - state_data = State.update_state_data(audience, existing_state_data, token_endpoint_response) + state_data = State.update_state_data( + audience, + existing_state_data, + token_endpoint_response + ) # Store domain for MCD session domain = await self._resolve_current_domain(store_options) @@ -1222,11 +1240,15 @@ async def login_backchannel( await self._state_store.set(self._state_identifier, state_data, store_options) - result = {"authorization_details": token_endpoint_response.get("authorization_details")} + result = { + "authorization_details": token_endpoint_response.get("authorization_details") + } return result async def backchannel_authentication( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Performs backchannel authentication with Auth0. @@ -1251,12 +1273,12 @@ async def backchannel_authentication( Raises: ApiError: If the backchannel authentication fails """ - backchannel_data = await self.initiate_backchannel_authentication( - options, store_options=store_options - ) + backchannel_data = await self.initiate_backchannel_authentication(options, store_options=store_options) auth_req_id = backchannel_data.get("auth_req_id") - expires_in = backchannel_data.get("expires_in", 120) # Default to 2 minutes - interval = backchannel_data.get("interval", 5) # Default to 5 seconds + expires_in = backchannel_data.get( + "expires_in", 120) # Default to 2 minutes + interval = backchannel_data.get( + "interval", 5) # Default to 5 seconds # Calculate when to stop polling end_time = time.time() + expires_in @@ -1265,9 +1287,7 @@ async def backchannel_authentication( while time.time() < end_time: # Make token request try: - token_response = await self.backchannel_authentication_grant( - auth_req_id, store_options=store_options - ) + token_response = await self.backchannel_authentication_grant(auth_req_id, store_options=store_options) return token_response except Exception as e: @@ -1283,14 +1303,17 @@ async def backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e, + e ) # If we get here, we've timed out - raise ApiError("timeout", "Backchannel authentication timed out") + raise ApiError( + "timeout", "Backchannel authentication timed out") async def initiate_backchannel_authentication( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Start backchannel authentication with Auth0. @@ -1320,13 +1343,18 @@ async def initiate_backchannel_authentication( https://auth0.com/docs/get-started/authentication-and-authorization-flow/client-initiated-backchannel-authentication-flow """ - sub = options.get("login_hint", {}).get("sub") + sub = options.get('login_hint', {}).get("sub") if not sub: - raise MissingRequiredArgumentError("login_hint.sub") + raise MissingRequiredArgumentError( + "login_hint.sub" + ) - authorization_params = options.get("authorization_params") + authorization_params = options.get('authorization_params') if authorization_params is not None and not isinstance(authorization_params, dict): - raise ApiError("invalid_argument", "authorization_params must be a dict") + raise ApiError( + "invalid_argument", + "authorization_params must be a dict" + ) if authorization_params: requested_expiry = authorization_params.get("requested_expiry") @@ -1334,7 +1362,7 @@ async def initiate_backchannel_authentication( if not isinstance(requested_expiry, int) or requested_expiry <= 0: raise ApiError( "invalid_argument", - "authorization_params.requested_expiry must be a positive integer", + "authorization_params.requested_expiry must be a positive integer" ) try: @@ -1343,18 +1371,24 @@ async def initiate_backchannel_authentication( metadata = await self._get_oidc_metadata_cached(domain) # Get the issuer from metadata - issuer = metadata.get("issuer") or f"https://{domain}/" + issuer = metadata.get( + "issuer") or f"https://{domain}/" # Get backchannel authentication endpoint - backchannel_endpoint = metadata.get("backchannel_authentication_endpoint") + backchannel_endpoint = metadata.get( + "backchannel_authentication_endpoint") if not backchannel_endpoint: raise ApiError( "configuration_error", - "Backchannel authentication is not supported by the authorization server", + "Backchannel authentication is not supported by the authorization server" ) # Prepare login hint in the required format - login_hint = json.dumps({"format": "iss_sub", "iss": issuer, "sub": sub}) + login_hint = json.dumps({ + "format": "iss_sub", + "iss": issuer, + "sub": sub + }) # The Request Parameters params = { @@ -1364,8 +1398,8 @@ async def initiate_backchannel_authentication( } # Add binding message if provided - if options.get("binding_message"): - params["binding_message"] = options.get("binding_message") + if options.get('binding_message'): + params["binding_message"] = options.get('binding_message') # Add any additional authorization parameters if self._default_authorization_params: @@ -1377,7 +1411,9 @@ async def initiate_backchannel_authentication( # Make the backchannel authentication request async with self._get_http_client() as client: backchannel_response = await client.post( - backchannel_endpoint, data=params, auth=(self._client_id, self._client_secret) + backchannel_endpoint, + data=params, + auth=(self._client_id, self._client_secret) ) if backchannel_response.status_code != 200: @@ -1385,8 +1421,7 @@ async def initiate_backchannel_authentication( raise ApiError( error_data.get("error", "backchannel_error"), error_data.get( - "error_description", "Backchannel authentication request failed" - ), + "error_description", "Backchannel authentication request failed") ) backchannel_data = backchannel_response.json() @@ -1395,7 +1430,7 @@ async def initiate_backchannel_authentication( if not auth_req_id: raise ApiError( "invalid_response", - "Missing auth_req_id in backchannel authentication response", + "Missing auth_req_id in backchannel authentication response" ) return backchannel_data @@ -1406,11 +1441,13 @@ async def initiate_backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e, + e ) async def backchannel_authentication_grant( - self, auth_req_id: str, store_options: Optional[dict[str, Any]] = None + self, + auth_req_id: str, + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Retrieves a token by exchanging an auth_req_id. @@ -1435,20 +1472,23 @@ async def backchannel_authentication_grant( token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", + "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { "grant_type": "urn:openid:params:grant-type:ciba", "auth_req_id": auth_req_id, "client_id": self._client_id, - "client_secret": self._client_secret, + "client_secret": self._client_secret } # Exchange the auth_req_id for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=token_params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1457,18 +1497,23 @@ async def backchannel_authentication_grant( interval = int(retry_after) if retry_after is not None else None raise PollingApiError( error_data.get("error", "auth_req_id_error"), - error_data.get("error_description", "Failed to exchange auth_req_id"), - interval, + error_data.get("error_description", + "Failed to exchange auth_req_id"), + interval ) try: token_response = response.json() except json.JSONDecodeError: - raise ApiError("invalid_response", "Failed to parse token response as JSON") + raise ApiError( + "invalid_response", + "Failed to parse token response as JSON" + ) # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int(time.time()) + token_response["expires_in"] + token_response["expires_at"] = int( + time.time()) + token_response["expires_in"] return token_response @@ -1478,7 +1523,7 @@ async def backchannel_authentication_grant( raise AccessTokenError( AccessTokenErrorCode.AUTH_REQ_ID_ERROR, "There was an error while trying to exchange the auth_req_id for an access token.", - e, + e ) # ============================================================================ @@ -1487,7 +1532,11 @@ async def backchannel_authentication_grant( # to a user's Auth0 profile. # ============================================================================ - async def start_link_user(self, options, store_options: Optional[dict[str, Any]] = None): + async def start_link_user( + self, + options, + store_options: Optional[dict[str, Any]] = None + ): """ Starts the user linking process, and returns a URL to redirect the user-agent to. @@ -1514,9 +1563,13 @@ async def start_link_user(self, options, store_options: Optional[dict[str, Any]] state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1530,7 +1583,7 @@ async def start_link_user(self, options, store_options: Optional[dict[str, Any]] code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain, + domain=origin_domain ) # Store transaction data @@ -1541,13 +1594,17 @@ async def start_link_user(self, options, store_options: Optional[dict[str, Any]] ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) return link_user_url async def complete_link_user( - self, url: str, store_options: Optional[dict[str, Any]] = None + self, + url: str, + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user linking process. @@ -1564,9 +1621,15 @@ async def complete_link_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return {"app_state": result.get("app_state")} + return { + "app_state": result.get("app_state") + } - async def start_unlink_user(self, options, store_options: Optional[dict[str, Any]] = None): + async def start_unlink_user( + self, + options, + store_options: Optional[dict[str, Any]] = None + ): """ Starts the user unlinking process, and returns a URL to redirect the user-agent to. @@ -1593,9 +1656,13 @@ async def start_unlink_user(self, options, store_options: Optional[dict[str, Any state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1608,7 +1675,7 @@ async def start_unlink_user(self, options, store_options: Optional[dict[str, Any code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain, + domain=origin_domain ) # Store transaction data @@ -1619,13 +1686,17 @@ async def start_unlink_user(self, options, store_options: Optional[dict[str, Any ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) return link_user_url async def complete_unlink_user( - self, url: str, store_options: Optional[dict[str, Any]] = None + self, + url: str, + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user unlinking process. @@ -1642,7 +1713,9 @@ async def complete_unlink_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return {"app_state": result.get("app_state")} + return { + "app_state": result.get("app_state") + } async def _build_link_user_url( self, @@ -1652,7 +1725,7 @@ async def _build_link_user_url( state: str, connection_scope: Optional[str] = None, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None, + domain: Optional[str] = None ) -> str: """Build a URL for linking user accounts""" # Generate code challenge from verifier @@ -1663,9 +1736,8 @@ async def _build_link_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get( - "authorization_endpoint", f"https://{resolved_domain}/authorize" - ) + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1678,7 +1750,7 @@ async def _build_link_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid link_account", - "audience": "my-account", + "audience": "my-account" } # Add connection scope if provided @@ -1697,7 +1769,7 @@ async def _build_unlink_user_url( code_verifier: str, state: str, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None, + domain: Optional[str] = None ) -> str: """Build a URL for unlinking user accounts""" # Generate code challenge from verifier @@ -1708,9 +1780,8 @@ async def _build_unlink_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get( - "authorization_endpoint", f"https://{resolved_domain}/authorize" - ) + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1722,7 +1793,7 @@ async def _build_unlink_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid unlink_account", - "audience": "my-account", + "audience": "my-account" } # Add any additional parameters if authorization_params: @@ -1737,7 +1808,9 @@ async def _build_unlink_user_url( # ============================================================================ async def get_access_token_for_connection( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> str: """ Retrieves an access token for a connection. @@ -1771,13 +1844,13 @@ async def get_access_token_for_connection( if not session_domain: raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) # Find existing connection token @@ -1796,24 +1869,21 @@ async def get_access_token_for_connection( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_REFRESH_TOKEN, - "A refresh token was not found but is required to be able to retrieve an access token for a connection.", + "A refresh token was not found but is required to be able to retrieve an access token for a connection." ) # Get new token for connection # Use session's domain for token exchange session_domain = state_data_dict.get("domain") or self._domain - token_endpoint_response = await self.get_token_for_connection( - { - "connection": options.get("connection"), - "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain, - } - ) + token_endpoint_response = await self.get_token_for_connection({ + "connection": options.get("connection"), + "login_hint": options.get("login_hint"), + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain + }) # Update state data with new token updated_state_data = State.update_state_data_for_connection_token_set( - options, state_data_dict, token_endpoint_response - ) + options, state_data_dict, token_endpoint_response) # Store updated state await self._state_store.set(self._state_identifier, updated_state_data, store_options) @@ -1837,12 +1907,8 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A """ # Constants SUBJECT_TYPE_REFRESH_TOKEN = "urn:ietf:params:oauth:token-type:refresh_token" - REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( - "http://auth0.com/oauth/token-type/federated-connection-access-token" - ) - GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( - "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" - ) + REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" + GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain @@ -1852,7 +1918,8 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", + "Token endpoint missing in OIDC metadata") # Prepare parameters params = { @@ -1861,7 +1928,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A "subject_token": options["refresh_token"], "requested_token_type": REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, "grant_type": GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, - "client_id": self._client_id, + "client_id": self._client_id } # Add login_hint if provided @@ -1871,41 +1938,38 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # Make the request async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = ( - response.json() - if response.headers.get("content-type") == "application/json" - else {} - ) + error_data = response.json() if response.headers.get( + "content-type") == "application/json" else {} raise ApiError( error_data.get("error", "connection_token_error"), error_data.get( - "error_description", - f"Failed to get token for connection: {response.status_code}", - ), + "error_description", f"Failed to get token for connection: {response.status_code}") ) token_endpoint_response = response.json() return { "access_token": token_endpoint_response.get("access_token"), - "expires_at": int(time.time()) - + int(token_endpoint_response.get("expires_in", 3600)), - "scope": token_endpoint_response.get("scope", ""), + "expires_at": int(time.time()) + int(token_endpoint_response.get("expires_in", 3600)), + "scope": token_endpoint_response.get("scope", "") } except Exception as e: if isinstance(e, ApiError): raise AccessTokenForConnectionError( - AccessTokenForConnectionErrorCode.API_ERROR, str(e) + AccessTokenForConnectionErrorCode.API_ERROR, + str(e) ) raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.FETCH_ERROR, "There was an error while trying to retrieve an access token for a connection.", - e, + e ) # ============================================================================ @@ -1915,7 +1979,9 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # ============================================================================ async def start_connect_account( - self, options: ConnectAccountOptions, store_options: dict = None + self, + options: ConnectAccountOptions, + store_options: dict = None ) -> str: """ Initiates the connect account flow for linking a third-party account to the user's profile. @@ -1940,25 +2006,26 @@ async def start_connect_account( code_verifier = PKCE.generate_code_verifier() code_challenge = PKCE.generate_code_challenge(code_verifier) - state = PKCE.generate_random_string(32) + state= PKCE.generate_random_string(32) connect_request = ConnectAccountRequest( connection=options.connection, scopes=options.scopes, - redirect_uri=redirect_uri, + redirect_uri = redirect_uri, code_challenge=code_challenge, code_challenge_method="S256", state=state, - authorization_params=options.authorization_params, + authorization_params=options.authorization_params ) access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options, + store_options=store_options ) connect_response = await self._my_account_client.connect_account( - access_token=access_token, request=connect_request + access_token=access_token, + request=connect_request ) # Build the transaction data to store @@ -1966,29 +2033,24 @@ async def start_connect_account( code_verifier=code_verifier, app_state=options.app_state, auth_session=connect_response.auth_session, - redirect_uri=redirect_uri, + redirect_uri=redirect_uri ) # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) parsedUrl = urlparse(connect_response.connect_uri) query = urlencode({"ticket": connect_response.connect_params.ticket}) - return urlunparse( - ( - parsedUrl.scheme, - parsedUrl.netloc, - parsedUrl.path, - parsedUrl.params, - query, - parsedUrl.fragment, - ) - ) + return urlunparse((parsedUrl.scheme, parsedUrl.netloc, parsedUrl.path, parsedUrl.params, query, parsedUrl.fragment)) async def complete_connect_account( - self, url: str, store_options: dict = None + self, + url: str, + store_options: dict = None ) -> CompleteConnectAccountResponse: """ Handles the redirect callback to complete the connect account flow for linking a third-party @@ -2020,9 +2082,7 @@ async def complete_connect_account( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get( - transaction_identifier, options=store_options - ) + transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) if not transaction_data: raise MissingTransactionError() @@ -2030,19 +2090,18 @@ async def complete_connect_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options, + store_options=store_options ) request = CompleteConnectAccountRequest( auth_session=transaction_data.auth_session, connect_code=connect_code, redirect_uri=transaction_data.redirect_uri, - code_verifier=transaction_data.code_verifier, + code_verifier=transaction_data.code_verifier ) try: response = await self._my_account_client.complete_connect_account( - access_token=access_token, request=request - ) + access_token=access_token, request=request) if transaction_data.app_state is not None: response.app_state = transaction_data.app_state finally: @@ -2056,7 +2115,7 @@ async def list_connected_accounts( connection: Optional[str] = None, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None, + store_options: dict = None ) -> ListConnectedAccountsResponse: """ Retrieves a list of connected accounts for the authenticated user. @@ -2080,14 +2139,15 @@ async def list_connected_accounts( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options, + store_options=store_options ) return await self._my_account_client.list_connected_accounts( - access_token=access_token, connection=connection, from_param=from_param, take=take - ) + access_token=access_token, connection=connection, from_param=from_param, take=take) async def delete_connected_account( - self, connected_account_id: str, store_options: dict = None + self, + connected_account_id: str, + store_options: dict = None ) -> None: """ Deletes a connected account. @@ -2106,17 +2166,16 @@ async def delete_connected_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="delete:me:connected_accounts", - store_options=store_options, + store_options=store_options ) await self._my_account_client.delete_connected_account( - access_token=access_token, connected_account_id=connected_account_id - ) + access_token=access_token, connected_account_id=connected_account_id) async def list_connected_account_connections( self, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None, + store_options: dict = None ) -> ListConnectedAccountConnectionsResponse: """ Retrieves a list of available connections that can be used connected accounts for the authenticated user. @@ -2139,11 +2198,10 @@ async def list_connected_account_connections( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options, + store_options=store_options ) return await self._my_account_client.list_connected_account_connections( - access_token=access_token, from_param=from_param, take=take - ) + access_token=access_token, from_param=from_param, take=take) # ============================================================================ # CUSTOM TOKEN EXCHANGE (RFC 8693) @@ -2151,7 +2209,9 @@ async def list_connected_account_connections( # ============================================================================ async def custom_token_exchange( - self, options: CustomTokenExchangeOptions, store_options: Optional[dict[str, Any]] = None + self, + options: CustomTokenExchangeOptions, + store_options: Optional[dict[str, Any]] = None ) -> TokenExchangeResponse: """ Exchanges a custom token for Auth0 tokens using RFC 8693. @@ -2224,12 +2284,7 @@ async def custom_token_exchange( # Merge additional authorization params if options.authorization_params: # Prevent override of critical parameters - forbidden_params = { - "grant_type", - "client_id", - "subject_token", - "subject_token_type", - } + forbidden_params = {"grant_type", "client_id", "subject_token", "subject_token_type"} for key, value in options.authorization_params.items(): if key not in forbidden_params: params[key] = value @@ -2237,20 +2292,17 @@ async def custom_token_exchange( # Make the token exchange request async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = ( - response.json() - if response.headers.get("content-type", "").startswith("application/json") - else {} - ) + error_data = response.json() if response.headers.get( + "content-type", "").startswith("application/json") else {} raise CustomTokenExchangeError( error_data.get("error", CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED), - error_data.get( - "error_description", f"Token exchange failed: {response.status_code}" - ), + error_data.get("error_description", f"Token exchange failed: {response.status_code}") ) try: @@ -2258,7 +2310,7 @@ async def custom_token_exchange( except json.JSONDecodeError: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_RESPONSE, - "Failed to parse token response as JSON", + "Failed to parse token response as JSON" ) # Validate and return response @@ -2267,7 +2319,7 @@ async def custom_token_exchange( except ValidationError as e: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_TOKEN_FORMAT, - f"Token validation failed: {str(e)}", + f"Token validation failed: {str(e)}" ) except Exception as e: if isinstance(e, (CustomTokenExchangeError, ApiError)): @@ -2275,13 +2327,13 @@ async def custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Token exchange failed: {str(e)}", - e, + e ) async def login_with_custom_token_exchange( self, options: LoginWithCustomTokenExchangeOptions, - store_options: Optional[dict[str, Any]] = None, + store_options: Optional[dict[str, Any]] = None ) -> LoginWithCustomTokenExchangeResult: """ Performs token exchange and establishes a user session. @@ -2326,12 +2378,10 @@ async def login_with_custom_token_exchange( actor_token=options.actor_token, actor_token_type=options.actor_token_type, organization=options.organization, - authorization_params=options.authorization_params, + authorization_params=options.authorization_params ) - token_response = await self.custom_token_exchange( - exchange_options, store_options=store_options - ) + token_response = await self.custom_token_exchange(exchange_options, store_options=store_options) # Resolve domain and fetch metadata for verification domain = await self._resolve_current_domain(store_options) @@ -2363,18 +2413,28 @@ async def login_with_custom_token_exchange( raise ApiError("jwks_key_not_found", str(e)) except jwt.InvalidSignatureError as e: raise ApiError( - "invalid_signature", f"ID token signature verification failed: {str(e)}", e + "invalid_signature", + f"ID token signature verification failed: {str(e)}", + e ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}: {str(e)}", - e, + e ) except jwt.ExpiredSignatureError as e: - raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) except jwt.InvalidTokenError as e: - raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) # Determine audience for token set audience = options.audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -2384,7 +2444,7 @@ async def login_with_custom_token_exchange( audience=audience, access_token=token_response.access_token, scope=token_response.scope or options.scope or "", - expires_at=int(time.time()) + token_response.expires_in, + expires_at=int(time.time()) + token_response.expires_in ) # Construct state data @@ -2394,14 +2454,19 @@ async def login_with_custom_token_exchange( refresh_token=token_response.refresh_token, token_sets=[token_set], domain=domain, - internal={"sid": sid, "created_at": int(time.time())}, + internal={ + "sid": sid, + "created_at": int(time.time()) + } ) # Store session await self._state_store.set(self._state_identifier, state_data, options=store_options) # Build result - result = LoginWithCustomTokenExchangeResult(state_data=state_data.dict()) + result = LoginWithCustomTokenExchangeResult( + state_data=state_data.dict() + ) return result @@ -2411,11 +2476,11 @@ async def login_with_custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Login with custom token exchange failed: {str(e)}", - e, + e ) # ============================================================================ - # PASSKEY AUTHENTICATION (Category 1) + # PASSKEY AUTHENTICATION # ============================================================================ GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" From 66071e9698e597d22d58ea636be96ddfdc2e8792 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 17:42:43 +0530 Subject: [PATCH 06/12] Edge case fix for Double URL-encoding and extra validation check --- .../auth_schemes/dpop_auth.py | 4 ++++ .../auth_server/my_account_client.py | 20 +++++++++---------- .../auth_types/__init__.py | 6 +++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index 1517a78..d10b8f3 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -17,6 +17,10 @@ def __init__(self, token: str, key: "jwk.JWK") -> None: public_jwk = key.export_public(as_dict=True) if public_jwk.get("kty") != "EC" or public_jwk.get("crv") != "P-256": raise ValueError("DPoP key must be an EC P-256 key") + try: + token.encode("ascii") + except UnicodeEncodeError: + raise ValueError("Access token must contain only ASCII characters") self._token = token self._key = key self._public_jwk = public_jwk diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index bd23a12..e5ef646 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -1,10 +1,8 @@ import json from typing import TYPE_CHECKING, Optional -from urllib.parse import quote +from urllib.parse import quote, unquote import httpx -from pydantic import ValidationError - from auth0_server_python.auth_schemes.bearer_auth import BearerAuth from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_types import ( @@ -416,7 +414,7 @@ async def get_factors( return GetFactorsResponse.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "get_factors_error", @@ -464,7 +462,7 @@ async def list_authentication_methods( return ListAuthenticationMethodsResponse.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "list_authentication_methods_error", @@ -509,7 +507,7 @@ async def get_authentication_method( return AuthenticationMethod.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "get_authentication_method_error", @@ -552,7 +550,7 @@ async def delete_authentication_method( ) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "delete_authentication_method_error", @@ -601,7 +599,7 @@ async def update_authentication_method( return AuthenticationMethod.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "update_authentication_method_error", @@ -661,7 +659,7 @@ async def enroll_authentication_method( path = location.split("?")[0].split("#")[0].rstrip("/") segments = path.split("/") - authentication_method_id = segments[-1] if len(segments) > 1 else "" + authentication_method_id = unquote(segments[-1]) if len(segments) > 1 else "" if not authentication_method_id or authentication_method_id in ( "authentication-methods", "v1", @@ -696,7 +694,7 @@ async def enroll_authentication_method( ) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "enroll_authentication_method_error", @@ -749,7 +747,7 @@ async def verify_authentication_method( return AuthenticationMethod.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "verify_authentication_method_error", diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index d306efa..8141a18 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -553,9 +553,9 @@ class VerifyAuthenticationMethodRequest(BaseModel): def _check_at_least_one_method(self) -> "VerifyAuthenticationMethodRequest": has_method = ( self.authn_response is not None - or self.otp_code is not None - or self.recovery_code is not None - or self.password is not None + or (self.otp_code is not None and self.otp_code.strip() != "") + or (self.recovery_code is not None and self.recovery_code.strip() != "") + or (self.password is not None and self.password.strip() != "") ) if not has_method: raise ValueError( From 3809edb21b3c0f7df4cbb15fffc5c9a167340856 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Fri, 5 Jun 2026 12:29:33 +0530 Subject: [PATCH 07/12] SDK-8780 PR review changes --- .../auth_schemes/dpop_auth.py | 38 +- .../auth_server/my_account_client.py | 8 +- .../auth_server/server_client.py | 158 ++-- .../auth_types/__init__.py | 388 ++++---- src/auth0_server_python/error/__init__.py | 18 + .../tests/test_my_account_client.py | 835 ++++++++++++++++++ .../tests/test_passkey_my_account.py | 830 ----------------- .../tests/test_passkey_server_client.py | 585 ------------ .../tests/test_server_client.py | 794 +++++++++++++++++ 9 files changed, 1956 insertions(+), 1698 deletions(-) delete mode 100644 src/auth0_server_python/tests/test_passkey_my_account.py delete mode 100644 src/auth0_server_python/tests/test_passkey_server_client.py diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index d10b8f3..0bf2d66 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -12,6 +12,28 @@ def _base64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") +def make_dpop_proof_for_token_endpoint(key: "jwk.JWK", method: str, url: str, nonce: str = None) -> str: + """ + Build a DPoP proof JWT for use at the token endpoint (RFC 9449 §4.2). + Unlike resource-server proofs, token-endpoint proofs do NOT include `ath` + because no access token exists yet at issuance time. + """ + public_jwk = key.export_public(as_dict=True) + htu = url.split("?")[0].split("#")[0] + header = {"typ": "dpop+jwt", "alg": "ES256", "jwk": public_jwk} + payload = { + "jti": str(uuid.uuid4()), + "htm": method.upper(), + "htu": htu, + "iat": int(time.time()), + } + if nonce is not None: + payload["nonce"] = nonce + token = jwcrypto_jwt.JWT(header=header, claims=payload) + token.make_signed_token(key) + return token.serialize() + + class DPoPAuth(httpx.Auth): def __init__(self, token: str, key: "jwk.JWK") -> None: public_jwk = key.export_public(as_dict=True) @@ -35,9 +57,19 @@ def auth_flow(self, request: httpx.Request): proof = self._make_proof(request.method, str(request.url)) request.headers["Authorization"] = f"DPoP {self._token}" request.headers["DPoP"] = proof - yield request + response = yield request + + # RFC 9449 §8.2 — server-nonce retry + if ( + response is not None + and response.status_code == 401 + and response.headers.get("DPoP-Nonce") + ): + nonce = response.headers["DPoP-Nonce"] + request.headers["DPoP"] = self._make_proof(request.method, str(request.url), nonce=nonce) + yield request - def _make_proof(self, method: str, url: str) -> str: + def _make_proof(self, method: str, url: str, nonce: str = None) -> str: htu = url.split("?")[0].split("#")[0] ath = _base64url(hashlib.sha256(self._token.encode("ascii")).digest()) @@ -49,6 +81,8 @@ def _make_proof(self, method: str, url: str) -> str: "iat": int(time.time()), "ath": ath, } + if nonce is not None: + payload["nonce"] = nonce token = jwcrypto_jwt.JWT(header=header, claims=payload) token.make_signed_token(self._key) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index e5ef646..5ffadd9 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -619,7 +619,7 @@ async def enroll_authentication_method( navigator.credentials.create(), then call verify_authentication_method() with the auth_session and credential result. - Requires scope: create:me:authentication_methods + Requires scope: create:me:authentication-methods """ if not access_token: raise MissingRequiredArgumentError("access_token") @@ -634,7 +634,7 @@ async def enroll_authentication_method( auth=_make_auth(access_token, dpop_key), ) - if response.status_code != 201: + if response.status_code != 202: try: error_data = response.json() except (json.JSONDecodeError, ValueError): @@ -711,7 +711,7 @@ async def verify_authentication_method( ) -> AuthenticationMethod: """Step 2 of 2: Verify enrollment (POST /me/v1/authentication-methods/{id}/verify). - Requires scope: create:me:authentication_methods + Requires scope: create:me:authentication-methods """ if not access_token: raise MissingRequiredArgumentError("access_token") @@ -728,7 +728,7 @@ async def verify_authentication_method( auth=_make_auth(access_token, dpop_key), ) - if response.status_code != 200: + if response.status_code != 201: try: error_data = response.json() except (json.JSONDecodeError, ValueError): diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index a233205..110f33a 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -7,12 +7,16 @@ import json import time from collections import OrderedDict -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union + +if TYPE_CHECKING: + from jwcrypto import jwk from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx import jwt from authlib.integrations.base_client.errors import OAuthError +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, make_dpop_proof_for_token_endpoint from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError @@ -34,6 +38,7 @@ PasskeyAuthResponse, PasskeyLoginChallengeResponse, PasskeySignupChallengeResponse, + PasskeyUserProfile, PasskeyTokenResponse, StartInteractiveLoginOptions, StateData, @@ -58,6 +63,8 @@ MfaRequiredError, MissingRequiredArgumentError, MissingTransactionError, + PasskeyError, + PasskeyErrorCode, PollingApiError, StartLinkUserError, ) @@ -82,6 +89,9 @@ class ServerClient(Generic[TStoreOptions]): and token operations using Authlib for OIDC functionality. """ DEFAULT_AUDIENCE_STATE_KEY = "default" + GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" + PASSKEY_REGISTER_PATH = "/passkey/register" + PASSKEY_CHALLENGE_PATH = "/passkey/challenge" # ============================================================================ # INITIALIZATION @@ -2480,22 +2490,21 @@ async def login_with_custom_token_exchange( ) # ============================================================================ - # PASSKEY AUTHENTICATION + # MFA (Multi-Factor Authentication) # ============================================================================ - GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" + @property + def mfa(self) -> MfaClient: + """Access the MFA client for multi-factor authentication operations.""" + return self._mfa_client + + # ============================================================================ + # PASSKEY AUTHENTICATION + # ============================================================================ async def passkey_signup_challenge( self, - name: Optional[str] = None, - email: Optional[str] = None, - username: Optional[str] = None, - phone_number: Optional[str] = None, - given_name: Optional[str] = None, - family_name: Optional[str] = None, - nickname: Optional[str] = None, - picture: Optional[str] = None, - user_metadata: Optional[dict[str, Any]] = None, + user_profile: Optional[PasskeyUserProfile] = None, connection: Optional[str] = None, organization: Optional[str] = None, store_options: Optional[dict[str, Any]] = None, @@ -2507,15 +2516,8 @@ async def passkey_signup_challenge( then call signin_with_passkey() with the auth_session and credential result. Args: - name: User's full name. - email: User's email address. - username: Username for the new account. - phone_number: User's phone number. - given_name: User's given (first) name. - family_name: User's family (last) name. - nickname: User's nickname. - picture: URL to the user's profile picture. - user_metadata: Arbitrary user metadata dict. + user_profile: Optional user profile data (email, name, username, etc.). + Use PasskeyUserProfile — supports extra fields for forward compatibility. connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. store_options: Optional options for domain resolution. @@ -2524,73 +2526,52 @@ async def passkey_signup_challenge( PasskeySignupChallengeResponse with auth_session and authn_params_public_key. Raises: - ApiError: If the challenge request fails. + PasskeyError: If the challenge request fails. """ try: domain = await self._resolve_current_domain(store_options) - user_profile: dict[str, Any] = {} - if email is not None: - user_profile["email"] = email - if name is not None: - user_profile["name"] = name - if username is not None: - user_profile["username"] = username - if phone_number is not None: - user_profile["phone_number"] = phone_number - if given_name is not None: - user_profile["given_name"] = given_name - if family_name is not None: - user_profile["family_name"] = family_name - if nickname is not None: - user_profile["nickname"] = nickname - if picture is not None: - user_profile["picture"] = picture - if user_metadata is not None: - user_profile["user_metadata"] = user_metadata - body: dict[str, Any] = {"client_id": self._client_id} if self._client_secret: body["client_secret"] = self._client_secret if user_profile: - body["user_profile"] = user_profile + body["user_profile"] = user_profile.model_dump(exclude_none=True) if connection: body["realm"] = connection if organization: body["organization"] = organization - url = f"https://{domain}/passkey/register" - async with self._get_http_client() as client: + url = f"https://{domain}{self.PASSKEY_REGISTER_PATH}" response = await client.post(url, json=body) if response.status_code != 200: try: error_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "passkey_challenge_error", + raise PasskeyError( + PasskeyErrorCode.CHALLENGE_FAILED, f"Passkey signup challenge failed with status {response.status_code}", ) - raise ApiError( - error_data.get("error", "passkey_challenge_error"), + raise PasskeyError( + error_data.get("error", PasskeyErrorCode.CHALLENGE_FAILED), error_data.get("error_description", "Passkey signup challenge failed"), ) try: data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "invalid_response", + raise PasskeyError( + PasskeyErrorCode.INVALID_RESPONSE, "Failed to parse passkey signup challenge response as JSON", ) return PasskeySignupChallengeResponse.model_validate(data) except Exception as e: - if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): raise - raise ApiError("passkey_challenge_error", "Passkey signup challenge failed", e) + raise PasskeyError(PasskeyErrorCode.CHALLENGE_FAILED, "Passkey signup challenge failed", e) from e async def passkey_login_challenge( self, @@ -2630,38 +2611,37 @@ async def passkey_login_challenge( if organization: body["organization"] = organization - url = f"https://{domain}/passkey/challenge" - async with self._get_http_client() as client: + url = f"https://{domain}{self.PASSKEY_CHALLENGE_PATH}" response = await client.post(url, json=body) if response.status_code != 200: try: error_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "passkey_challenge_error", + raise PasskeyError( + PasskeyErrorCode.CHALLENGE_FAILED, f"Passkey login challenge failed with status {response.status_code}", ) - raise ApiError( - error_data.get("error", "passkey_challenge_error"), + raise PasskeyError( + error_data.get("error", PasskeyErrorCode.CHALLENGE_FAILED), error_data.get("error_description", "Passkey login challenge failed"), ) try: data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "invalid_response", + raise PasskeyError( + PasskeyErrorCode.INVALID_RESPONSE, "Failed to parse passkey login challenge response as JSON", ) return PasskeyLoginChallengeResponse.model_validate(data) except Exception as e: - if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): raise - raise ApiError("passkey_challenge_error", "Passkey login challenge failed", e) + raise PasskeyError(PasskeyErrorCode.CHALLENGE_FAILED, "Passkey login challenge failed", e) from e async def signin_with_passkey( self, @@ -2672,6 +2652,7 @@ async def signin_with_passkey( organization: Optional[str] = None, scope: Optional[str] = None, audience: Optional[str] = None, + dpop_key: Optional["jwk.JWK"] = None, ) -> PasskeyTokenResponse: """ Completes passkey authentication by exchanging the WebAuthn assertion @@ -2690,13 +2671,16 @@ async def signin_with_passkey( organization: Auth0 organization ID or name. scope: OAuth2 scope string. audience: Target API audience. + dpop_key: Optional EC P-256 JWK for DPoP-bound token exchange. When provided, + attaches a DPoP proof header so Auth0 issues a DPoP-bound token + (token_type: DPoP). Required when the tenant mandates DPoP binding. Returns: PasskeyTokenResponse containing access_token, id_token, expires_in, etc. Raises: MissingRequiredArgumentError: If auth_session or authn_response is missing. - ApiError: If token exchange fails. + PasskeyError: If token exchange fails. """ if not auth_session: raise MissingRequiredArgumentError("auth_session") @@ -2709,7 +2693,7 @@ async def signin_with_passkey( token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Token endpoint missing in OIDC metadata") body: dict[str, Any] = { "grant_type": self.GRANT_TYPE_PASSKEY, @@ -2729,26 +2713,43 @@ async def signin_with_passkey( body["audience"] = audience async with self._get_http_client() as client: - response = await client.post(token_endpoint, json=body) + headers = {} + if dpop_key is not None: + headers["DPoP"] = make_dpop_proof_for_token_endpoint( + dpop_key, "POST", token_endpoint + ) + response = await client.post(token_endpoint, json=body, headers=headers) + + # RFC 9449 §8.2 — nonce retry for DPoP token endpoint calls + if ( + dpop_key is not None + and response.status_code == 401 + and response.headers.get("DPoP-Nonce") + ): + nonce = response.headers["DPoP-Nonce"] + headers["DPoP"] = make_dpop_proof_for_token_endpoint( + dpop_key, "POST", token_endpoint, nonce=nonce + ) + response = await client.post(token_endpoint, json=body, headers=headers) if response.status_code != 200: try: error_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "passkey_token_error", + raise PasskeyError( + PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, f"Passkey token exchange failed with status {response.status_code}", ) - raise ApiError( - error_data.get("error", "passkey_token_error"), + raise PasskeyError( + error_data.get("error", PasskeyErrorCode.TOKEN_EXCHANGE_FAILED), error_data.get("error_description", "Passkey token exchange failed"), ) try: token_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "invalid_response", "Failed to parse passkey token response as JSON" + raise PasskeyError( + PasskeyErrorCode.INVALID_RESPONSE, "Failed to parse passkey token response as JSON" ) if "expires_in" in token_data and "expires_at" not in token_data: @@ -2757,15 +2758,6 @@ async def signin_with_passkey( return PasskeyTokenResponse.model_validate(token_data) except Exception as e: - if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): raise - raise ApiError("passkey_token_error", "Passkey sign-in failed", e) - - # ============================================================================ - # MFA (Multi-Factor Authentication) - # ============================================================================ - - @property - def mfa(self) -> MfaClient: - """Access the MFA client for multi-factor authentication operations.""" - return self._mfa_client + raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Passkey sign-in failed", e) from e diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 8141a18..9494a22 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -465,200 +465,6 @@ class ListConnectedAccountConnectionsResponse(BaseModel): next: Optional[str] = None -# ============================================================================= -# Passkey & MyAccount Authentication Methods Types -# ============================================================================= - - -class PasskeyRpInfo(BaseModel): - id: str - name: str - - -class PasskeyUserInfo(BaseModel): - model_config = ConfigDict(populate_by_name=True) - id: str - name: str - display_name: Optional[str] = Field(None, alias="displayName") - - -class PasskeyPubKeyCredParam(BaseModel): - type: str - alg: int - - -class PasskeyAuthenticatorSelection(BaseModel): - model_config = ConfigDict(populate_by_name=True) - resident_key: Optional[str] = Field(None, alias="residentKey") - user_verification: Optional[str] = Field(None, alias="userVerification") - - -class PasskeyPublicKeyOptions(BaseModel): - model_config = ConfigDict(populate_by_name=True) - challenge: str - rp: Optional[PasskeyRpInfo] = None - rp_id: Optional[str] = Field(None, alias="rpId") - user: Optional[PasskeyUserInfo] = None - pub_key_cred_params: Optional[list[PasskeyPubKeyCredParam]] = Field( - None, alias="pubKeyCredParams" - ) - authenticator_selection: Optional[PasskeyAuthenticatorSelection] = Field( - None, alias="authenticatorSelection" - ) - timeout: Optional[int] = None - user_verification: Optional[str] = Field(None, alias="userVerification") - - -class EnrollAuthenticationMethodRequest(BaseModel): - type: str - email: Optional[str] = None - phone_number: Optional[str] = None - preferred_authentication_method: Optional[str] = None - user_identity_id: Optional[str] = None - connection: Optional[str] = None - - -class EnrollmentChallengeResponse(BaseModel): - authentication_method_id: str - auth_session: str - authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None - - def __repr__(self) -> str: - return ( - f"EnrollmentChallengeResponse(" - f"authentication_method_id={self.authentication_method_id!r}, " - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) - - -class PasskeyAuthResponse(BaseModel): - model_config = ConfigDict(populate_by_name=True) - id: str - raw_id: str = Field(alias="rawId") - type: str - authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") - response: dict[str, str] - client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") - - -class VerifyAuthenticationMethodRequest(BaseModel): - auth_session: str - authn_response: Optional[PasskeyAuthResponse] = None - otp_code: Optional[str] = None - recovery_code: Optional[str] = None - password: Optional[str] = None - - @model_validator(mode="after") - def _check_at_least_one_method(self) -> "VerifyAuthenticationMethodRequest": - has_method = ( - self.authn_response is not None - or (self.otp_code is not None and self.otp_code.strip() != "") - or (self.recovery_code is not None and self.recovery_code.strip() != "") - or (self.password is not None and self.password.strip() != "") - ) - if not has_method: - raise ValueError( - "At least one verification method must be provided: " - "authn_response, otp_code, recovery_code, or password" - ) - return self - - -class AuthenticationMethod(BaseModel): - model_config = ConfigDict(extra="allow", populate_by_name=True) - - id: str - type: str - created_at: str - confirmed: Optional[bool] = None - usage: Optional[list[str]] = None - identity_user_id: Optional[str] = None - credential_device_type: Optional[str] = None - credential_backed_up: Optional[bool] = None - key_id: Optional[str] = None - public_key: Optional[str] = None - transports: Optional[list[str]] = None - user_agent: Optional[str] = None - user_handle: Optional[str] = None - aaguid: Optional[str] = None - relying_party_id: Optional[str] = None - phone_number: Optional[str] = None - preferred_authentication_method: Optional[str] = None - email: Optional[str] = None - name: Optional[str] = None - last_password_reset: Optional[str] = None - - -class UpdateAuthenticationMethodRequest(BaseModel): - name: Optional[str] = None - preferred_authentication_method: Optional[str] = None - - -class ListAuthenticationMethodsResponse(BaseModel): - authentication_methods: list[AuthenticationMethod] - - -class Factor(BaseModel): - model_config = ConfigDict(extra="allow") - name: str - enabled: Optional[bool] = None - trial_expired: Optional[bool] = None - - -class GetFactorsResponse(BaseModel): - factors: list[Factor] - - -class PasskeySignupChallengeResponse(BaseModel): - model_config = ConfigDict(populate_by_name=True) - auth_session: str - authn_params_public_key: PasskeyPublicKeyOptions - - def __repr__(self) -> str: - return ( - f"PasskeySignupChallengeResponse(" - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) - - -class PasskeyLoginChallengeResponse(BaseModel): - model_config = ConfigDict(populate_by_name=True) - auth_session: str - authn_params_public_key: PasskeyPublicKeyOptions - - def __repr__(self) -> str: - return ( - f"PasskeyLoginChallengeResponse(" - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) - - -class PasskeyTokenResponse(BaseModel): - model_config = ConfigDict(extra="allow", populate_by_name=True) - access_token: str - token_type: str = "Bearer" - expires_in: int - expires_at: Optional[int] = None - scope: Optional[str] = None - id_token: Optional[str] = None - refresh_token: Optional[str] = None - - def __repr__(self) -> str: - return ( - f"PasskeyTokenResponse(" - f"token_type={self.token_type!r}, " - f"expires_in={self.expires_in!r}, " - f"expires_at={self.expires_at!r}, " - f"scope={self.scope!r}, " - f"access_token=[REDACTED], " - f"id_token=[REDACTED], " - f"refresh_token=[REDACTED])" - ) - - # ============================================================================= # MFA Types # ============================================================================= @@ -839,3 +645,197 @@ class MfaTokenContext(BaseModel): scope: str mfa_requirements: Optional[MfaRequirements] = None created_at: int + + +# ============================================================================= +# Passkey & MyAccount Authentication Methods Types +# ============================================================================= + + +class PasskeyRpInfo(BaseModel): + id: str + name: str + + +class PasskeyUserInfo(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + name: str + display_name: Optional[str] = Field(None, alias="displayName") + + +class PasskeyPubKeyCredParam(BaseModel): + type: str + alg: int + + +class PasskeyAuthenticatorSelection(BaseModel): + model_config = ConfigDict(populate_by_name=True) + resident_key: Optional[str] = Field(None, alias="residentKey") + user_verification: Optional[str] = Field(None, alias="userVerification") + + +class PasskeyPublicKeyOptions(BaseModel): + model_config = ConfigDict(populate_by_name=True) + challenge: str + rp: Optional[PasskeyRpInfo] = None + rp_id: Optional[str] = Field(None, alias="rpId") + user: Optional[PasskeyUserInfo] = None + pub_key_cred_params: Optional[list[PasskeyPubKeyCredParam]] = Field( + None, alias="pubKeyCredParams" + ) + authenticator_selection: Optional[PasskeyAuthenticatorSelection] = Field( + None, alias="authenticatorSelection" + ) + timeout: Optional[int] = None + user_verification: Optional[str] = Field(None, alias="userVerification") + + +EnrollmentType = Literal["passkey", "email", "phone", "totp", "push-notification", "recovery-code", "password"] +PreferredAuthMethod = Literal["sms", "voice"] + + +class EnrollAuthenticationMethodRequest(BaseModel): + type: EnrollmentType + email: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[PreferredAuthMethod] = None + user_identity_id: Optional[str] = None + connection: Optional[str] = None + + +class EnrollmentChallengeResponse(BaseModel): + authentication_method_id: str + auth_session: str + authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None + + def __repr__(self) -> str: + return ( + f"EnrollmentChallengeResponse(" + f"authentication_method_id={self.authentication_method_id!r}, " + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyAuthResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + raw_id: str = Field(alias="rawId") + type: str + authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") + response: dict[str, str] + client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") + + +class VerifyAuthenticationMethodRequest(BaseModel): + auth_session: str + authn_response: Optional[PasskeyAuthResponse] = None + otp_code: Optional[str] = None + recovery_code: Optional[str] = None + password: Optional[str] = None + + +class AuthenticationMethod(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + type: str + created_at: str + confirmed: Optional[bool] = None + usage: Optional[list[str]] = None + identity_user_id: Optional[str] = None + credential_device_type: Optional[str] = None + credential_backed_up: Optional[bool] = None + key_id: Optional[str] = None + public_key: Optional[str] = None + transports: Optional[list[str]] = None + user_agent: Optional[str] = None + user_handle: Optional[str] = None + aaguid: Optional[str] = None + relying_party_id: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[str] = None + email: Optional[str] = None + name: Optional[str] = None + last_password_reset: Optional[str] = None + + +class UpdateAuthenticationMethodRequest(BaseModel): + name: Optional[str] = None + preferred_authentication_method: Optional[str] = None + + +class ListAuthenticationMethodsResponse(BaseModel): + authentication_methods: list[AuthenticationMethod] + + +class Factor(BaseModel): + model_config = ConfigDict(extra="allow") + name: str + enabled: Optional[bool] = None + trial_expired: Optional[bool] = None + + +class GetFactorsResponse(BaseModel): + factors: list[Factor] + + +class PasskeyUserProfile(BaseModel): + model_config = ConfigDict(extra="allow") + email: Optional[str] = None + name: Optional[str] = None + username: Optional[str] = None + phone_number: Optional[str] = None + given_name: Optional[str] = None + family_name: Optional[str] = None + nickname: Optional[str] = None + picture: Optional[str] = None + user_metadata: Optional[dict[str, Any]] = None + + +class PasskeySignupChallengeResponse(BaseModel): + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeySignupChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyLoginChallengeResponse(BaseModel): + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeyLoginChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyTokenResponse(BaseModel): + model_config = ConfigDict(extra="allow") + access_token: str + token_type: str = "Bearer" + expires_in: int + expires_at: Optional[int] = None + scope: Optional[str] = None + id_token: Optional[str] = None + refresh_token: Optional[str] = None + + def __repr__(self) -> str: + return ( + f"PasskeyTokenResponse(" + f"token_type={self.token_type!r}, " + f"expires_in={self.expires_in!r}, " + f"expires_at={self.expires_at!r}, " + f"scope={self.scope!r}, " + f"access_token=[REDACTED], " + f"id_token=[REDACTED], " + f"refresh_token=[REDACTED])" + ) diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index db4f28e..615c112 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -229,6 +229,24 @@ class CustomTokenExchangeErrorCode: INVALID_RESPONSE = "invalid_response" +class PasskeyError(Auth0Error): + """ + Error raised during passkey authentication operations. + """ + def __init__(self, code: str, message: str, cause=None): + super().__init__(message) + self.code = code + self.name = "PasskeyError" + self.cause = cause + + +class PasskeyErrorCode: + """Error codes for passkey operations.""" + CHALLENGE_FAILED = "passkey_challenge_error" + TOKEN_EXCHANGE_FAILED = "passkey_token_error" + INVALID_RESPONSE = "invalid_response" + + # ============================================================================= # MFA Error Classes # ============================================================================= diff --git a/src/auth0_server_python/tests/test_my_account_client.py b/src/auth0_server_python/tests/test_my_account_client.py index e4ff74c..da2875d 100644 --- a/src/auth0_server_python/tests/test_my_account_client.py +++ b/src/auth0_server_python/tests/test_my_account_client.py @@ -1,9 +1,13 @@ from unittest.mock import ANY, AsyncMock, MagicMock +import httpx import pytest +from jwcrypto import jwk as jwk_module +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_types import ( + AuthenticationMethod, CompleteConnectAccountRequest, CompleteConnectAccountResponse, ConnectAccountRequest, @@ -11,10 +15,18 @@ ConnectedAccount, ConnectedAccountConnection, ConnectParams, + EnrollAuthenticationMethodRequest, + EnrollmentChallengeResponse, + GetFactorsResponse, + ListAuthenticationMethodsResponse, ListConnectedAccountConnectionsResponse, ListConnectedAccountsResponse, + PasskeyAuthResponse, + UpdateAuthenticationMethodRequest, + VerifyAuthenticationMethodRequest, ) from auth0_server_python.error import ( + ApiError, InvalidArgumentError, MissingRequiredArgumentError, MyAccountApiError, @@ -502,3 +514,826 @@ async def test_list_connected_account_connections_api_response_failure(mocker): mock_get.assert_awaited_once() assert "Invalid Token" in str(exc.value) + +# ============================================================================= +# AUTHENTICATION METHODS & FACTORS (Passkey / MyAccount API) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_factors_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + + assert isinstance(result, GetFactorsResponse) + assert len(result.factors) == 1 + assert result.factors[0].name == "sms" + assert result.factors[0].enabled is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("access_token", [None, ""]) +async def test_get_factors_missing_access_token(mocker, access_token): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_factors(access_token=access_token) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_factors_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock(return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Insufficient scope", + "status": 403, + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_factors(access_token="token123") + + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_get_factors_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.get_factors(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_factors_empty_list(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors == [] + + +@pytest.mark.asyncio +async def test_get_factors_extra_fields(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors[0].name == "webauthn-roaming" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "authentication_methods": [ + {"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z", "key_id": "kid1"} + ] + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert isinstance(result, ListAuthenticationMethodsResponse) + assert len(result.authentication_methods) == 1 + assert result.authentication_methods[0].type == "passkey" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_type_filter(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.list_authentication_methods(access_token="token123", type_filter="passkey") + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert call_kwargs["params"] == {"type": "passkey"} + + +@pytest.mark.asyncio +async def test_list_authentication_methods_empty(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert result.authentication_methods == [] + + +@pytest.mark.asyncio +async def test_get_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert isinstance(result, AuthenticationMethod) + assert result.id == "am_1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_get_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_authentication_method_path_traversal(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "id/slash", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="id/slash" + ) + call_url = mock_get.call_args[1]["url"] + assert "id%2Fslash" in call_url + assert "id/slash" not in call_url.replace("https://auth0.local/me/", "") + + +@pytest.mark.asyncio +async def test_get_authentication_method_pipe_encoding(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "passkey|new", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="passkey|new" + ) + call_url = mock_get.call_args[1]["url"] + assert "passkey%7Cnew" in call_url + + +@pytest.mark.asyncio +async def test_delete_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + result = await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert result is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_delete_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_delete = mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_update_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z", "name": "My Key", + }) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + req = UpdateAuthenticationMethodRequest(name="My Key") + result = await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert result.name == "My Key" + call_kwargs = mock_patch.call_args[1] + assert call_kwargs["json"] == {"name": "My Key"} + + +@pytest.mark.asyncio +async def test_update_authentication_method_missing_request(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=None + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_abc", + "authn_params_public_key": { + "challenge": "dGVzdA", + "rp": {"id": "auth0.local", "name": "My App"}, + "user": {"id": "dXNlcl8x", "name": "user@test.com", "displayName": "Test User"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": {"residentKey": "required", "userVerification": "preferred"}, + "timeout": 60000, + }, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert isinstance(result, EnrollmentChallengeResponse) + assert result.authentication_method_id == "passkey|new" + assert result.auth_session == "session_abc" + assert result.authn_params_public_key is not None + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + assert result.authn_params_public_key.user.display_name == "Test User" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_missing_location(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + + assert "Location header" in str(exc.value) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_with_query(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/abc123?tracking=1"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "abc123" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_absolute_url(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "https://tenant.auth0.com/me/v1/authentication-methods/am_xyz"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "am_xyz" + + +@pytest.mark.asyncio +async def test_verify_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z", "confirmed": True, + }) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + authn_response = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + authenticator_attachment="platform", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest( + auth_session="session_abc", authn_response=authn_response + ) + result = await client.verify_authentication_method( + access_token="token123", authentication_method_id="passkey|new", request=req + ) + + assert isinstance(result, AuthenticationMethod) + assert result.confirmed is True + + call_kwargs = mock_post.call_args[1] + body = call_kwargs["json"] + assert "rawId" in body["authn_response"] + assert "raw_id" not in body["authn_response"] + assert "authenticatorAttachment" in body["authn_response"] + assert body["auth_session"] == "session_abc" + assert "passkey%7Cnew" in call_kwargs["url"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_verify_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(MissingRequiredArgumentError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id=method_id, request=req + ) + + +def test_enrollment_challenge_response_repr(): + resp = EnrollmentChallengeResponse( + authentication_method_id="am_1", + auth_session="super_secret_session", + authn_params_public_key=None, + ) + repr_str = repr(resp) + assert "super_secret_session" not in repr_str + assert "[REDACTED]" in repr_str + assert "am_1" in repr_str + + +def test_verify_request_auth_session_only_is_valid(): + req = VerifyAuthenticationMethodRequest(auth_session="session_abc") + assert req.auth_session == "session_abc" + assert req.otp_code is None + assert req.authn_response is None + + +def test_verify_request_accepts_otp_code(): + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + assert req.otp_code == "123456" + + +def test_verify_request_accepts_authn_response(): + authn_resp = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", authn_response=authn_resp) + assert req.authn_response is not None + + +@pytest.mark.asyncio +async def test_get_factors_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_factors(access_token="token123", dpop_key=dpop_key) + + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert isinstance(call_kwargs["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.list_authentication_methods(access_token="token123", dpop_key=dpop_key) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_get_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mock_delete = mocker.patch( + "httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_delete.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_update_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = UpdateAuthenticationMethodRequest(name="New Name") + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_patch.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = EnrollAuthenticationMethodRequest(type="passkey") + await client.enroll_authentication_method( + access_token="token123", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + await client.verify_authentication_method( + access_token="token123", + authentication_method_id="am_1", + request=req, + dpop_key=dpop_key, + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_list_authentication_methods_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock(return_value={ + "title": "Forbidden", "type": "forbidden", "detail": "Insufficient scope", "status": 403, + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.list_authentication_methods(access_token="token123") + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_list_authentication_methods_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.list_authentication_methods(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock(return_value={ + "title": "Not Found", "type": "not_found", "detail": "Not found", "status": 404, + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_get_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("timeout")) + + with pytest.raises(ApiError): + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock(return_value={ + "title": "Not Found", "type": "not_found", "detail": "Not found", "status": 404, + }) + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_delete_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.delete", + new_callable=AsyncMock, + side_effect=Exception("Connection reset"), + ) + + with pytest.raises(ApiError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_update_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 422 + response.json = MagicMock(return_value={ + "title": "Unprocessable", "type": "validation_error", "detail": "Invalid", "status": 422, + }) + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(MyAccountApiError) as exc: + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 422 + + +@pytest.mark.asyncio +async def test_update_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, side_effect=Exception("timeout") + ) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(ApiError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock(return_value={ + "title": "Forbidden", "type": "forbidden", "detail": "Scope missing", "status": 403, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(MyAccountApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError): + await client.enroll_authentication_method(access_token="token123", request=req) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 400 + response.json = MagicMock(return_value={ + "title": "Bad Request", "type": "invalid_request", "detail": "Invalid OTP", "status": 400, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="000000") + with pytest.raises(MyAccountApiError) as exc: + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 400 + + +@pytest.mark.asyncio +async def test_verify_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(ApiError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_collection_url(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert "could not extract ID" in str(exc.value) + + +# ============================================================================= +# DPoP nonce retry (RFC 9449 §8.2) — tests DPoPAuth.auth_flow directly +# ============================================================================= + + +def test_dpop_auth_flow_retries_with_nonce_on_401(): + """ + DPoPAuth.auth_flow() must retry with DPoP-Nonce when server responds 401 + + DPoP-Nonce header (RFC 9449 §8.2). Tested by driving the generator directly. + """ + import base64 + import json as _json + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + auth = DPoPAuth(token="test_access_token", key=dpop_key) + + request = httpx.Request("GET", "https://auth0.local/me/v1/factors") + flow = auth.auth_flow(request) + + # First yield — initial request + first_request = next(flow) + assert "DPoP" in first_request.headers + assert "Authorization" in first_request.headers + + # First proof must not have nonce + proof1 = first_request.headers["DPoP"] + payload1_b64 = proof1.split(".")[1] + padding = 4 - len(payload1_b64) % 4 + payload1 = _json.loads(base64.urlsafe_b64decode(payload1_b64 + "=" * padding)) + assert "nonce" not in payload1 + + # Simulate 401 + DPoP-Nonce response + nonce_response = httpx.Response( + status_code=401, + headers={"DPoP-Nonce": "server-nonce-abc"}, + content=b'{"error":"use_dpop_nonce"}', + request=request, + ) + + # Second yield — retry request with nonce + try: + second_request = flow.send(nonce_response) + except StopIteration: + second_request = None + + assert second_request is not None + proof2 = second_request.headers["DPoP"] + payload2_b64 = proof2.split(".")[1] + padding = 4 - len(payload2_b64) % 4 + payload2 = _json.loads(base64.urlsafe_b64decode(payload2_b64 + "=" * padding)) + assert payload2["nonce"] == "server-nonce-abc" + + +def test_dpop_auth_flow_no_retry_on_non_401(): + """DPoPAuth.auth_flow() must NOT retry when the response is not 401.""" + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + auth = DPoPAuth(token="test_access_token", key=dpop_key) + + request = httpx.Request("GET", "https://auth0.local/me/v1/factors") + flow = auth.auth_flow(request) + next(flow) + + success_response = httpx.Response( + status_code=200, + content=b'{"factors":[]}', + request=request, + ) + + try: + flow.send(success_response) + retried = True + except StopIteration: + retried = False + + assert not retried + diff --git a/src/auth0_server_python/tests/test_passkey_my_account.py b/src/auth0_server_python/tests/test_passkey_my_account.py deleted file mode 100644 index 4b7f29d..0000000 --- a/src/auth0_server_python/tests/test_passkey_my_account.py +++ /dev/null @@ -1,830 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest -from jwcrypto import jwk as jwk_module - -from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth -from auth0_server_python.auth_server.my_account_client import MyAccountClient -from auth0_server_python.auth_types import ( - AuthenticationMethod, - EnrollAuthenticationMethodRequest, - EnrollmentChallengeResponse, - GetFactorsResponse, - ListAuthenticationMethodsResponse, - PasskeyAuthResponse, - UpdateAuthenticationMethodRequest, - VerifyAuthenticationMethodRequest, -) -from auth0_server_python.error import ApiError, MissingRequiredArgumentError, MyAccountApiError - - -@pytest.mark.asyncio -async def test_get_factors_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_factors(access_token="token123") - - assert isinstance(result, GetFactorsResponse) - assert len(result.factors) == 1 - assert result.factors[0].name == "sms" - assert result.factors[0].enabled is True - - -@pytest.mark.asyncio -@pytest.mark.parametrize("access_token", [None, ""]) -async def test_get_factors_missing_access_token(mocker, access_token): - client = MyAccountClient(domain="auth0.local") - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.get_factors(access_token=access_token) - - mock_get.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_get_factors_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 403 - response.json = MagicMock( - return_value={ - "title": "Forbidden", - "type": "forbidden", - "detail": "Insufficient scope", - "status": 403, - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.get_factors(access_token="token123") - - assert exc.value.status == 403 - - -@pytest.mark.asyncio -async def test_get_factors_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") - ) - - with pytest.raises(ApiError): - await client.get_factors(access_token="token123") - - -@pytest.mark.asyncio -async def test_get_factors_empty_list(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"factors": []}) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_factors(access_token="token123") - assert result.factors == [] - - -@pytest.mark.asyncio -async def test_get_factors_extra_fields(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_factors(access_token="token123") - assert result.factors[0].name == "webauthn-roaming" - - -@pytest.mark.asyncio -async def test_list_authentication_methods_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "authentication_methods": [ - { - "id": "am_1", - "type": "passkey", - "created_at": "2026-01-01T00:00:00Z", - "key_id": "kid1", - } - ] - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.list_authentication_methods(access_token="token123") - assert isinstance(result, ListAuthenticationMethodsResponse) - assert len(result.authentication_methods) == 1 - assert result.authentication_methods[0].type == "passkey" - - -@pytest.mark.asyncio -async def test_list_authentication_methods_with_type_filter(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"authentication_methods": []}) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - await client.list_authentication_methods(access_token="token123", type_filter="passkey") - mock_get.assert_awaited_once() - call_kwargs = mock_get.call_args[1] - assert call_kwargs["params"] == {"type": "passkey"} - - -@pytest.mark.asyncio -async def test_list_authentication_methods_empty(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"authentication_methods": []}) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.list_authentication_methods(access_token="token123") - assert result.authentication_methods == [] - - -@pytest.mark.asyncio -async def test_get_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert isinstance(result, AuthenticationMethod) - assert result.id == "am_1" - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method_id", [None, ""]) -async def test_get_authentication_method_missing_id(mocker, method_id): - client = MyAccountClient(domain="auth0.local") - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.get_authentication_method( - access_token="token123", authentication_method_id=method_id - ) - - mock_get.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_get_authentication_method_path_traversal(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "id/slash", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - await client.get_authentication_method( - access_token="token123", authentication_method_id="id/slash" - ) - call_url = mock_get.call_args[1]["url"] - assert "id%2Fslash" in call_url - assert "id/slash" not in call_url.replace("https://auth0.local/me/", "") - - -@pytest.mark.asyncio -async def test_get_authentication_method_pipe_encoding(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "passkey|new", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - await client.get_authentication_method( - access_token="token123", authentication_method_id="passkey|new" - ) - call_url = mock_get.call_args[1]["url"] - assert "passkey%7Cnew" in call_url - - -@pytest.mark.asyncio -async def test_delete_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 204 - mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) - - result = await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert result is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method_id", [None, ""]) -async def test_delete_authentication_method_missing_id(mocker, method_id): - client = MyAccountClient(domain="auth0.local") - mock_delete = mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.delete_authentication_method( - access_token="token123", authentication_method_id=method_id - ) - - mock_delete.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_update_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "id": "am_1", - "type": "passkey", - "created_at": "2026-01-01T00:00:00Z", - "name": "My Key", - } - ) - mock_patch = mocker.patch( - "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response - ) - - req = UpdateAuthenticationMethodRequest(name="My Key") - result = await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - assert result.name == "My Key" - call_kwargs = mock_patch.call_args[1] - assert call_kwargs["json"] == {"name": "My Key"} - - -@pytest.mark.asyncio -async def test_update_authentication_method_missing_request(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=None - ) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} - response.json = MagicMock( - return_value={ - "auth_session": "session_abc", - "authn_params_public_key": { - "challenge": "dGVzdA", - "rp": {"id": "auth0.local", "name": "My App"}, - "user": {"id": "dXNlcl8x", "name": "user@test.com", "displayName": "Test User"}, - "pubKeyCredParams": [{"type": "public-key", "alg": -7}], - "authenticatorSelection": { - "residentKey": "required", - "userVerification": "preferred", - }, - "timeout": 60000, - }, - } - ) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - result = await client.enroll_authentication_method(access_token="token123", request=req) - - assert isinstance(result, EnrollmentChallengeResponse) - assert result.authentication_method_id == "passkey|new" - assert result.auth_session == "session_abc" - assert result.authn_params_public_key is not None - assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 - assert result.authn_params_public_key.authenticator_selection.resident_key == "required" - assert result.authn_params_public_key.user.display_name == "Test User" - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_missing_location(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(ApiError) as exc: - await client.enroll_authentication_method(access_token="token123", request=req) - - assert "Location header" in str(exc.value) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_location_with_query(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/abc123?tracking=1"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - result = await client.enroll_authentication_method(access_token="token123", request=req) - assert result.authentication_method_id == "abc123" - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_location_absolute_url(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "https://tenant.auth0.com/me/v1/authentication-methods/am_xyz"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - result = await client.enroll_authentication_method(access_token="token123", request=req) - assert result.authentication_method_id == "am_xyz" - - -@pytest.mark.asyncio -async def test_verify_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "id": "am_1", - "type": "passkey", - "created_at": "2026-01-01T00:00:00Z", - "confirmed": True, - } - ) - mock_post = mocker.patch( - "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response - ) - - authn_response = PasskeyAuthResponse( - id="cred1", - raw_id="cmF3MQ", - type="public-key", - authenticator_attachment="platform", - response={"clientDataJSON": "abc", "attestationObject": "def"}, - ) - req = VerifyAuthenticationMethodRequest( - auth_session="session_abc", authn_response=authn_response - ) - result = await client.verify_authentication_method( - access_token="token123", authentication_method_id="passkey|new", request=req - ) - - assert isinstance(result, AuthenticationMethod) - assert result.confirmed is True - - call_kwargs = mock_post.call_args[1] - body = call_kwargs["json"] - assert "rawId" in body["authn_response"] - assert "raw_id" not in body["authn_response"] - assert "authenticatorAttachment" in body["authn_response"] - assert body["auth_session"] == "session_abc" - assert "passkey%7Cnew" in call_kwargs["url"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method_id", [None, ""]) -async def test_verify_authentication_method_missing_id(mocker, method_id): - client = MyAccountClient(domain="auth0.local") - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) - - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - with pytest.raises(MissingRequiredArgumentError): - await client.verify_authentication_method( - access_token="token123", authentication_method_id=method_id, request=req - ) - - -@pytest.mark.asyncio -async def test_enrollment_challenge_response_repr(): - resp = EnrollmentChallengeResponse( - authentication_method_id="am_1", - auth_session="super_secret_session", - authn_params_public_key=None, - ) - repr_str = repr(resp) - assert "super_secret_session" not in repr_str - assert "[REDACTED]" in repr_str - assert "am_1" in repr_str - - -def test_verify_request_requires_at_least_one_method(): - with pytest.raises(Exception, match="At least one verification method"): - VerifyAuthenticationMethodRequest(auth_session="session_abc") - - -def test_verify_request_accepts_otp_code(): - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - assert req.otp_code == "123456" - - -def test_verify_request_accepts_authn_response(): - authn_resp = PasskeyAuthResponse( - id="cred1", - raw_id="cmF3MQ", - type="public-key", - response={"clientDataJSON": "abc", "attestationObject": "def"}, - ) - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", authn_response=authn_resp) - assert req.authn_response is not None - - -@pytest.mark.asyncio -async def test_get_factors_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.get_factors(access_token="token123", dpop_key=dpop_key) - - mock_get.assert_awaited_once() - call_kwargs = mock_get.call_args[1] - assert isinstance(call_kwargs["auth"], DPoPAuth) - - -# ============================================================================= -# DPoP integration(mock) tests -# ============================================================================= - - -@pytest.mark.asyncio -async def test_list_authentication_methods_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"authentication_methods": []}) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.list_authentication_methods(access_token="token123", dpop_key=dpop_key) - - assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_get_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key - ) - - assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_delete_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 204 - mock_delete = mocker.patch( - "httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key - ) - - assert isinstance(mock_delete.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_update_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_patch = mocker.patch( - "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - req = UpdateAuthenticationMethodRequest(name="New Name") - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req, dpop_key=dpop_key - ) - - assert isinstance(mock_patch.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mock_post = mocker.patch( - "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - req = EnrollAuthenticationMethodRequest(type="passkey") - await client.enroll_authentication_method( - access_token="token123", request=req, dpop_key=dpop_key - ) - - assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_verify_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_post = mocker.patch( - "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - await client.verify_authentication_method( - access_token="token123", - authentication_method_id="am_1", - request=req, - dpop_key=dpop_key, - ) - - assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) - - -# ============================================================================= -# API error and network error tests -# ============================================================================= - - -@pytest.mark.asyncio -async def test_list_authentication_methods_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 403 - response.json = MagicMock( - return_value={ - "title": "Forbidden", - "type": "forbidden", - "detail": "Insufficient scope", - "status": 403, - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.list_authentication_methods(access_token="token123") - assert exc.value.status == 403 - - -@pytest.mark.asyncio -async def test_list_authentication_methods_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") - ) - - with pytest.raises(ApiError): - await client.list_authentication_methods(access_token="token123") - - -@pytest.mark.asyncio -async def test_get_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 404 - response.json = MagicMock( - return_value={ - "title": "Not Found", - "type": "not_found", - "detail": "Not found", - "status": 404, - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert exc.value.status == 404 - - -@pytest.mark.asyncio -async def test_get_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("timeout")) - - with pytest.raises(ApiError): - await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - - -@pytest.mark.asyncio -async def test_delete_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 404 - response.json = MagicMock( - return_value={ - "title": "Not Found", - "type": "not_found", - "detail": "Not found", - "status": 404, - } - ) - mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert exc.value.status == 404 - - -@pytest.mark.asyncio -async def test_delete_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.delete", - new_callable=AsyncMock, - side_effect=Exception("Connection reset"), - ) - - with pytest.raises(ApiError): - await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - - -@pytest.mark.asyncio -async def test_update_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 422 - response.json = MagicMock( - return_value={ - "title": "Unprocessable", - "type": "validation_error", - "detail": "Invalid", - "status": 422, - } - ) - mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response) - - req = UpdateAuthenticationMethodRequest(name="x") - with pytest.raises(MyAccountApiError) as exc: - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - assert exc.value.status == 422 - - -@pytest.mark.asyncio -async def test_update_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.patch", new_callable=AsyncMock, side_effect=Exception("timeout") - ) - - req = UpdateAuthenticationMethodRequest(name="x") - with pytest.raises(ApiError): - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 403 - response.json = MagicMock( - return_value={ - "title": "Forbidden", - "type": "forbidden", - "detail": "Scope missing", - "status": 403, - } - ) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(MyAccountApiError) as exc: - await client.enroll_authentication_method(access_token="token123", request=req) - assert exc.value.status == 403 - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.post", - new_callable=AsyncMock, - side_effect=Exception("Connection refused"), - ) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(ApiError): - await client.enroll_authentication_method(access_token="token123", request=req) - - -@pytest.mark.asyncio -async def test_verify_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 400 - response.json = MagicMock( - return_value={ - "title": "Bad Request", - "type": "invalid_request", - "detail": "Invalid OTP", - "status": 400, - } - ) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="000000") - with pytest.raises(MyAccountApiError) as exc: - await client.verify_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - assert exc.value.status == 400 - - -@pytest.mark.asyncio -async def test_verify_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.post", - new_callable=AsyncMock, - side_effect=Exception("Connection refused"), - ) - - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - with pytest.raises(ApiError): - await client.verify_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - - -# ============================================================================= -# Location header extraction edge case -# ============================================================================= - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_location_collection_url(mocker): - """Rejects Location header that ends at collection path without resource ID.""" - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(ApiError) as exc: - await client.enroll_authentication_method(access_token="token123", request=req) - assert "could not extract ID" in str(exc.value) diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py deleted file mode 100644 index 7c2be37..0000000 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ /dev/null @@ -1,585 +0,0 @@ -import time -from unittest.mock import AsyncMock - -import httpx -import pytest - -from auth0_server_python.auth_server.server_client import ServerClient -from auth0_server_python.auth_types import ( - PasskeyAuthResponse, - PasskeyLoginChallengeResponse, - PasskeySignupChallengeResponse, - PasskeyTokenResponse, -) -from auth0_server_python.error import ApiError, MissingRequiredArgumentError - - -@pytest.fixture -def server_client(): - return ServerClient( - domain="auth0.local", - client_id="test_client_id", - client_secret="test_client_secret", - state_store=AsyncMock(), - transaction_store=AsyncMock(), - secret="test-secret-value", - ) - - -SIGNUP_CHALLENGE_RESPONSE = { - "auth_session": "session_abc123", - "authn_params_public_key": { - "challenge": "dGVzdC1jaGFsbGVuZ2U", - "rp": {"id": "auth0.local", "name": "Test App"}, - "user": {"id": "dXNlcl8x", "name": "user@example.com", "displayName": "Jane"}, - "pubKeyCredParams": [{"type": "public-key", "alg": -7}], - "authenticatorSelection": { - "residentKey": "required", - "userVerification": "preferred", - }, - "timeout": 60000, - }, -} - -LOGIN_CHALLENGE_RESPONSE = { - "auth_session": "session_login_xyz", - "authn_params_public_key": { - "challenge": "bG9naW4tY2hhbGxlbmdl", - "rpId": "auth0.local", - "timeout": 60000, - "userVerification": "preferred", - }, -} - -TOKEN_RESPONSE = { - "access_token": "at_passkey_123", - "id_token": "eyJ.test.jwt", - "token_type": "Bearer", - "expires_in": 86400, - "scope": "openid profile", -} - - -def _mock_response(status_code=200, json_data=None, headers=None): - resp = httpx.Response( - status_code=status_code, - json=json_data, - headers=headers or {}, - request=httpx.Request("POST", "https://auth0.local/passkey/register"), - ) - return resp - - -# ============================================================================= -# passkey_signup_challenge -# ============================================================================= - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_success(server_client, mocker): - mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - result = await server_client.passkey_signup_challenge( - email="user@example.com", - name="Jane Doe", - connection="Username-Password-Authentication", - ) - - assert isinstance(result, PasskeySignupChallengeResponse) - assert result.auth_session == "session_abc123" - assert result.authn_params_public_key.challenge == "dGVzdC1jaGFsbGVuZ2U" - assert result.authn_params_public_key.rp.id == "auth0.local" - assert result.authn_params_public_key.user.display_name == "Jane" - assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 - assert result.authn_params_public_key.authenticator_selection.resident_key == "required" - - call_args = mock_client.post.call_args - assert "/passkey/register" in call_args.args[0] - body = call_args.kwargs["json"] - assert body["client_id"] == "test_client_id" - assert body["client_secret"] == "test_client_secret" - assert body["user_profile"]["email"] == "user@example.com" - assert body["user_profile"]["name"] == "Jane Doe" - assert body["realm"] == "Username-Password-Authentication" - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_user_profile_fields(server_client, mocker): - mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - await server_client.passkey_signup_challenge( - email="u@e.com", - username="jdoe", - phone_number="+1234567890", - given_name="Jane", - family_name="Doe", - nickname="jd", - picture="https://example.com/pic.jpg", - user_metadata={"role": "admin"}, - organization="org_123", - ) - - body = mock_client.post.call_args.kwargs["json"] - assert body["user_profile"]["email"] == "u@e.com" - assert body["user_profile"]["username"] == "jdoe" - assert body["user_profile"]["phone_number"] == "+1234567890" - assert body["user_profile"]["given_name"] == "Jane" - assert body["user_profile"]["family_name"] == "Doe" - assert body["user_profile"]["nickname"] == "jd" - assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert body["user_profile"]["user_metadata"] == {"role": "admin"} - assert body["organization"] == "org_123" - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_minimal_body(server_client, mocker): - mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - await server_client.passkey_signup_challenge() - - body = mock_client.post.call_args.kwargs["json"] - assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} - assert "user_profile" not in body - assert "realm" not in body - assert "organization" not in body - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_api_error(server_client, mocker): - error_resp = _mock_response( - 403, - {"error": "access_denied", "error_description": "Passkey not enabled"}, - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=error_resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError) as exc: - await server_client.passkey_signup_challenge(email="test@example.com") - assert "access_denied" in str(exc.value) or "Passkey not enabled" in str(exc.value) - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_non_json_error(server_client, mocker): - resp = httpx.Response( - status_code=502, - content=b"Bad Gateway", - headers={"content-type": "text/html"}, - request=httpx.Request("POST", "https://auth0.local/passkey/register"), - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError) as exc: - await server_client.passkey_signup_challenge() - assert "502" in str(exc.value) or "passkey_challenge_error" in str(exc.value) - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_network_error(server_client, mocker): - mock_client = AsyncMock() - mock_client.post = AsyncMock(side_effect=Exception("Connection refused")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError) as exc: - await server_client.passkey_signup_challenge() - assert "Passkey signup challenge failed" in str(exc.value) - - -# ============================================================================= -# passkey_login_challenge -# ============================================================================= - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_success(server_client, mocker): - mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - result = await server_client.passkey_login_challenge( - connection="Username-Password-Authentication", - organization="org_abc", - ) - - assert isinstance(result, PasskeyLoginChallengeResponse) - assert result.auth_session == "session_login_xyz" - assert result.authn_params_public_key.challenge == "bG9naW4tY2hhbGxlbmdl" - assert result.authn_params_public_key.rp_id == "auth0.local" - assert result.authn_params_public_key.user_verification == "preferred" - - body = mock_client.post.call_args.kwargs["json"] - assert body["client_id"] == "test_client_id" - assert body["realm"] == "Username-Password-Authentication" - assert body["organization"] == "org_abc" - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_with_username(server_client, mocker): - mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - await server_client.passkey_login_challenge(username="jane@example.com") - - body = mock_client.post.call_args.kwargs["json"] - assert body["username"] == "jane@example.com" - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_api_error(server_client, mocker): - error_resp = _mock_response( - 400, - {"error": "invalid_request", "error_description": "Missing client_id"}, - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=error_resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError): - await server_client.passkey_login_challenge() - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_network_error(server_client, mocker): - mock_client = AsyncMock() - mock_client.post = AsyncMock(side_effect=Exception("timeout")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError): - await server_client.passkey_login_challenge() - - -# ============================================================================= -# signin_with_passkey -# ============================================================================= - - -@pytest.fixture -def authn_response(): - return PasskeyAuthResponse( - id="cred_abc123", - raw_id="Y3JlZF9hYmMxMjM", - type="public-key", - authenticator_attachment="platform", - response={ - "clientDataJSON": "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0In0", - "authenticatorData": "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2M", - "signature": "MEUCIQC", - "userHandle": "dXNlcl8x", - }, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_success(server_client, authn_response, mocker): - mock_response = _mock_response(200, TOKEN_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - result = await server_client.signin_with_passkey( - auth_session="session_xyz", - authn_response=authn_response, - scope="openid profile", - audience="https://api.example.com", - connection="Username-Password-Authentication", - organization="org_abc", - ) - - assert isinstance(result, PasskeyTokenResponse) - assert result.access_token == "at_passkey_123" - assert result.token_type == "Bearer" - assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 - - body = mock_client.post.call_args.kwargs["json"] - assert body["grant_type"] == "urn:okta:params:oauth:grant-type:webauthn" - assert body["client_id"] == "test_client_id" - assert body["client_secret"] == "test_client_secret" - assert body["auth_session"] == "session_xyz" - assert body["scope"] == "openid profile" - assert body["audience"] == "https://api.example.com" - assert body["realm"] == "Username-Password-Authentication" - assert body["organization"] == "org_abc" - assert body["authn_response"]["rawId"] == "Y3JlZF9hYmMxMjM" - assert body["authn_response"]["authenticatorAttachment"] == "platform" - assert "raw_id" not in body["authn_response"] - - -@pytest.mark.asyncio -async def test_signin_with_passkey_uses_json_content_type(server_client, authn_response, mocker): - mock_response = _mock_response(200, TOKEN_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - await server_client.signin_with_passkey( - auth_session="s", - authn_response=authn_response, - ) - - call_kwargs = mock_client.post.call_args.kwargs - assert "json" in call_kwargs - assert "data" not in call_kwargs - - -@pytest.mark.asyncio -@pytest.mark.parametrize("auth_session", [None, ""]) -async def test_signin_with_passkey_missing_auth_session( - server_client, authn_response, auth_session -): - with pytest.raises(MissingRequiredArgumentError): - await server_client.signin_with_passkey( - auth_session=auth_session, - authn_response=authn_response, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_missing_authn_response(server_client): - with pytest.raises(MissingRequiredArgumentError): - await server_client.signin_with_passkey( - auth_session="session_abc", - authn_response=None, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_api_error(server_client, authn_response, mocker): - error_resp = _mock_response( - 401, - {"error": "invalid_grant", "error_description": "Invalid auth_session"}, - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=error_resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - with pytest.raises(ApiError) as exc: - await server_client.signin_with_passkey( - auth_session="expired_session", - authn_response=authn_response, - ) - assert "invalid_grant" in str(exc.value) or "Invalid auth_session" in str(exc.value) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_missing_token_endpoint(server_client, authn_response, mocker): - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={}, - ) - - with pytest.raises(ApiError) as exc: - await server_client.signin_with_passkey( - auth_session="session", - authn_response=authn_response, - ) - assert "token endpoint" in str(exc.value).lower() - - -@pytest.mark.asyncio -async def test_signin_with_passkey_network_error(server_client, authn_response, mocker): - mock_client = AsyncMock() - mock_client.post = AsyncMock(side_effect=Exception("Connection reset")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - with pytest.raises(ApiError): - await server_client.signin_with_passkey( - auth_session="session", - authn_response=authn_response, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_no_client_secret(mocker): - client = ServerClient( - domain="auth0.local", - client_id="public_client", - client_secret=None, - state_store=AsyncMock(), - transaction_store=AsyncMock(), - secret="test-secret", - ) - - mock_response = _mock_response(200, TOKEN_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - authn_resp = PasskeyAuthResponse( - id="cred", - raw_id="cmF3", - type="public-key", - response={"clientDataJSON": "abc", "authenticatorData": "def", "signature": "ghi"}, - ) - - await client.signin_with_passkey( - auth_session="session", - authn_response=authn_resp, - ) - - body = mock_client.post.call_args.kwargs["json"] - assert "client_secret" not in body - assert body["client_id"] == "public_client" - - -@pytest.mark.asyncio -async def test_signup_challenge_repr_redacts_auth_session(): - resp = PasskeySignupChallengeResponse.model_validate(SIGNUP_CHALLENGE_RESPONSE) - repr_str = repr(resp) - assert "session_abc123" not in repr_str - assert "[REDACTED]" in repr_str - - -@pytest.mark.asyncio -async def test_login_challenge_repr_redacts_auth_session(): - resp = PasskeyLoginChallengeResponse.model_validate(LOGIN_CHALLENGE_RESPONSE) - repr_str = repr(resp) - assert "session_login_xyz" not in repr_str - assert "[REDACTED]" in repr_str - - -def test_passkey_token_response_repr_redacts_tokens(): - resp = PasskeyTokenResponse( - access_token="secret_at_value", - token_type="Bearer", - expires_in=86400, - id_token="secret_id_token", - refresh_token="secret_rt_value", - ) - repr_str = repr(resp) - assert "secret_at_value" not in repr_str - assert "secret_id_token" not in repr_str - assert "secret_rt_value" not in repr_str - assert "[REDACTED]" in repr_str - assert "86400" in repr_str - - -# ============================================================================= -# expires_at edge cases -# ============================================================================= - - -@pytest.mark.asyncio -async def test_signin_with_passkey_preserves_server_expires_at( - server_client, authn_response, mocker -): - token_data = { - "access_token": "at_123", - "token_type": "Bearer", - "expires_in": 3600, - "expires_at": 9999999999, - } - mock_response = _mock_response(200, token_data) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - result = await server_client.signin_with_passkey( - auth_session="session", authn_response=authn_response - ) - - assert result.expires_at == 9999999999 - - -@pytest.mark.asyncio -async def test_signin_with_passkey_missing_expires_at_calculates( - server_client, authn_response, mocker -): - token_data = { - "access_token": "at_123", - "token_type": "Bearer", - "expires_in": 60, - } - mock_response = _mock_response(200, token_data) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - result = await server_client.signin_with_passkey( - auth_session="session", authn_response=authn_response - ) - - assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 47ba774..f987567 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -4816,3 +4816,797 @@ async def _fake_fetch(self, domain): assert exc.value.mfa_requirements is not None finally: ServerClient._fetch_oidc_metadata = original_fetch + + +# ============================================================================= +# PASSKEY AUTHENTICATION +# ============================================================================= + +_PASSKEY_SIGNUP_CHALLENGE_RESPONSE = { + "auth_session": "session_abc123", + "authn_params_public_key": { + "challenge": "dGVzdC1jaGFsbGVuZ2U", + "rp": {"id": "auth0.local", "name": "Test App"}, + "user": {"id": "dXNlcl8x", "name": "user@example.com", "displayName": "Jane"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": { + "residentKey": "required", + "userVerification": "preferred", + }, + "timeout": 60000, + }, +} + +_PASSKEY_LOGIN_CHALLENGE_RESPONSE = { + "auth_session": "session_login_xyz", + "authn_params_public_key": { + "challenge": "bG9naW4tY2hhbGxlbmdl", + "rpId": "auth0.local", + "timeout": 60000, + "userVerification": "preferred", + }, +} + +_PASSKEY_TOKEN_RESPONSE = { + "access_token": "at_passkey_123", + "id_token": "eyJ.test.jwt", + "token_type": "Bearer", + "expires_in": 86400, + "scope": "openid profile", +} + + +def _make_passkey_authn_response(): + from auth0_server_python.auth_types import PasskeyAuthResponse + return PasskeyAuthResponse( + id="cred_abc123", + raw_id="Y3JlZF9hYmMxMjM", + type="public-key", + authenticator_attachment="platform", + response={ + "clientDataJSON": "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0In0", + "authenticatorData": "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2M", + "signature": "MEUCIQC", + "userHandle": "dXNlcl8x", + }, + ) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_success(mocker): + from auth0_server_python.auth_types import PasskeySignupChallengeResponse, PasskeyUserProfile + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + result = await client.passkey_signup_challenge( + user_profile=PasskeyUserProfile(email="user@example.com", name="Jane Doe"), + connection="Username-Password-Authentication", + ) + + assert isinstance(result, PasskeySignupChallengeResponse) + assert result.auth_session == "session_abc123" + assert result.authn_params_public_key.challenge == "dGVzdC1jaGFsbGVuZ2U" + assert result.authn_params_public_key.rp.id == "auth0.local" + assert result.authn_params_public_key.user.display_name == "Jane" + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + + mock_post.assert_awaited_once() + args, kwargs = mock_post.call_args + assert "/passkey/register" in args[0] + body = kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["user_profile"]["email"] == "user@example.com" + assert body["user_profile"]["name"] == "Jane Doe" + assert body["realm"] == "Username-Password-Authentication" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_user_profile_fields(mocker): + from auth0_server_python.auth_types import PasskeyUserProfile + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_signup_challenge( + user_profile=PasskeyUserProfile( + email="u@e.com", + username="jdoe", + phone_number="+1234567890", + given_name="Jane", + family_name="Doe", + nickname="jd", + picture="https://example.com/pic.jpg", + user_metadata={"role": "admin"}, + ), + organization="org_123", + ) + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["user_profile"]["email"] == "u@e.com" + assert body["user_profile"]["username"] == "jdoe" + assert body["user_profile"]["phone_number"] == "+1234567890" + assert body["user_profile"]["given_name"] == "Jane" + assert body["user_profile"]["family_name"] == "Doe" + assert body["user_profile"]["nickname"] == "jd" + assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" + assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert body["organization"] == "org_123" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_minimal_body(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_signup_challenge() + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} + assert "user_profile" not in body + assert "realm" not in body + assert "organization" not in body + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_api_error(mocker): + from auth0_server_python.auth_types import PasskeyUserProfile + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 403 + mock_response.json = MagicMock(return_value={ + "error": "access_denied", + "error_description": "Passkey not enabled", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.passkey_signup_challenge( + user_profile=PasskeyUserProfile(email="test@example.com") + ) + assert "access_denied" in str(exc.value) or "Passkey not enabled" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_non_json_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 502 + mock_response.json = MagicMock(side_effect=json.JSONDecodeError("bad", "", 0)) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.passkey_signup_challenge() + assert "502" in str(exc.value) or "passkey_challenge_error" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_network_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_post.side_effect = Exception("Connection refused") + + with pytest.raises(PasskeyError) as exc: + await client.passkey_signup_challenge() + assert "Passkey signup challenge failed" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_success(mocker): + from auth0_server_python.auth_types import PasskeyLoginChallengeResponse + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + result = await client.passkey_login_challenge( + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyLoginChallengeResponse) + assert result.auth_session == "session_login_xyz" + assert result.authn_params_public_key.challenge == "bG9naW4tY2hhbGxlbmdl" + assert result.authn_params_public_key.rp_id == "auth0.local" + assert result.authn_params_public_key.user_verification == "preferred" + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_with_username(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_login_challenge(username="jane@example.com") + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["username"] == "jane@example.com" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_api_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 400 + mock_response.json = MagicMock(return_value={ + "error": "invalid_request", + "error_description": "Missing client_id", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError): + await client.passkey_login_challenge() + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_network_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_post.side_effect = Exception("timeout") + + with pytest.raises(PasskeyError): + await client.passkey_login_challenge() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_success(mocker): + from auth0_server_python.auth_types import PasskeyTokenResponse + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + authn_response = _make_passkey_authn_response() + + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=authn_response, + scope="openid profile", + audience="https://api.example.com", + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyTokenResponse) + assert result.access_token == "at_passkey_123" + assert result.token_type == "Bearer" + assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 + + mock_post.assert_awaited_once() + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["grant_type"] == "urn:okta:params:oauth:grant-type:webauthn" + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["auth_session"] == "session_xyz" + assert body["scope"] == "openid profile" + assert body["audience"] == "https://api.example.com" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + assert body["authn_response"]["rawId"] == "Y3JlZF9hYmMxMjM" + assert body["authn_response"]["authenticatorAttachment"] == "platform" + assert "raw_id" not in body["authn_response"] + + +@pytest.mark.asyncio +async def test_signin_with_passkey_uses_json_content_type(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + await client.signin_with_passkey( + auth_session="s", + authn_response=_make_passkey_authn_response(), + ) + + args, kwargs = mock_post.call_args + assert "json" in kwargs + assert "data" not in kwargs + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_session", [None, ""]) +async def test_signin_with_passkey_missing_auth_session(auth_session): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + with pytest.raises(MissingRequiredArgumentError): + await client.signin_with_passkey( + auth_session=auth_session, + authn_response=_make_passkey_authn_response(), + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_authn_response(): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + with pytest.raises(MissingRequiredArgumentError): + await client.signin_with_passkey( + auth_session="session_abc", + authn_response=None, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_api_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 401 + mock_response.json = MagicMock(return_value={ + "error": "invalid_grant", + "error_description": "Invalid auth_session", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="expired_session", + authn_response=_make_passkey_authn_response(), + ) + assert "invalid_grant" in str(exc.value) or "Invalid auth_session" in str(exc.value) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_token_endpoint(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={}) + + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + assert "token endpoint" in str(exc.value).lower() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_network_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_post.side_effect = Exception("Connection reset") + + with pytest.raises(PasskeyError): + await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_no_client_secret(mocker): + client = ServerClient( + domain="auth0.local", + client_id="public_client", + client_secret=None, + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + from auth0_server_python.auth_types import PasskeyAuthResponse + authn_resp = PasskeyAuthResponse( + id="cred", + raw_id="cmF3", + type="public-key", + response={"clientDataJSON": "abc", "authenticatorData": "def", "signature": "ghi"}, + ) + await client.signin_with_passkey(auth_session="session", authn_response=authn_resp) + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert "client_secret" not in body + assert body["client_id"] == "public_client" + + +def test_passkey_signup_challenge_repr_redacts_auth_session(): + from auth0_server_python.auth_types import PasskeySignupChallengeResponse + resp = PasskeySignupChallengeResponse.model_validate(_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_abc123" not in repr_str + assert "[REDACTED]" in repr_str + + +def test_passkey_login_challenge_repr_redacts_auth_session(): + from auth0_server_python.auth_types import PasskeyLoginChallengeResponse + resp = PasskeyLoginChallengeResponse.model_validate(_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_login_xyz" not in repr_str + assert "[REDACTED]" in repr_str + + +def test_passkey_token_response_repr_redacts_tokens(): + from auth0_server_python.auth_types import PasskeyTokenResponse + resp = PasskeyTokenResponse( + access_token="secret_at_value", + token_type="Bearer", + expires_in=86400, + id_token="secret_id_token", + refresh_token="secret_rt_value", + ) + repr_str = repr(resp) + assert "secret_at_value" not in repr_str + assert "secret_id_token" not in repr_str + assert "secret_rt_value" not in repr_str + assert "[REDACTED]" in repr_str + assert "86400" in repr_str + + +@pytest.mark.asyncio +async def test_signin_with_passkey_preserves_server_expires_at(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={ + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": 9999999999, + }) + mock_post.return_value = mock_response + + from auth0_server_python.auth_types import PasskeyTokenResponse + result = await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + assert result.expires_at == 9999999999 + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_expires_at_calculates(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={ + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 60, + }) + mock_post.return_value = mock_response + + result = await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 + + +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): + import base64 + import json as _json + from jwcrypto import jwk as jwk_module + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + + args, kwargs = mock_post.call_args + assert "DPoP" in kwargs["headers"] + + # Decode proof and assert no ath claim (token endpoint proof — RFC 9449 §4.2) + proof = kwargs["headers"]["DPoP"] + payload_b64 = proof.split(".")[1] + padding = 4 - len(payload_b64) % 4 + payload = _json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + assert "ath" not in payload + assert "jti" in payload + assert payload["htm"] == "POST" + assert payload["htu"] == "https://auth0.local/oauth/token" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_nonce_retry(mocker): + import base64 + import json as _json + from jwcrypto import jwk as jwk_module + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + nonce_response = AsyncMock() + nonce_response.status_code = 401 + nonce_response.headers = {"DPoP-Nonce": "server-nonce-abc"} + nonce_response.json = MagicMock(return_value={"error": "use_dpop_nonce"}) + + success_response = AsyncMock() + success_response.status_code = 200 + success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + + mock_post.side_effect = [nonce_response, success_response] + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + + assert mock_post.await_count == 2 + assert result.access_token == "at_passkey_123" + + # Second call must include the nonce in the DPoP proof + second_call_kwargs = mock_post.call_args_list[1][1] + proof = second_call_kwargs["headers"]["DPoP"] + payload_b64 = proof.split(".")[1] + padding = 4 - len(payload_b64) % 4 + payload = _json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + assert payload["nonce"] == "server-nonce-abc" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + + args, kwargs = mock_post.call_args + assert "DPoP" not in kwargs.get("headers", {}) From d299dbff859f633e6bfc4646a1b7db735f761039 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Fri, 5 Jun 2026 13:16:43 +0530 Subject: [PATCH 08/12] SDK-8780 Added review changes for integrating passkey sign-in with SDKs state handling --- .../auth_server/server_client.py | 66 ++++++- .../auth_types/__init__.py | 12 ++ .../tests/test_server_client.py | 166 ++++++++++++++++-- 3 files changed, 218 insertions(+), 26 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 110f33a..1803ea2 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -37,6 +37,7 @@ MfaRequirements, PasskeyAuthResponse, PasskeyLoginChallengeResponse, + PasskeyLoginResult, PasskeySignupChallengeResponse, PasskeyUserProfile, PasskeyTokenResponse, @@ -2653,20 +2654,22 @@ async def signin_with_passkey( scope: Optional[str] = None, audience: Optional[str] = None, dpop_key: Optional["jwk.JWK"] = None, - ) -> PasskeyTokenResponse: + ) -> PasskeyLoginResult: """ Completes passkey authentication by exchanging the WebAuthn assertion - for tokens (POST /oauth/token with webauthn grant). + for tokens and establishing a server-side session. This is step 2 of 2: call passkey_signup_challenge or passkey_login_challenge first to obtain auth_session and the WebAuthn challenge options. Uses Content-Type: application/json (required for nested authn_response). + Persists the session to the state store (same as complete_interactive_login). Args: auth_session: Session credential from passkey_signup_challenge or passkey_login_challenge. authn_response: Serialized WebAuthn credential from navigator.credentials.create/get. - store_options: Optional options for domain resolution and state store. + store_options: Options passed to the state store (e.g., request/response for cookies). + When None, session storage is skipped (stateless deployments). connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. scope: OAuth2 scope string. @@ -2676,11 +2679,12 @@ async def signin_with_passkey( (token_type: DPoP). Required when the tenant mandates DPoP binding. Returns: - PasskeyTokenResponse containing access_token, id_token, expires_in, etc. + PasskeyLoginResult containing state_data with user claims and token sets, + consistent with complete_interactive_login and login_with_custom_token_exchange. Raises: MissingRequiredArgumentError: If auth_session or authn_response is missing. - PasskeyError: If token exchange fails. + PasskeyError: If token exchange or session creation fails. """ if not auth_session: raise MissingRequiredArgumentError("auth_session") @@ -2755,9 +2759,57 @@ async def signin_with_passkey( if "expires_in" in token_data and "expires_at" not in token_data: token_data["expires_at"] = int(time.time()) + token_data["expires_in"] - return PasskeyTokenResponse.model_validate(token_data) + token_response = PasskeyTokenResponse.model_validate(token_data) + + # Extract user claims from ID token if present + user_claims = None + sid = PKCE.generate_random_string(32) + if token_response.id_token: + jwks = await self._get_jwks_cached(domain, metadata) + try: + claims = await self._verify_and_decode_jwt( + token_response.id_token, jwks, audience=self._client_id + ) + origin_issuer = metadata.get("issuer") + token_issuer = claims.get("iss", "") + if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): + raise IssuerValidationError( + "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." + ) + user_claims = UserClaims.parse_obj(claims) + sid = claims.get("sid", sid) + except ValueError as e: + raise ApiError("jwks_key_not_found", str(e)) + except jwt.InvalidSignatureError as e: + raise ApiError("invalid_signature", f"ID token signature verification failed: {str(e)}", e) + except jwt.InvalidAudienceError as e: + raise ApiError("invalid_audience", f"ID token audience mismatch: {str(e)}", e) + except jwt.ExpiredSignatureError as e: + raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) + except jwt.InvalidTokenError as e: + raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) + + # Build token set and session state + token_set = TokenSet( + audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, + access_token=token_response.access_token, + scope=token_response.scope or scope or "", + expires_at=token_response.expires_at or int(time.time()) + token_response.expires_in, + ) + state_data = StateData( + user=user_claims, + id_token=token_response.id_token, + refresh_token=token_response.refresh_token, + token_sets=[token_set], + domain=domain, + internal={"sid": sid, "created_at": int(time.time())}, + ) + + await self._state_store.set(self._state_identifier, state_data, options=store_options) + + return PasskeyLoginResult(state_data=state_data.dict()) except Exception as e: - if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError, ApiError, IssuerValidationError)): raise raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Passkey sign-in failed", e) from e diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 9494a22..c6caf8c 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -380,6 +380,18 @@ class LoginWithCustomTokenExchangeResult(BaseModel): authorization_details: Optional[list[AuthorizationDetails]] = None +class PasskeyLoginResult(BaseModel): + """ + Result from signin_with_passkey. + + Contains the session data established after the webauthn token exchange. + Mirrors LoginWithCustomTokenExchangeResult — passkey sign-in is a complete + login ceremony and creates a server-side session like every other login path. + """ + + state_data: dict[str, Any] + + # ============================================================================= # Connected Accounts Types # ============================================================================= diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index f987567..88d565a 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -5154,21 +5154,25 @@ async def test_passkey_login_challenge_network_error(mocker): @pytest.mark.asyncio async def test_signin_with_passkey_success(mocker): - from auth0_server_python.auth_types import PasskeyTokenResponse - from auth0_server_python.error import PasskeyError + from auth0_server_python.auth_types import PasskeyLoginResult + state_store = AsyncMock() client = ServerClient( domain="auth0.local", client_id="test_client_id", client_secret="test_client_secret", - state_store=AsyncMock(), + state_store=state_store, transaction_store=AsyncMock(), secret="test-secret-value", ) mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "name": "Jane", "iss": "https://auth0.local/", "sid": "sid_abc" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5185,10 +5189,13 @@ async def test_signin_with_passkey_success(mocker): organization="org_abc", ) - assert isinstance(result, PasskeyTokenResponse) - assert result.access_token == "at_passkey_123" - assert result.token_type == "Bearer" - assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 + assert isinstance(result, PasskeyLoginResult) + assert "token_sets" in result.state_data + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_123" + assert result.state_data["token_sets"][0]["audience"] == "https://api.example.com" + + # Session must be persisted + state_store.set.assert_awaited_once() mock_post.assert_awaited_once() args, kwargs = mock_post.call_args @@ -5219,8 +5226,12 @@ async def test_signin_with_passkey_uses_json_content_type(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5365,8 +5376,12 @@ async def test_signin_with_passkey_no_client_secret(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5434,8 +5449,12 @@ async def test_signin_with_passkey_preserves_server_expires_at(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5447,12 +5466,11 @@ async def test_signin_with_passkey_preserves_server_expires_at(mocker): }) mock_post.return_value = mock_response - from auth0_server_python.auth_types import PasskeyTokenResponse result = await client.signin_with_passkey( auth_session="session", authn_response=_make_passkey_authn_response(), ) - assert result.expires_at == 9999999999 + assert result.state_data["token_sets"][0]["expires_at"] == 9999999999 @pytest.mark.asyncio @@ -5468,8 +5486,12 @@ async def test_signin_with_passkey_missing_expires_at_calculates(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5484,7 +5506,7 @@ async def test_signin_with_passkey_missing_expires_at_calculates(mocker): auth_session="session", authn_response=_make_passkey_authn_response(), ) - assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 + assert abs(result.state_data["token_sets"][0]["expires_at"] - (int(time.time()) + 60)) <= 2 @pytest.mark.asyncio @@ -5503,8 +5525,12 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5548,8 +5574,12 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) nonce_response = AsyncMock() @@ -5571,7 +5601,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): ) assert mock_post.await_count == 2 - assert result.access_token == "at_passkey_123" + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_123" # Second call must include the nonce in the DPoP proof second_call_kwargs = mock_post.call_args_list[1][1] @@ -5595,8 +5625,12 @@ async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5610,3 +5644,97 @@ async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): args, kwargs = mock_post.call_args assert "DPoP" not in kwargs.get("headers", {}) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_creates_session_in_state_store(mocker): + """signin_with_passkey must persist a session — consistent with complete_interactive_login.""" + from auth0_server_python.auth_types import PasskeyLoginResult + state_store = AsyncMock() + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=state_store, + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", + "name": "Jane Doe", + "email": "jane@example.com", + "iss": "https://auth0.local/", + "sid": "session_sid_abc", + }) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + + # State store must be called exactly once + state_store.set.assert_awaited_once() + + # Result must be PasskeyLoginResult, not bare tokens + assert isinstance(result, PasskeyLoginResult) + + # State data must contain user, token_sets, domain, internal + sd = result.state_data + assert sd["user"]["sub"] == "auth0|user123" + assert sd["user"]["name"] == "Jane Doe" + assert sd["token_sets"][0]["access_token"] == "at_passkey_123" + assert sd["id_token"] == "eyJ.test.jwt" + assert sd["refresh_token"] is None + assert sd["domain"] == "auth0.local" + assert sd["internal"]["sid"] == "session_sid_abc" + assert "created_at" in sd["internal"] + + +@pytest.mark.asyncio +async def test_signin_with_passkey_session_without_id_token(mocker): + """When no id_token is returned, session is still created with user=None.""" + from auth0_server_python.auth_types import PasskeyLoginResult + state_store = AsyncMock() + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=state_store, + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={ + "access_token": "at_no_id_token", + "token_type": "Bearer", + "expires_in": 3600, + }) + mock_post.return_value = mock_response + + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + + assert isinstance(result, PasskeyLoginResult) + state_store.set.assert_awaited_once() + assert result.state_data["user"] is None + assert result.state_data["token_sets"][0]["access_token"] == "at_no_id_token" From 0aacc8ebd0590121d5b6be707eb92216b571321e Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Fri, 5 Jun 2026 22:29:02 +0530 Subject: [PATCH 09/12] PR Review Changes --- .../auth_schemes/dpop_auth.py | 3 +- .../auth_server/my_account_client.py | 19 +-- .../auth_server/server_client.py | 28 +++- .../auth_types/__init__.py | 24 ++- .../tests/test_my_account_client.py | 73 +++++++++ .../tests/test_server_client.py | 143 +++++++++++++++++- 6 files changed, 254 insertions(+), 36 deletions(-) diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index 0bf2d66..a0e0a19 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -61,8 +61,7 @@ def auth_flow(self, request: httpx.Request): # RFC 9449 §8.2 — server-nonce retry if ( - response is not None - and response.status_code == 401 + response.status_code == 401 and response.headers.get("DPoP-Nonce") ): nonce = response.headers["DPoP-Nonce"] diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 5ffadd9..5e10b60 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -1,8 +1,9 @@ import json from typing import TYPE_CHECKING, Optional -from urllib.parse import quote, unquote +from urllib.parse import quote, unquote, urlparse import httpx + from auth0_server_python.auth_schemes.bearer_auth import BearerAuth from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_types import ( @@ -654,12 +655,12 @@ async def enroll_authentication_method( if not location: raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but Location header is missing", + "Enrollment succeeded (202) but Location header is missing", ) - path = location.split("?")[0].split("#")[0].rstrip("/") - segments = path.split("/") - authentication_method_id = unquote(segments[-1]) if len(segments) > 1 else "" + parsed_path = urlparse(location).path.rstrip("/") + raw_id = parsed_path.rsplit("/", 1)[-1] if "/" in parsed_path else "" + authentication_method_id = unquote(raw_id) if not authentication_method_id or authentication_method_id in ( "authentication-methods", "v1", @@ -667,7 +668,7 @@ async def enroll_authentication_method( ): raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but could not extract ID from Location header", + "Enrollment succeeded (202) but could not extract ID from Location header", ) try: @@ -675,21 +676,21 @@ async def enroll_authentication_method( except (json.JSONDecodeError, ValueError): raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but response body is not valid JSON", + "Enrollment succeeded (202) but response body is not valid JSON", ) auth_session = data.get("auth_session") if not auth_session: raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but auth_session is missing from response", + "Enrollment succeeded (202) but auth_session is missing from response", ) return EnrollmentChallengeResponse.model_validate( { + **data, "authentication_method_id": authentication_method_id, "auth_session": auth_session, - "authn_params_public_key": data.get("authn_params_public_key"), } ) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 1803ea2..22f6d21 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -16,10 +16,10 @@ import httpx import jwt from authlib.integrations.base_client.errors import OAuthError -from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, make_dpop_proof_for_token_endpoint from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError +from auth0_server_python.auth_schemes.dpop_auth import make_dpop_proof_for_token_endpoint from auth0_server_python.auth_server.mfa_client import MfaClient from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_types import ( @@ -39,8 +39,8 @@ PasskeyLoginChallengeResponse, PasskeyLoginResult, PasskeySignupChallengeResponse, - PasskeyUserProfile, PasskeyTokenResponse, + PasskeyUserProfile, StartInteractiveLoginOptions, StateData, TokenExchangeResponse, @@ -2508,6 +2508,7 @@ async def passkey_signup_challenge( user_profile: Optional[PasskeyUserProfile] = None, connection: Optional[str] = None, organization: Optional[str] = None, + user_metadata: Optional[dict[str, Any]] = None, store_options: Optional[dict[str, Any]] = None, ) -> PasskeySignupChallengeResponse: """ @@ -2521,6 +2522,8 @@ async def passkey_signup_challenge( Use PasskeyUserProfile — supports extra fields for forward compatibility. connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. + user_metadata: Optional custom metadata added at the root of the request body, + not nested inside user_profile (per Auth0 API spec). store_options: Optional options for domain resolution. Returns: @@ -2537,6 +2540,8 @@ async def passkey_signup_challenge( body["client_secret"] = self._client_secret if user_profile: body["user_profile"] = user_profile.model_dump(exclude_none=True) + if user_metadata: + body["user_metadata"] = user_metadata if connection: body["realm"] = connection if organization: @@ -2669,7 +2674,7 @@ async def signin_with_passkey( auth_session: Session credential from passkey_signup_challenge or passkey_login_challenge. authn_response: Serialized WebAuthn credential from navigator.credentials.create/get. store_options: Options passed to the state store (e.g., request/response for cookies). - When None, session storage is skipped (stateless deployments). + Passed through to the store on every call. connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. scope: OAuth2 scope string. @@ -2761,6 +2766,13 @@ async def signin_with_passkey( token_response = PasskeyTokenResponse.model_validate(token_data) + if dpop_key is not None and token_response.token_type.lower() != "dpop": + raise PasskeyError( + PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, + f"DPoP token binding failed: expected token_type 'DPoP', " + f"got '{token_response.token_type}'", + ) + # Extract user claims from ID token if present user_claims = None sid = PKCE.generate_random_string(32) @@ -2771,12 +2783,16 @@ async def signin_with_passkey( token_response.id_token, jwks, audience=self._client_id ) origin_issuer = metadata.get("issuer") + if not origin_issuer: + raise IssuerValidationError( + "Issuer missing from OIDC metadata. Cannot validate ID token issuer." + ) token_issuer = claims.get("iss", "") if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): raise IssuerValidationError( "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." ) - user_claims = UserClaims.parse_obj(claims) + user_claims = UserClaims.model_validate(claims) sid = claims.get("sid", sid) except ValueError as e: raise ApiError("jwks_key_not_found", str(e)) @@ -2794,7 +2810,7 @@ async def signin_with_passkey( audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.access_token, scope=token_response.scope or scope or "", - expires_at=token_response.expires_at or int(time.time()) + token_response.expires_in, + expires_at=token_response.expires_at if token_response.expires_at is not None else int(time.time()) + token_response.expires_in, ) state_data = StateData( user=user_claims, @@ -2807,7 +2823,7 @@ async def signin_with_passkey( await self._state_store.set(self._state_identifier, state_data, options=store_options) - return PasskeyLoginResult(state_data=state_data.dict()) + return PasskeyLoginResult(state_data=state_data.model_dump()) except Exception as e: if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError, ApiError, IssuerValidationError)): diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index c6caf8c..44ec918 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -688,7 +688,7 @@ class PasskeyAuthenticatorSelection(BaseModel): class PasskeyPublicKeyOptions(BaseModel): - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, extra="allow") challenge: str rp: Optional[PasskeyRpInfo] = None rp_id: Optional[str] = Field(None, alias="rpId") @@ -717,6 +717,7 @@ class EnrollAuthenticationMethodRequest(BaseModel): class EnrollmentChallengeResponse(BaseModel): + model_config = ConfigDict(extra="allow") authentication_method_id: str auth_session: str authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None @@ -737,7 +738,7 @@ class PasskeyAuthResponse(BaseModel): type: str authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") response: dict[str, str] - client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") + client_extension_results: Optional[dict[str, Any]] = Field(None, alias="clientExtensionResults") class VerifyAuthenticationMethodRequest(BaseModel): @@ -803,31 +804,26 @@ class PasskeyUserProfile(BaseModel): family_name: Optional[str] = None nickname: Optional[str] = None picture: Optional[str] = None - user_metadata: Optional[dict[str, Any]] = None -class PasskeySignupChallengeResponse(BaseModel): +class _PasskeyChallengeResponseBase(BaseModel): auth_session: str authn_params_public_key: PasskeyPublicKeyOptions def __repr__(self) -> str: return ( - f"PasskeySignupChallengeResponse(" + f"{self.__class__.__name__}(" f"auth_session=[REDACTED], " f"authn_params_public_key={self.authn_params_public_key!r})" ) -class PasskeyLoginChallengeResponse(BaseModel): - auth_session: str - authn_params_public_key: PasskeyPublicKeyOptions +class PasskeySignupChallengeResponse(_PasskeyChallengeResponseBase): + pass - def __repr__(self) -> str: - return ( - f"PasskeyLoginChallengeResponse(" - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) + +class PasskeyLoginChallengeResponse(_PasskeyChallengeResponseBase): + pass class PasskeyTokenResponse(BaseModel): diff --git a/src/auth0_server_python/tests/test_my_account_client.py b/src/auth0_server_python/tests/test_my_account_client.py index da2875d..6f254c1 100644 --- a/src/auth0_server_python/tests/test_my_account_client.py +++ b/src/auth0_server_python/tests/test_my_account_client.py @@ -804,6 +804,36 @@ async def test_enroll_authentication_method_success(mocker): assert result.authn_params_public_key.user.display_name == "Test User" +@pytest.mark.asyncio +async def test_enroll_authentication_method_public_key_extra_fields_preserved(mocker): + """Unknown WebAuthn fields (excludeCredentials, attestation, extensions) must not be dropped.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_abc", + "authn_params_public_key": { + "challenge": "dGVzdA", + "rp": {"id": "auth0.local", "name": "My App"}, + "user": {"id": "dXNlcl8x", "name": "user@test.com"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "excludeCredentials": [{"type": "public-key", "id": "Y3JlZA"}], + "attestation": "direct", + "extensions": {"appid": "https://auth0.local"}, + }, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + pk = result.authn_params_public_key + assert pk.model_extra["excludeCredentials"] == [{"type": "public-key", "id": "Y3JlZA"}] + assert pk.model_extra["attestation"] == "direct" + assert pk.model_extra["extensions"] == {"appid": "https://auth0.local"} + + @pytest.mark.asyncio async def test_enroll_authentication_method_missing_location(mocker): client = MyAccountClient(domain="auth0.local") @@ -848,6 +878,49 @@ async def test_enroll_authentication_method_location_absolute_url(mocker): assert result.authentication_method_id == "am_xyz" +@pytest.mark.asyncio +async def test_enroll_authentication_method_totp_preserves_secret(mocker): + """TOTP enrollment response includes totp_secret and barcode_uri — must not be dropped.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/totp|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_totp", + "totp_secret": "JBSWY3DPEHPK3PXP", + "barcode_uri": "otpauth://totp/Example:alice@example.com?secret=JBSWY3DPEHPK3PXP", + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="totp") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert result.authentication_method_id == "totp|new" + assert result.auth_session == "session_totp" + assert result.model_extra["totp_secret"] == "JBSWY3DPEHPK3PXP" + assert result.model_extra["barcode_uri"].startswith("otpauth://") + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_oob_preserves_oob_code(mocker): + """OOB (email/phone) enrollment response includes oob_code — must not be dropped.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/email|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_oob", + "oob_code": "oob_abc123", + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="email") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert result.authentication_method_id == "email|new" + assert result.model_extra["oob_code"] == "oob_abc123" + + @pytest.mark.asyncio async def test_verify_authentication_method_success(mocker): client = MyAccountClient(domain="auth0.local") diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 88d565a..3c271ea 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -4855,6 +4855,14 @@ async def _fake_fetch(self, domain): "scope": "openid profile", } +_PASSKEY_TOKEN_RESPONSE_DPOP = { + "access_token": "at_passkey_dpop_123", + "id_token": "eyJ.test.jwt", + "token_type": "DPoP", + "expires_in": 86400, + "scope": "openid profile", +} + def _make_passkey_authn_response(): from auth0_server_python.auth_types import PasskeyAuthResponse @@ -4939,8 +4947,8 @@ async def test_passkey_signup_challenge_user_profile_fields(mocker): family_name="Doe", nickname="jd", picture="https://example.com/pic.jpg", - user_metadata={"role": "admin"}, ), + user_metadata={"role": "admin"}, organization="org_123", ) @@ -4953,7 +4961,8 @@ async def test_passkey_signup_challenge_user_profile_fields(mocker): assert body["user_profile"]["family_name"] == "Doe" assert body["user_profile"]["nickname"] == "jd" assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert "user_metadata" not in body["user_profile"] + assert body["user_metadata"] == {"role": "admin"} assert body["organization"] == "org_123" @@ -4979,10 +4988,38 @@ async def test_passkey_signup_challenge_minimal_body(mocker): body = kwargs["json"] assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} assert "user_profile" not in body + assert "user_metadata" not in body assert "realm" not in body assert "organization" not in body +@pytest.mark.asyncio +async def test_passkey_signup_challenge_user_metadata_root_level(mocker): + """user_metadata must be sent at root level, not nested inside user_profile.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_signup_challenge( + user_metadata={"preferred_language": "en"}, + ) + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["user_metadata"] == {"preferred_language": "en"} + assert "user_profile" not in body + + @pytest.mark.asyncio async def test_passkey_signup_challenge_api_error(mocker): from auth0_server_python.auth_types import PasskeyUserProfile @@ -5085,6 +5122,34 @@ async def test_passkey_login_challenge_success(mocker): assert body["client_id"] == "test_client_id" assert body["realm"] == "Username-Password-Authentication" assert body["organization"] == "org_abc" + assert "username" not in body + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_minimal_body(mocker): + """No optional fields sent when called with no arguments.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_login_challenge() + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} + assert "username" not in body + assert "realm" not in body + assert "organization" not in body @pytest.mark.asyncio @@ -5534,7 +5599,7 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 - mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE_DPOP) mock_post.return_value = mock_response dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") @@ -5589,7 +5654,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): success_response = AsyncMock() success_response.status_code = 200 - success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE_DPOP) mock_post.side_effect = [nonce_response, success_response] @@ -5601,7 +5666,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): ) assert mock_post.await_count == 2 - assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_123" + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_dpop_123" # Second call must include the nonce in the DPoP proof second_call_kwargs = mock_post.call_args_list[1][1] @@ -5612,6 +5677,74 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): assert payload["nonce"] == "server-nonce-abc" +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_rejects_bearer_downgrade(mocker): + """Server returning token_type=Bearer when DPoP was requested must raise PasskeyError.""" + from auth0_server_python.error import PasskeyError + from jwcrypto import jwk as jwk_module + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + assert "DPoP" in str(exc.value) or "token_type" in str(exc.value).lower() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_issuer_in_metadata(mocker): + """Missing 'issuer' in OIDC metadata must raise IssuerValidationError, not silently pass.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc: + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + assert "issuer" in str(exc.value).lower() + + @pytest.mark.asyncio async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): client = ServerClient( From 786b4ece19a6c759994a9ce811873e72f0fc230b Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Mon, 15 Jun 2026 14:51:39 +0530 Subject: [PATCH 10/12] MFA Support for Passkeys --- .../auth_server/mfa_client.py | 95 ++++-- .../auth_server/server_client.py | 69 ++-- .../tests/test_mfa_client.py | 96 +++--- .../tests/test_server_client.py | 301 ++++++++++++++++-- 4 files changed, 449 insertions(+), 112 deletions(-) diff --git a/src/auth0_server_python/auth_server/mfa_client.py b/src/auth0_server_python/auth_server/mfa_client.py index 3904203..bfe1e08 100644 --- a/src/auth0_server_python/auth_server/mfa_client.py +++ b/src/auth0_server_python/auth_server/mfa_client.py @@ -38,6 +38,7 @@ ) DEFAULT_MFA_TOKEN_TTL = 300 # 5 minutes +MFA_PENDING_IDENTIFIER = "_a0_mfa_pending" class MfaClient: @@ -47,9 +48,9 @@ class MfaClient: Provides methods for listing authenticators, enrolling new authenticators, deleting authenticators, challenging authenticators, and verifying MFA codes. - All API operations require a raw mfa_token. If the token was encrypted - (e.g. from MfaRequiredError raised by get_access_token()), use - decrypt_mfa_token() first to obtain the raw token. + All public API methods accept an encrypted mfa_token (as issued by + MfaRequiredError) and decrypt it internally. Callers never handle the + raw Auth0 mfa_token directly. """ def __init__( @@ -130,13 +131,63 @@ def decrypt_mfa_token(self, encrypted_token: str) -> MfaTokenContext: except Exception: raise MfaTokenInvalidError() - # Check TTL elapsed = int(time.time()) - context.created_at if elapsed > DEFAULT_MFA_TOKEN_TTL: raise MfaTokenExpiredError() return context + # ============================================================================ + # MFA STATE + # ============================================================================ + + async def store_pending_mfa( + self, + encrypted_token: str, + store_options: Optional[dict[str, Any]] = None, + ) -> None: + """Save an in-progress MFA token so challenge and verify can proceed without the client carrying the token.""" + if self._state_store: + await self._state_store.set( + MFA_PENDING_IDENTIFIER, + {"mfa_token": encrypted_token}, + options=store_options, + ) + + async def get_pending_mfa( + self, + store_options: Optional[dict[str, Any]] = None, + ) -> Optional[str]: + """Retrieve the in-progress MFA token if a challenge is pending for the current session, or None.""" + if not self._state_store: + return None + data = await self._state_store.get(MFA_PENDING_IDENTIFIER, store_options) + if data and isinstance(data, dict): + return data.get("mfa_token") + return None + + async def _clear_pending_mfa( + self, + store_options: Optional[dict[str, Any]] = None, + ) -> None: + """Clear the in-progress MFA state after successful verification.""" + if self._state_store: + await self._state_store.delete(MFA_PENDING_IDENTIFIER, store_options) + + def _resolve_encrypted_token( + self, + options: dict[str, Any], + ) -> str: + """ + Extract the encrypted mfa_token from options. + + Raises MfaTokenInvalidError if the key is absent or empty. + """ + token = options.get("mfa_token") + if not token: + raise MfaTokenInvalidError() + return token + # ============================================================================ # MFA API OPERATIONS # ============================================================================ @@ -159,7 +210,7 @@ async def list_authenticators( Raises: MfaListAuthenticatorsError: When the request fails. """ - mfa_token = options["mfa_token"] + context = self.decrypt_mfa_token(self._resolve_encrypted_token(options)) base_url = await self._resolve_base_url(store_options) url = f"{base_url}/mfa/authenticators" @@ -167,7 +218,7 @@ async def list_authenticators( async with self._get_http_client() as client: response = await client.get( url, - auth=BearerAuth(mfa_token) + auth=BearerAuth(context.mfa_token) ) if response.status_code != 200: @@ -207,7 +258,7 @@ async def enroll_authenticator( Raises: MfaEnrollmentError: When enrollment fails. """ - mfa_token = options["mfa_token"] + context = self.decrypt_mfa_token(self._resolve_encrypted_token(options)) factor_type = options["factor_type"] base_url = await self._resolve_base_url(store_options) url = f"{base_url}/mfa/associate" @@ -243,7 +294,7 @@ async def enroll_authenticator( response = await client.post( url, json=body, - auth=BearerAuth(mfa_token), + auth=BearerAuth(context.mfa_token), headers={"Content-Type": "application/json"} ) @@ -292,7 +343,7 @@ async def challenge_authenticator( Raises: MfaChallengeError: When the challenge fails. """ - mfa_token = options["mfa_token"] + context = self.decrypt_mfa_token(self._resolve_encrypted_token(options)) factor_type = options["factor_type"] base_url = await self._resolve_base_url(store_options) url = f"{base_url}/mfa/challenge" @@ -308,7 +359,7 @@ async def challenge_authenticator( ) body: dict[str, Any] = { - "mfa_token": mfa_token, + "mfa_token": context.mfa_token, "client_id": self._client_id, "client_secret": self._client_secret, "challenge_type": challenge_type @@ -373,13 +424,12 @@ async def verify( MfaVerifyError: When verification fails. MfaRequiredError: When chained MFA is required. """ - mfa_token = options["mfa_token"] + context = self.decrypt_mfa_token(self._resolve_encrypted_token(options)) - # Determine grant type and build body body: dict[str, Any] = { "client_id": self._client_id, "client_secret": self._client_secret, - "mfa_token": mfa_token + "mfa_token": context.mfa_token } if "otp" in options: @@ -412,19 +462,23 @@ async def verify( if response.status_code != 200: error_data = response.json() - # Handle chained MFA — token is raw; encryption is the - # framework SDK's responsibility (see ServerClient.get_access_token). if error_data.get("error") == "mfa_required": - new_mfa_token = error_data.get("mfa_token") + new_raw_token = error_data.get("mfa_token") mfa_requirements_data = error_data.get("mfa_requirements") mfa_requirements = None if mfa_requirements_data: mfa_requirements = MfaRequirements(**mfa_requirements_data) + new_encrypted = self._encrypt_mfa_token( + raw_mfa_token=new_raw_token, + audience=context.audience, + scope=context.scope, + mfa_requirements=mfa_requirements, + ) raise MfaRequiredError( error_data.get("error_description", "Additional MFA factor required"), - mfa_token=new_mfa_token, - mfa_requirements=mfa_requirements + mfa_token=new_encrypted, + mfa_requirements=mfa_requirements, ) raise MfaVerifyError( @@ -435,11 +489,12 @@ async def verify( token_response = response.json() verify_response = MfaVerifyResponse(**token_response) - # Persist tokens to state store if requested + await self._clear_pending_mfa(store_options) + if options.get("persist") and self._state_store: await self._persist_mfa_tokens( verify_response=verify_response, - options=options, + options={**options, "audience": options.get("audience") or context.audience, "scope": options.get("scope") or context.scope}, store_options=store_options ) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 22f6d21..7997804 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -1029,24 +1029,9 @@ async def get_access_token( return token_endpoint_response["access_token"] except Exception as e: - # Check for mfa_required error from token refresh - if isinstance(e, ApiError) and e.code == "mfa_required": - raw_mfa_token = getattr(e, "mfa_token", None) - mfa_requirements = getattr(e, "mfa_requirements", None) - - if raw_mfa_token: - encrypted_token = self._mfa_client._encrypt_mfa_token( - raw_mfa_token=raw_mfa_token, - audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, - scope=merged_scope or "", - mfa_requirements=mfa_requirements - ) - raise MfaRequiredError( - "Multifactor authentication required", - mfa_token=encrypted_token, - mfa_requirements=mfa_requirements - ) - + if isinstance(e, MfaRequiredError): + await self._mfa_client.store_pending_mfa(e.mfa_token, store_options) + raise if isinstance(e, AccessTokenError): raise raise AccessTokenError( @@ -1117,18 +1102,22 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, error_data = response.json() error_code = error_data.get("error", "refresh_token_error") - # Preserve mfa_required details for upstream handling if error_code == "mfa_required": - error = ApiError( - error_code, - error_data.get("error_description", "MFA required") - ) - error.mfa_token = error_data.get("mfa_token") + raw_mfa_token = error_data.get("mfa_token") mfa_requirements_data = error_data.get("mfa_requirements") - error.mfa_requirements = None - if mfa_requirements_data: - error.mfa_requirements = MfaRequirements(**mfa_requirements_data) - raise error + mfa_requirements = MfaRequirements(**mfa_requirements_data) if mfa_requirements_data else None + if raw_mfa_token: + encrypted_token = self._mfa_client._encrypt_mfa_token( + raw_mfa_token=raw_mfa_token, + audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, + scope=merged_scope or "", + mfa_requirements=mfa_requirements, + ) + raise MfaRequiredError( + error_data.get("error_description", "MFA required"), + mfa_token=encrypted_token, + mfa_requirements=mfa_requirements, + ) raise ApiError( error_code, @@ -1146,7 +1135,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, return token_response except Exception as e: - if isinstance(e, ApiError): + if isinstance(e, (ApiError, MfaRequiredError)): raise raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, @@ -2749,8 +2738,26 @@ async def signin_with_passkey( PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, f"Passkey token exchange failed with status {response.status_code}", ) + error_code = error_data.get("error", PasskeyErrorCode.TOKEN_EXCHANGE_FAILED) + if error_code == "mfa_required": + raw_mfa_token = error_data.get("mfa_token") + mfa_requirements_data = error_data.get("mfa_requirements") + mfa_requirements = MfaRequirements(**mfa_requirements_data) if mfa_requirements_data else None + if raw_mfa_token: + encrypted_token = self._mfa_client._encrypt_mfa_token( + raw_mfa_token=raw_mfa_token, + audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, + scope=scope or "", + mfa_requirements=mfa_requirements, + ) + await self._mfa_client.store_pending_mfa(encrypted_token, store_options) + raise MfaRequiredError( + "Multifactor authentication required", + mfa_token=encrypted_token, + mfa_requirements=mfa_requirements, + ) raise PasskeyError( - error_data.get("error", PasskeyErrorCode.TOKEN_EXCHANGE_FAILED), + error_code, error_data.get("error_description", "Passkey token exchange failed"), ) @@ -2826,6 +2833,6 @@ async def signin_with_passkey( return PasskeyLoginResult(state_data=state_data.model_dump()) except Exception as e: - if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError, ApiError, IssuerValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError, ApiError, IssuerValidationError, MfaRequiredError)): raise raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Passkey sign-in failed", e) from e diff --git a/src/auth0_server_python/tests/test_mfa_client.py b/src/auth0_server_python/tests/test_mfa_client.py index ed93275..3533b57 100644 --- a/src/auth0_server_python/tests/test_mfa_client.py +++ b/src/auth0_server_python/tests/test_mfa_client.py @@ -42,6 +42,12 @@ def _make_client() -> MfaClient: ) +def _enc(raw: str = "raw_mfa_tok", audience: str = "default", scope: str = "") -> str: + """Encrypt a raw MFA token using the shared test secret.""" + client = _make_client() + return client._encrypt_mfa_token(raw, audience, scope) + + # ── Constructor ────────────────────────────────────────────────────────────── class TestMfaClientConstructor: @@ -121,7 +127,7 @@ async def test_resolver_failure_propagates_through_api_method(self, mocker): # list_authenticators wraps unexpected errors in MfaListAuthenticatorsError, # but DomainResolverError is NOT caught by the inner try/except — it propagates. with pytest.raises(DomainResolverError): - await client.list_authenticators({"mfa_token": "tok"}) + await client.list_authenticators({"mfa_token": _enc()}) @pytest.mark.asyncio async def test_store_options_forwarded_to_resolver(self): @@ -234,7 +240,7 @@ async def test_list_authenticators_success(self, mocker): ]) mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - result = await client.list_authenticators({"mfa_token": "mfa_tok"}) + result = await client.list_authenticators({"mfa_token": _enc("mfa_tok")}) assert len(result) == 2 assert isinstance(result[0], AuthenticatorResponse) assert result[0].id == "auth|123" @@ -252,7 +258,7 @@ async def test_list_authenticators_api_error(self, mocker): mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) with pytest.raises(MfaListAuthenticatorsError) as exc: - await client.list_authenticators({"mfa_token": "bad_tok"}) + await client.list_authenticators({"mfa_token": _enc("bad_tok")}) assert "Invalid MFA token" in str(exc.value) @pytest.mark.asyncio @@ -261,7 +267,7 @@ async def test_list_authenticators_unexpected_error(self, mocker): mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("network down")) with pytest.raises(MfaListAuthenticatorsError) as exc: - await client.list_authenticators({"mfa_token": "tok"}) + await client.list_authenticators({"mfa_token": _enc()}) assert "network down" in str(exc.value) @@ -282,7 +288,7 @@ async def test_enroll_otp_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.enroll_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "otp" }) assert isinstance(result, OtpEnrollmentResponse) @@ -303,7 +309,7 @@ async def test_enroll_sms_oob_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.enroll_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "sms", "phone_number": "+1234567890" }) @@ -323,7 +329,7 @@ async def test_enroll_email_oob_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.enroll_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "email", "email": "user@example.com" }) @@ -345,7 +351,7 @@ async def test_enroll_push_auth0_channel_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.enroll_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "auth0" }) assert isinstance(result, OobEnrollmentResponse) @@ -364,7 +370,7 @@ async def test_enroll_api_error(self, mocker): with pytest.raises(MfaEnrollmentError) as exc: await client.enroll_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "otp" }) assert "Bad enrollment request" in str(exc.value) @@ -381,7 +387,7 @@ async def test_enroll_unexpected_authenticator_type(self, mocker): with pytest.raises(MfaEnrollmentError) as exc: await client.enroll_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "unknown" }) assert "Unsupported factor_type" in str(exc.value) @@ -401,7 +407,7 @@ async def test_challenge_otp_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.challenge_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "otp" }) assert isinstance(result, ChallengeResponse) @@ -420,7 +426,7 @@ async def test_challenge_oob_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.challenge_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "sms", "authenticator_id": "auth|456" }) @@ -440,7 +446,7 @@ async def test_challenge_api_error(self, mocker): with pytest.raises(MfaChallengeError) as exc: await client.challenge_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "otp" }) assert "Token expired" in str(exc.value) @@ -459,7 +465,7 @@ async def test_challenge_expired_mfa_token(self, mocker): with pytest.raises(MfaChallengeError) as exc: await client.challenge_authenticator({ - "mfa_token": "expired_tok", + "mfa_token": _enc("expired_tok"), "factor_type": "otp" }) assert "mfa_token is expired" in str(exc.value) @@ -478,7 +484,7 @@ async def test_challenge_email_with_authenticator_id(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.challenge_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "email", "authenticator_id": "email|dev_Fvx38nHufsGL5lWI" }) @@ -500,7 +506,7 @@ async def test_challenge_sms_with_authenticator_id(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.challenge_authenticator({ - "mfa_token": "tok", + "mfa_token": _enc(), "factor_type": "sms", "authenticator_id": "sms|dev_h1uXXoVjQ5BpU9iQ" }) @@ -524,7 +530,7 @@ async def test_verify_otp_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.verify({ - "mfa_token": "tok", + "mfa_token": _enc(), "otp": "123456" }) assert isinstance(result, MfaVerifyResponse) @@ -543,7 +549,7 @@ async def test_verify_oob_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.verify({ - "mfa_token": "tok", + "mfa_token": _enc(), "oob_code": "oob_123", "binding_code": "bind_456" }) @@ -562,7 +568,7 @@ async def test_verify_recovery_code_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.verify({ - "mfa_token": "tok", + "mfa_token": _enc(), "recovery_code": "ABCD-1234-EFGH" }) assert isinstance(result, MfaVerifyResponse) @@ -571,7 +577,7 @@ async def test_verify_recovery_code_success(self, mocker): async def test_verify_no_credential_raises(self): client = _make_client() with pytest.raises(MfaVerifyError) as exc: - await client.verify({"mfa_token": "tok"}) + await client.verify({"mfa_token": _enc()}) assert "No verification credential" in str(exc.value) @pytest.mark.asyncio @@ -596,7 +602,7 @@ async def mock_post(self_client, url, **kwargs): mocker.patch("httpx.AsyncClient.post", new=mock_post) await client.verify({ - "mfa_token": "my_mfa_token", + "mfa_token": _enc("my_mfa_token"), "otp": "123456" }) @@ -622,7 +628,7 @@ async def test_verify_expired_mfa_token(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) with pytest.raises(MfaVerifyError) as exc: - await client.verify({"mfa_token": "expired_tok", "otp": "123456"}) + await client.verify({"mfa_token": _enc("expired_tok"), "otp": "123456"}) assert "mfa_token is expired" in str(exc.value) @pytest.mark.asyncio @@ -638,7 +644,7 @@ async def test_verify_invalid_challenge_type(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) with pytest.raises(MfaVerifyError) as exc: - await client.verify({"mfa_token": "tok", "recovery_code": "ABCD-1234"}) + await client.verify({"mfa_token": _enc(), "recovery_code": "ABCD-1234"}) assert "Invalid challenge type" in str(exc.value) @pytest.mark.asyncio @@ -656,7 +662,7 @@ async def test_verify_response_includes_recovery_code(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.verify({ - "mfa_token": "tok", + "mfa_token": _enc(), "recovery_code": "OLD-RECOVERY-CODE" }) assert isinstance(result, MfaVerifyResponse) @@ -676,7 +682,7 @@ async def test_verify_push_oob_success(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.verify({ - "mfa_token": "tok", + "mfa_token": _enc(), "oob_code": "oob_push_code", "binding_code": "" }) @@ -695,7 +701,7 @@ async def test_verify_wrong_code_raises(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) with pytest.raises(MfaVerifyError) as exc: - await client.verify({"mfa_token": "tok", "otp": "000000"}) + await client.verify({"mfa_token": _enc(), "otp": "000000"}) assert "Invalid OTP" in str(exc.value) @pytest.mark.asyncio @@ -711,9 +717,12 @@ async def test_verify_chained_mfa_raises_mfa_required(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) with pytest.raises(MfaRequiredError) as exc: - await client.verify({"mfa_token": "tok", "otp": "123456"}) - assert exc.value.mfa_token == "new_raw_mfa_token" + await client.verify({"mfa_token": _enc(), "otp": "123456"}) + assert exc.value.mfa_token is not None + assert exc.value.mfa_token != "new_raw_mfa_token" # must be encrypted assert exc.value.code == "mfa_required" + decrypted = client.decrypt_mfa_token(exc.value.mfa_token) + assert decrypted.mfa_token == "new_raw_mfa_token" @pytest.mark.asyncio async def test_verify_unexpected_error(self, mocker): @@ -721,7 +730,7 @@ async def test_verify_unexpected_error(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, side_effect=Exception("connection reset")) with pytest.raises(MfaVerifyError) as exc: - await client.verify({"mfa_token": "tok", "otp": "123456"}) + await client.verify({"mfa_token": _enc(), "otp": "123456"}) assert "connection reset" in str(exc.value) @pytest.mark.asyncio @@ -751,7 +760,7 @@ async def test_verify_persist_updates_session(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) await client.verify( - {"mfa_token": "tok", "otp": "123456", + {"mfa_token": _enc(), "otp": "123456", "persist": True, "audience": "https://api.example.com"} ) @@ -762,8 +771,15 @@ async def test_verify_persist_updates_session(self, mocker): assert saved_state["token_sets"][0]["access_token"] == "new_at_from_mfa" @pytest.mark.asyncio - async def test_verify_persist_missing_audience_raises(self, mocker): + async def test_verify_persist_uses_context_audience_when_not_in_options(self, mocker): + """persist=True without explicit audience falls back to audience in encrypted token context.""" store = AsyncMock() + store.get = AsyncMock(return_value={ + "user": {"sub": "auth0|123"}, + "id_token": "id", + "token_sets": [], + "internal": {"sid": "s", "created_at": 1000} + }) client = MfaClient( domain=DOMAIN, client_id=CLIENT_ID, client_secret=CLIENT_SECRET, secret=SECRET, state_store=store @@ -776,10 +792,14 @@ async def test_verify_persist_missing_audience_raises(self, mocker): }) mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - with pytest.raises(MfaVerifyError, match="audience is required"): - await client.verify( - {"mfa_token": "tok", "otp": "123456", "persist": True} - ) + result = await client.verify( + {"mfa_token": _enc(audience="https://api.example.com"), "otp": "123456", "persist": True} + ) + assert result.access_token == "at" + store.set.assert_called_once() + saved_state = store.set.call_args[0][1] + saved_audience = saved_state["token_sets"][0]["audience"] + assert saved_audience == "https://api.example.com" @pytest.mark.asyncio async def test_verify_persist_no_existing_session_raises(self, mocker): @@ -799,7 +819,7 @@ async def test_verify_persist_no_existing_session_raises(self, mocker): with pytest.raises(MfaVerifyError, match="No existing session"): await client.verify( - {"mfa_token": "tok", "otp": "123456", + {"mfa_token": _enc(), "otp": "123456", "persist": True, "audience": "https://api.example.com"} ) @@ -816,7 +836,7 @@ async def test_verify_persist_skipped_when_no_state_store(self, mocker): mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) result = await client.verify( - {"mfa_token": "tok", "otp": "123456", + {"mfa_token": _enc(), "otp": "123456", "persist": True, "audience": "https://api.example.com"} ) assert result.access_token == "at" @@ -843,6 +863,6 @@ async def test_verify_persist_store_failure_raises(self, mocker): with pytest.raises(MfaVerifyError, match="Failed to persist"): await client.verify( - {"mfa_token": "tok", "otp": "123456", + {"mfa_token": _enc(), "otp": "123456", "persist": True, "audience": "https://api.example.com"} ) diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 3c271ea..3bc9cdb 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -2320,6 +2320,40 @@ async def test_get_token_by_refresh_token_exchange_failed(mocker): args, kwargs = mock_post.call_args assert kwargs["data"]["refresh_token"] == "" + +@pytest.mark.asyncio +async def test_get_token_by_refresh_token_mfa_required_raises_mfa_required_error(mocker): + """get_token_by_refresh_token raises MfaRequiredError (not ApiError) with encrypted token.""" + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + secret="a-test-secret-with-enough-length", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) + + fail_response = AsyncMock() + fail_response.status_code = 403 + fail_response.json = MagicMock(return_value={ + "error": "mfa_required", + "error_description": "MFA required", + "mfa_token": "raw_server_mfa_token", + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=fail_response) + + with pytest.raises(MfaRequiredError) as exc: + await client.get_token_by_refresh_token({"refresh_token": "rt_abc"}) + + assert exc.value.mfa_token is not None + assert exc.value.mfa_token != "raw_server_mfa_token" + decrypted = client._mfa_client.decrypt_mfa_token(exc.value.mfa_token) + assert decrypted.mfa_token == "raw_server_mfa_token" + + # ============================================================================= # Connected Accounts Tests (My Account Client) # ============================================================================= @@ -4684,8 +4718,8 @@ async def _fake_fetch(self, domain): @pytest.mark.asyncio async def test_get_access_token_mfa_required(mocker): """ - When get_token_by_refresh_token returns an mfa_required error, - get_access_token should raise MfaRequiredError with an encrypted mfa_token. + When get_token_by_refresh_token returns MfaRequiredError, + get_access_token re-raises it (token is already encrypted by get_token_by_refresh_token). """ mock_secret = "a-test-secret-with-enough-length" mock_store = MagicMock() @@ -4714,7 +4748,6 @@ async def _fake_fetch(self, domain): state_store=mock_store, ) - # Simulate state with a refresh_token and expired access token mock_store.get = AsyncMock(return_value={ "refresh_token": "rt_123", "token_sets": [ @@ -4726,13 +4759,15 @@ async def _fake_fetch(self, domain): ] }) - # Simulate mfa_required ApiError from token refresh - mfa_err = ApiError( - code="mfa_required", - message="Multifactor authentication required", + encrypted_token = client._mfa_client._encrypt_mfa_token( + raw_mfa_token="raw_mfa_token_xyz", + audience="default", + scope="", + ) + mfa_err = MfaRequiredError( + "Multifactor authentication required", + mfa_token=encrypted_token, ) - mfa_err.mfa_token = "raw_mfa_token_xyz" - mfa_err.mfa_requirements = None mocker.patch.object(client, "get_token_by_refresh_token", new_callable=AsyncMock, side_effect=mfa_err) @@ -4740,8 +4775,7 @@ async def _fake_fetch(self, domain): with pytest.raises(MfaRequiredError) as exc: await client.get_access_token() - assert exc.value.mfa_token is not None - assert exc.value.mfa_token != "raw_mfa_token_xyz" # encrypted + assert exc.value.mfa_token == encrypted_token finally: ServerClient._fetch_oidc_metadata = original_fetch @@ -4749,8 +4783,8 @@ async def _fake_fetch(self, domain): @pytest.mark.asyncio async def test_get_access_token_mfa_required_with_enroll_requirements(mocker): """ - When get_token_by_refresh_token returns mfa_required with enroll requirements, - get_access_token should raise MfaRequiredError with mfa_requirements containing enroll. + When get_token_by_refresh_token returns MfaRequiredError with mfa_requirements, + get_access_token re-raises it preserving requirements. """ mock_secret = "a-test-secret-with-enough-length" mock_store = MagicMock() @@ -4779,7 +4813,6 @@ async def _fake_fetch(self, domain): state_store=mock_store, ) - # Simulate state with a refresh_token and expired access token mock_store.get = AsyncMock(return_value={ "refresh_token": "rt_123", "token_sets": [ @@ -4791,19 +4824,24 @@ async def _fake_fetch(self, domain): ] }) - # Simulate mfa_required with enroll requirements - mfa_err = ApiError( - code="mfa_required", - message="Multifactor authentication required", - ) - mfa_err.mfa_token = "raw_mfa_token_enroll" - mfa_err.mfa_requirements = MfaRequirements( + requirements = MfaRequirements( enroll=[ {"type": "otp"}, {"type": "phone"}, {"type": "push-notification"} ] ) + encrypted_token = client._mfa_client._encrypt_mfa_token( + raw_mfa_token="raw_mfa_token_enroll", + audience="default", + scope="", + mfa_requirements=requirements, + ) + mfa_err = MfaRequiredError( + "Multifactor authentication required", + mfa_token=encrypted_token, + mfa_requirements=requirements, + ) mocker.patch.object(client, "get_token_by_refresh_token", new_callable=AsyncMock, side_effect=mfa_err) @@ -4811,8 +4849,7 @@ async def _fake_fetch(self, domain): with pytest.raises(MfaRequiredError) as exc: await client.get_access_token() - assert exc.value.mfa_token is not None - assert exc.value.mfa_token != "raw_mfa_token_enroll" # encrypted + assert exc.value.mfa_token == encrypted_token assert exc.value.mfa_requirements is not None finally: ServerClient._fetch_oidc_metadata = original_fetch @@ -5871,3 +5908,221 @@ async def test_signin_with_passkey_session_without_id_token(mocker): state_store.set.assert_awaited_once() assert result.state_data["user"] is None assert result.state_data["token_sets"][0]["access_token"] == "at_no_id_token" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_mfa_required_raises_mfa_required_error(mocker): + """Server returns 403 mfa_required — SDK raises MfaRequiredError with encrypted token.""" + from auth0_server_python.error import MfaRequiredError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 403 + mock_response.json = MagicMock(return_value={ + "error": "mfa_required", + "error_description": "MFA required", + "mfa_token": "raw_mfa_token_xyz", + }) + mock_post.return_value = mock_response + + with pytest.raises(MfaRequiredError) as exc: + await client.signin_with_passkey( + auth_session="session_abc", + authn_response=_make_passkey_authn_response(), + ) + assert exc.value.mfa_token is not None + assert exc.value.mfa_token != "raw_mfa_token_xyz" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_mfa_required_with_requirements(mocker): + """mfa_required response including mfa_requirements is propagated correctly.""" + from auth0_server_python.error import MfaRequiredError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 403 + mock_response.json = MagicMock(return_value={ + "error": "mfa_required", + "error_description": "MFA required", + "mfa_token": "raw_mfa_token_xyz", + "mfa_requirements": {"challengeTypes": ["oob"], "mfaToken": "raw_mfa_token_xyz"}, + }) + mock_post.return_value = mock_response + + with pytest.raises(MfaRequiredError) as exc: + await client.signin_with_passkey( + auth_session="session_abc", + authn_response=_make_passkey_authn_response(), + ) + assert exc.value.mfa_token is not None + assert exc.value.mfa_requirements is not None + + +@pytest.mark.asyncio +async def test_signin_with_passkey_mfa_required_without_mfa_token_falls_through(mocker): + """mfa_required response missing mfa_token raises PasskeyError (server misconfiguration).""" + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 403 + mock_response.json = MagicMock(return_value={ + "error": "mfa_required", + "error_description": "MFA required", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="session_abc", + authn_response=_make_passkey_authn_response(), + ) + assert exc.value.code == "mfa_required" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_mfa_required_stores_pending_mfa(mocker): + """When signin_with_passkey raises MfaRequiredError, the encrypted token is stored in the state store.""" + mock_store = AsyncMock() + mock_store.get = AsyncMock(return_value=None) + mock_store.set = AsyncMock() + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=mock_store, + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 403 + mock_response.json = MagicMock(return_value={ + "error": "mfa_required", + "error_description": "MFA required", + "mfa_token": "raw_mfa_token_xyz", + }) + mock_post.return_value = mock_response + + with pytest.raises(MfaRequiredError) as exc: + await client.signin_with_passkey( + auth_session="session_abc", + authn_response=_make_passkey_authn_response(), + ) + + mock_store.set.assert_called_once() + store_key, store_payload = mock_store.set.call_args[0][:2] + assert store_key == "_a0_mfa_pending" + assert store_payload["mfa_token"] == exc.value.mfa_token + + +@pytest.mark.asyncio +async def test_get_access_token_mfa_required_stores_pending_mfa(mocker): + """When get_access_token raises MfaRequiredError, the encrypted token is stored in the state store.""" + mock_secret = "a-test-secret-with-enough-length" + mock_store = AsyncMock() + mock_store.get = AsyncMock(return_value={ + "refresh_token": "rt_123", + "token_sets": [ + {"audience": "default", "access_token": "expired_at", "expires_at": 0} + ] + }) + mock_store.set = AsyncMock() + mock_store.delete = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="cid", + client_secret="csecret", + secret=mock_secret, + transaction_store=mock_store, + state_store=mock_store, + ) + + encrypted_token = client._mfa_client._encrypt_mfa_token( + raw_mfa_token="raw_mfa_token_xyz", + audience="default", + scope="", + ) + mfa_err = MfaRequiredError( + "Multifactor authentication required", + mfa_token=encrypted_token, + ) + mocker.patch.object(client, "get_token_by_refresh_token", + new_callable=AsyncMock, side_effect=mfa_err) + + with pytest.raises(MfaRequiredError): + await client.get_access_token() + + pending_calls = [ + call for call in mock_store.set.call_args_list + if call[0][0] == "_a0_mfa_pending" + ] + assert len(pending_calls) == 1 + assert pending_calls[0][0][1]["mfa_token"] == encrypted_token + + +@pytest.mark.asyncio +async def test_mfa_client_store_and_get_pending_mfa(): + """store_pending_mfa / get_pending_mfa roundtrip through the state store.""" + store = AsyncMock() + store.set = AsyncMock() + store.get = AsyncMock(return_value={"mfa_token": "enc_tok"}) + + client = MfaClient( + domain="auth0.local", + client_id="cid", + client_secret="csecret", + secret="a-test-secret-with-enough-length", + state_store=store, + ) + + await client.store_pending_mfa("enc_tok") + store.set.assert_called_once_with( + "_a0_mfa_pending", {"mfa_token": "enc_tok"}, options=None + ) + + result = await client.get_pending_mfa() + assert result == "enc_tok" From 22d2997b4ae66de5cf2a6a13a538e1863cd4b009 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Thu, 25 Jun 2026 16:02:01 +0530 Subject: [PATCH 11/12] Changes for MyAccount Factors Schema --- .../auth_server/my_account_client.py | 5 ++- .../auth_types/__init__.py | 7 ++-- .../tests/test_my_account_client.py | 37 ++++++++++++++++--- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 5e10b60..4a09ed3 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -412,7 +412,10 @@ async def get_factors( validation_errors=error_data.get("validation_errors", None), ) - return GetFactorsResponse.model_validate(response.json()) + # Auth0 /me/v1/factors returns a plain array, not {"factors":[...]} + raw = response.json() + payload = raw if isinstance(raw, dict) else {"factors": raw} + return GetFactorsResponse.model_validate(payload) except Exception as e: if isinstance(e, (MyAccountApiError, ApiError)): diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 0433ae7..485a4eb 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -717,7 +717,7 @@ class EnrollAuthenticationMethodRequest(BaseModel): email: Optional[str] = None phone_number: Optional[str] = None preferred_authentication_method: Optional[PreferredAuthMethod] = None - user_identity_id: Optional[str] = None + identity_user_id: Optional[str] = None # OAS: IdentityAuthenticationMethodBase.identity_user_id connection: Optional[str] = None @@ -790,9 +790,8 @@ class ListAuthenticationMethodsResponse(BaseModel): class Factor(BaseModel): model_config = ConfigDict(extra="allow") - name: str - enabled: Optional[bool] = None - trial_expired: Optional[bool] = None + type: str + usage: Optional[list[str]] = None class GetFactorsResponse(BaseModel): diff --git a/src/auth0_server_python/tests/test_my_account_client.py b/src/auth0_server_python/tests/test_my_account_client.py index 6f254c1..1737f32 100644 --- a/src/auth0_server_python/tests/test_my_account_client.py +++ b/src/auth0_server_python/tests/test_my_account_client.py @@ -525,15 +525,17 @@ async def test_get_factors_success(mocker): client = MyAccountClient(domain="auth0.local") response = AsyncMock() response.status_code = 200 - response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + response.json = MagicMock( + return_value={"factors": [{"type": "phone", "usage": ["primary"]}]} + ) mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) result = await client.get_factors(access_token="token123") assert isinstance(result, GetFactorsResponse) assert len(result.factors) == 1 - assert result.factors[0].name == "sms" - assert result.factors[0].enabled is True + assert result.factors[0].type == "phone" + assert result.factors[0].usage == ["primary"] @pytest.mark.asyncio @@ -596,12 +598,13 @@ async def test_get_factors_extra_fields(mocker): response = AsyncMock() response.status_code = 200 response.json = MagicMock(return_value={ - "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] + "factors": [{"type": "webauthn-roaming", "usage": ["secondary"], "future_field": "value"}] }) mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) result = await client.get_factors(access_token="token123") - assert result.factors[0].name == "webauthn-roaming" + assert result.factors[0].type == "webauthn-roaming" + assert result.factors[0].model_extra["future_field"] == "value" @pytest.mark.asyncio @@ -804,6 +807,26 @@ async def test_enroll_authentication_method_success(mocker): assert result.authn_params_public_key.user.display_name == "Test User" +@pytest.mark.asyncio +async def test_enroll_authentication_method_sends_identity_user_id(mocker): + """identity_user_id must serialize to the request body under its OAS wire key.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + req = EnrollAuthenticationMethodRequest(type="passkey", identity_user_id="auth0|abc123") + await client.enroll_authentication_method(access_token="token123", request=req) + + sent_body = mock_post.call_args[1]["json"] + assert sent_body["identity_user_id"] == "auth0|abc123" + assert "user_identity_id" not in sent_body + + @pytest.mark.asyncio async def test_enroll_authentication_method_public_key_extra_fields_preserved(mocker): """Unknown WebAuthn fields (excludeCredentials, attestation, extensions) must not be dropped.""" @@ -1012,7 +1035,9 @@ async def test_get_factors_with_dpop_key(mocker): client = MyAccountClient(domain="auth0.local") response = AsyncMock() response.status_code = 200 - response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + response.json = MagicMock( + return_value={"factors": [{"type": "phone", "usage": ["primary"]}]} + ) mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") From 69ea84794bf07ad5c7bcd0cc7c05ac1bd0a1e4bc Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Sun, 28 Jun 2026 19:43:32 +0530 Subject: [PATCH 12/12] Added docs and feedback changes --- README.md | 71 +++++ examples/DPoP.md | 133 +++++++++ examples/MyAccountAuthenticationMethods.md | 240 +++++++++++++++++ examples/Passkeys.md | 252 ++++++++++++++++++ .../auth_schemes/dpop_auth.py | 14 +- .../auth_server/mfa_client.py | 108 ++++++-- .../auth_server/server_client.py | 66 +++-- .../auth_types/__init__.py | 16 +- .../tests/test_dpop_auth.py | 44 ++- .../tests/test_mfa_client.py | 30 ++- .../tests/test_server_client.py | 73 ++++- 11 files changed, 962 insertions(+), 85 deletions(-) create mode 100644 examples/DPoP.md create mode 100644 examples/MyAccountAuthenticationMethods.md create mode 100644 examples/Passkeys.md diff --git a/README.md b/README.md index d0cefb7..ae68aa3 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,77 @@ The SDK handles per-domain OIDC discovery, JWKS fetching, issuer validation, and For more details and examples, see [examples/MultipleCustomDomains.md](examples/MultipleCustomDomains.md). +### 6. Passkey Authentication + +Sign users up or in with [WebAuthn](https://www.w3.org/TR/webauthn-2/) passkeys (Touch ID, Face ID, Windows Hello, or a security key) instead of a password. The ceremony is two steps — request a challenge, sign it in the browser, then complete sign-in — and establishes a server-side session like every other login path: + +```python +from auth0_server_python.auth_types import PasskeyUserProfile, PasskeyAuthResponse + +# Step 1 — request a challenge +challenge = await auth0.passkey_login_challenge( + store_options={"request": request, "response": response} +) + +# Step 2 — browser signs: navigator.credentials.get(challenge.authn_params_public_key) + +# Step 3 — complete sign-in and establish the session +result = await auth0.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=PasskeyAuthResponse(**credential), + store_options={"request": request, "response": response} +) + +user = result.state_data["user"] +``` + +For signup, organizations, step-up MFA, and error handling, see [examples/Passkeys.md](examples/Passkeys.md). + +### 7. My Account API — Authentication Methods + +Let a logged-in user manage their own enrolled authentication methods — enroll a new passkey (or other factor), list, rename, and delete — via the [My Account API](https://auth0.com/docs/manage-users/my-account-api): + +```python +from auth0_server_python.auth_server.my_account_client import MyAccountClient +from auth0_server_python.auth_types import EnrollAuthenticationMethodRequest + +# Obtain a My Account-scoped token for the current session (MRRT) +access_token = await auth0.get_access_token( + store_options={"request": request, "response": response}, + audience=f"https://{YOUR_CUSTOM_DOMAIN}/me/", + scope="create:me:authentication-methods read:me:authentication-methods", +) + +my_account = MyAccountClient(domain=YOUR_CUSTOM_DOMAIN) + +# Start enrolling a passkey (then sign it in the browser and verify) +challenge = await my_account.enroll_authentication_method( + access_token=access_token, + request=EnrollAuthenticationMethodRequest(type="passkey"), +) +``` + +For the full enroll/verify ceremony, listing, updating, deleting, and error handling, see [examples/MyAccountAuthenticationMethods.md](examples/MyAccountAuthenticationMethods.md). + +### 8. DPoP — Sender-Constrained Tokens + +Bind tokens to a key your server holds ([RFC 9449](https://www.rfc-editor.org/rfc/rfc9449)) so a stolen token alone cannot be replayed. Generate an EC P-256 key and pass it to passkey sign-in or any My Account API call: + +```python +from jwcrypto import jwk + +dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") # you create and keep this key + +result = await auth0.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=authn_response, + dpop_key=dpop_key, + store_options={"request": request, "response": response} +) +``` + +For the `dpop_key` vs `dpop_proof` distinction, key lifecycle, nonce handling, and error handling, see [examples/DPoP.md](examples/DPoP.md). + ## Feedback ### Contributing diff --git a/examples/DPoP.md b/examples/DPoP.md new file mode 100644 index 0000000..d33911d --- /dev/null +++ b/examples/DPoP.md @@ -0,0 +1,133 @@ +# DPoP — Sender-Constrained Tokens + +DPoP (Demonstrating Proof of Possession, [RFC 9449](https://www.rfc-editor.org/rfc/rfc9449)) binds an access token to a cryptographic key the client holds. A normal **Bearer** token is usable by anyone who holds it; a **DPoP-bound** token is useless without a matching proof signed by the private key — so a stolen token alone cannot be replayed. + +This SDK supports DPoP for **passkey sign-in** (`ServerClient.signin_with_passkey`) and for every **My Account API** call (`MyAccountClient`). + +> [!NOTE] +> DPoP is a confidential-client (Regular Web App) capability here: your server holds the key. The SDK does not store the key for you — you generate it and pass it in, so it lives in whatever secret store you choose (KMS/HSM/etc.). + +## Table of Contents + +- [`dpop_key` vs `dpop_proof`](#dpop_key-vs-dpop_proof) +- [1. Generate a key](#1-generate-a-key) +- [2. DPoP-bound passkey sign-in](#2-dpop-bound-passkey-sign-in) +- [3. DPoP on My Account API calls](#3-dpop-on-my-account-api-calls) +- [4. Generating a proof manually](#4-generating-a-proof-manually) +- [Key lifecycle and security](#key-lifecycle-and-security) +- [Error Handling](#error-handling) +- [Additional Resources](#additional-resources) + +## `dpop_key` vs `dpop_proof` + +These are **different things**, and the distinction is the whole mental model. You only ever handle the **key**; the SDK derives a fresh **proof** from it on every request. + +| | `dpop_key` | `dpop_proof` | +|---|------------|--------------| +| What it is | A long-lived **EC P-256 key pair** | A signed **JWT**, created fresh for one request | +| Lifetime | Reused across sign-in and every API call | Single-use — one per HTTP request | +| Who holds it | You (the private key never leaves your server) | Sent on the wire in the `DPoP:` header | +| Sensitivity | **Tier 0** — it is a secret | Not a stored secret — a short-lived derived artifact | +| In the SDK | The `dpop_key` parameter you pass in | Built internally — you never construct one | + +Think of `dpop_key` as a **signet ring** you keep, and `dpop_proof` as the **wax seal** you stamp on each letter: verifiably yours, but the seal from one letter is worthless on another. Each request the SDK mints a new proof (binding the HTTP method, the URL, a unique id, a timestamp, and — at the resource server — a hash of the access token), so a captured proof cannot be reused elsewhere. + +## 1. Generate a key + +The SDK uses `jwcrypto` (already a dependency). Generate one EC P-256 key and reuse the **same instance** for sign-in and for all subsequent API calls — the token is bound to that key. + +```python +from jwcrypto import jwk + +dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") +``` + +> [!NOTE] +> The key **must** be EC P-256 (Auth0 advertises `ES256` only). Passing an RSA or P-384 key raises `ValueError` before any network call — it fails closed. + +## 2. DPoP-bound passkey sign-in + +Pass `dpop_key` to `signin_with_passkey`. The SDK attaches a token-endpoint DPoP proof so Auth0 issues a DPoP-bound token, and **rejects a Bearer downgrade**: if a key was supplied but the server returns `token_type: Bearer`, it raises instead of silently accepting an unbound token. + +```python +result = await server_client.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=authn_response, + dpop_key=dpop_key, + store_options={"request": request, "response": response}, +) +``` + +See [examples/Passkeys.md](Passkeys.md) for the full passkey flow. + +## 3. DPoP on My Account API calls + +Every `MyAccountClient` method takes an optional `dpop_key`. Supply it and the call sends `Authorization: DPoP ` plus a fresh `DPoP:` proof header; omit it and the call uses a plain `Authorization: Bearer ` — no behaviour change for callers that don't need DPoP. + +```python +from auth0_server_python.auth_server.my_account_client import MyAccountClient + +my_account = MyAccountClient(domain="YOUR_CUSTOM_DOMAIN") + +methods = await my_account.list_authentication_methods( + access_token=access_token, # a DPoP-bound token from sign-in / MRRT + dpop_key=dpop_key, # the SAME key the token was bound to +) +``` + +> [!NOTE] +> If a `/me/v1/...` call is answered with `401 + DPoP-Nonce` (the server demanding a nonce), the SDK transparently retries the request **once** with the nonce embedded in the proof (RFC 9449 §9.1). The token endpoint nonce challenge (`400 + DPoP-Nonce`, §8.1) is handled the same way during sign-in. There is never more than one retry — it will not loop. + +## 4. Generating a proof manually + +For the token endpoint specifically (no access token exists yet, so the proof omits the `ath` claim), the SDK exposes a helper. You rarely need this — `signin_with_passkey` and the `MyAccountClient` methods build proofs for you — but it is available for custom token requests: + +```python +from auth0_server_python.auth_schemes.dpop_auth import make_dpop_proof_for_token_endpoint + +proof = make_dpop_proof_for_token_endpoint( + dpop_key, + "POST", + "https://YOUR_CUSTOM_DOMAIN/oauth/token", + # nonce="..." # supply when the server returned a DPoP-Nonce +) +# send as the "DPoP" request header +``` + +For resource-server requests, the `DPoPAuth` httpx handler (also exported from `auth_schemes`) builds the proof — including the `ath` token-hash claim — automatically. The `MyAccountClient` methods select it internally when you pass `dpop_key`. + +## Key lifecycle and security + +- **You own the key.** Generate it, store it in your secret store, and reuse the same instance for the bound token's lifetime. Discard it when the session ends. +- **One key, one bound token.** The token is bound to the key; using a different key on a later API call will be rejected by the resource server (`401 invalid_dpop_proof`). +- **The proof is request-specific.** Method, URL, a unique `jti`, and a timestamp are baked into every proof, so it cannot be replayed against a different endpoint or reused. +- **Never log the private key or a proof.** Treat the key as Tier 0 and proofs as transient secrets. The SDK's auth handlers redact the key and token in their `repr()`. + +## Error Handling + +DPoP failures surface through the error type of the operation that used the key: + +```python +from auth0_server_python.error import PasskeyError, MyAccountApiError, Auth0Error + +# Wrong key type — fails closed before any request +try: + await server_client.signin_with_passkey( + auth_session=auth_session, authn_response=authn_response, + dpop_key=rsa_key, # not EC P-256 + ) +except ValueError as e: + print(e) # "DPoP key must be an EC P-256 key" + +# Bearer downgrade when DPoP was requested +except PasskeyError as e: + print(e.code, e.message) # passkey_token_error — "DPoP token binding failed..." +``` + +On the My Account surface, a key mismatch or a DPoP-required endpoint reached without binding surfaces as `MyAccountApiError` (typically `status=401`). Catch `Auth0Error` for uniform handling. + +## Additional Resources + +- [Passkey Authentication](Passkeys.md) +- [My Account — Authentication Methods](MyAccountAuthenticationMethods.md) +- [RFC 9449 — OAuth 2.0 Demonstrating Proof of Possession (DPoP)](https://www.rfc-editor.org/rfc/rfc9449) diff --git a/examples/MyAccountAuthenticationMethods.md b/examples/MyAccountAuthenticationMethods.md new file mode 100644 index 0000000..9adca45 --- /dev/null +++ b/examples/MyAccountAuthenticationMethods.md @@ -0,0 +1,240 @@ +# My Account API — Authentication Methods & Factors + +The [My Account API](https://auth0.com/docs/manage-users/my-account-api) lets a **logged-in user manage their own account**. This guide covers the **authentication-methods** and **factors** surface: enrolling a new passkey (or other factor), and listing, reading, renaming, and deleting a user's enrolled methods. + +> [!NOTE] +> This is a different My Account resource from [Connected Accounts](ConnectedAccounts.md) (Token Vault). Connected-accounts management is exposed as convenience methods on `ServerClient`; **authentication-method management is on `MyAccountClient` directly**, because each call takes a user access token you obtain yourself. The two share the same My Account setup (activation, MRRT, scopes, `MyAccountApiError`) — see [ConnectedAccounts.md → Pre-requisites](ConnectedAccounts.md#pre-requisites) for that common setup. + +> [!NOTE] +> To **sign in** with a passkey (rather than manage one), see [examples/Passkeys.md](Passkeys.md). To **bind these calls to a held key** with DPoP, see [examples/DPoP.md](DPoP.md). + +## Table of Contents + +- [Prerequisites](#prerequisites) +- [Obtaining a scoped token](#obtaining-a-scoped-token) +- [1. List factors available for enrollment](#1-list-factors-available-for-enrollment) +- [2. Enroll an authentication method (passkey)](#2-enroll-an-authentication-method-passkey) +- [3. List authentication methods](#3-list-authentication-methods) +- [4. Get a single authentication method](#4-get-a-single-authentication-method) +- [5. Update (rename) an authentication method](#5-update-rename-an-authentication-method) +- [6. Delete an authentication method](#6-delete-an-authentication-method) +- [DPoP](#dpop) +- [Error Handling](#error-handling) +- [Additional Resources](#additional-resources) + +## Prerequisites + +1. [Activate the My Account API](https://auth0.com/docs/manage-users/my-account-api#activate-the-my-account-api) on your tenant and enable access for your application. +2. [Configure MRRT](https://auth0.com/docs/secure/tokens/refresh-tokens/multi-resource-refresh-token) so your refresh-token policy can mint tokens for the My Account audience (`https://{yourDomain}/me/`) with the authentication-methods scopes. +3. Passkey enrollment additionally requires a [Custom Domain](https://auth0.com/docs/customize/custom-domains) and the native passkey feature on your tenant. + +The scopes for this surface (note the **hyphens**): + +| Operation | Scope | +|-----------|-------| +| List factors | `read:me:factors` | +| List / get methods | `read:me:authentication-methods` | +| Enroll / verify | `create:me:authentication-methods` | +| Update | `update:me:authentication-methods` | +| Delete | `delete:me:authentication-methods` | + +> [!TIP] +> As with Connected Accounts, set the default `scope` for the My Account audience when constructing `ServerClient` to avoid a fresh token request per scope. See [ConnectedAccounts.md → A note about scopes](ConnectedAccounts.md#a-note-about-scopes). + +## Obtaining a scoped token + +`MyAccountClient` is **stateless** — it takes a correctly-scoped user access token on every call. Obtain that token from your `ServerClient` session via MRRT, then construct the client: + +```python +from auth0_server_python.auth_server.my_account_client import MyAccountClient + +# Fresh My Account-scoped token for the current session (MRRT exchange) +access_token = await server_client.get_access_token( + store_options={"request": request, "response": response}, + audience=f"https://{YOUR_CUSTOM_DOMAIN}/me/", + scope="create:me:authentication-methods read:me:authentication-methods read:me:factors", +) + +my_account = MyAccountClient(domain=YOUR_CUSTOM_DOMAIN) +``` + +## 1. List factors available for enrollment + +```python +factors = await my_account.get_factors(access_token=access_token) +for factor in factors.factors: + print(factor.type, factor.usage) +``` + +## 2. Enroll an authentication method (passkey) + +Enrollment is a **two-step** ceremony, mirroring sign-in: request a challenge, sign it in the browser, then verify. + +### Step 1 — Start enrollment + +```python +from auth0_server_python.auth_types import EnrollAuthenticationMethodRequest + +challenge = await my_account.enroll_authentication_method( + access_token=access_token, + request=EnrollAuthenticationMethodRequest(type="passkey"), +) + +# challenge.authentication_method_id -> id of the new (unverified) method +# challenge.auth_session -> Tier 1 session credential (do not log) +# challenge.authn_params_public_key -> pass to navigator.credentials.create() +``` + +`EnrollAuthenticationMethodRequest.type` is a closed set: `passkey`, `email`, `phone`, `totp`, `push-notification`, `recovery-code`, `password`. For non-passkey types, supply the relevant fields (`email`, `phone_number`, `preferred_authentication_method`). An invalid type fails at construction with a clear `ValidationError`. + +### Step 2 — Create the credential in the browser + +Pass `challenge.authn_params_public_key` to `navigator.credentials.create()` and collect the resulting credential. + +### Step 3 — Verify enrollment + +```python +from auth0_server_python.auth_types import ( + VerifyAuthenticationMethodRequest, + PasskeyAuthResponse, +) + +method = await my_account.verify_authentication_method( + access_token=access_token, + authentication_method_id=challenge.authentication_method_id, + request=VerifyAuthenticationMethodRequest( + auth_session=challenge.auth_session, + authn_response=PasskeyAuthResponse( + id=credential["id"], + raw_id=credential["rawId"], + type="public-key", + response={ + "clientDataJSON": credential["response"]["clientDataJSON"], + "attestationObject": credential["response"]["attestationObject"], + }, + ), + ), +) +print(f"Enrolled: {method.id} ({method.type})") +``` + +> [!NOTE] +> For non-passkey types, set the matching field on `VerifyAuthenticationMethodRequest` instead of `authn_response`: `otp_code` (email/phone/totp), `recovery_code`, or `password`. A push enrollment needs only `auth_session`. + +## 3. List authentication methods + +```python +all_methods = await my_account.list_authentication_methods(access_token=access_token) + +# Filter by type +passkeys = await my_account.list_authentication_methods( + access_token=access_token, + type_filter="passkey", +) +for m in passkeys.authentication_methods: + print(m.id, m.type, m.created_at) +``` + +> [!NOTE] +> `AuthenticationMethod` and `Factor` are forward-tolerant (`extra="allow"`): fields or method/factor types Auth0 adds later still deserialize. Don't switch exhaustively on `type` — handle unknown types gracefully. + +## 4. Get a single authentication method + +```python +method = await my_account.get_authentication_method( + access_token=access_token, + authentication_method_id="passkey|abc123", +) +``` + +> [!NOTE] +> Method IDs (e.g. `passkey|abc123`) can contain characters like `|`. The SDK URL-encodes every ID it places in a path, so pass the raw ID exactly as returned — do not pre-encode it. + +## 5. Update (rename) an authentication method + +```python +from auth0_server_python.auth_types import UpdateAuthenticationMethodRequest + +method = await my_account.update_authentication_method( + access_token=access_token, + authentication_method_id="passkey|abc123", + request=UpdateAuthenticationMethodRequest(name="My Work Laptop"), +) +``` + +## 6. Delete an authentication method + +```python +await my_account.delete_authentication_method( + access_token=access_token, + authentication_method_id="passkey|abc123", +) +# Returns None on success (HTTP 204). +``` + +## DPoP + +Every method above accepts an optional `dpop_key` to present a sender-constrained token (`Authorization: DPoP` + a per-request proof) instead of a Bearer token. Pass the **same key** the access token was bound to: + +```python +methods = await my_account.list_authentication_methods( + access_token=access_token, + dpop_key=dpop_key, +) +``` + +See [examples/DPoP.md](DPoP.md) for key generation, the `dpop_key` vs `dpop_proof` distinction, and nonce handling. + +## Error Handling + +All errors inherit from `Auth0Error`. My Account API errors are `MyAccountApiError` (RFC 7807 problem-details, carrying `status`, `detail`, and optional `validation_errors`); missing arguments raise `MissingRequiredArgumentError`; transport or non-JSON responses surface as `ApiError`. + +### Basic handling (recommended) + +```python +from auth0_server_python.error import Auth0Error + +try: + methods = await my_account.list_authentication_methods(access_token=access_token) +except Auth0Error as e: + return {"error": str(e)} +``` + +### Advanced handling (when actions differ by case) + +```python +from auth0_server_python.error import Auth0Error, MyAccountApiError + +try: + await my_account.enroll_authentication_method( + access_token=access_token, + request=EnrollAuthenticationMethodRequest(type="passkey"), + ) +except MyAccountApiError as e: + if e.status == 401: + return redirect_to_login() # token expired + if e.status == 403: + return {"error": "Missing required scope"} # e.g. create:me:authentication-methods + if e.status == 400 and e.validation_errors: + return {"error": "Validation failed", "details": e.validation_errors} + raise +except Auth0Error as e: + return {"error": str(e)} +``` + +> [!NOTE] +> Enrollment raises `MyAccountApiError`/`ApiError`, whereas passkey **sign-in** (`ServerClient`) raises `PasskeyError`. They are two distinct API surfaces — an auth grant versus a My Account resource — so write the `except` that matches the call you made. + +### Common error types + +- **`Auth0Error`** (base): catch for general handling +- **`MyAccountApiError`**: My Account API errors with `status`, `detail`, optional `validation_errors` +- **`MissingRequiredArgumentError`**: a required parameter (`access_token`, `authentication_method_id`, `request`) was not provided +- **`ApiError`**: transport failure or a non-JSON error body + +## Additional Resources + +- [Connected Accounts (Token Vault)](ConnectedAccounts.md) — the other My Account surface, and shared My Account/MRRT setup +- [Passkey Authentication](Passkeys.md) — signing in with a passkey +- [DPoP](DPoP.md) — sender-constrained tokens +- [Auth0 My Account API documentation](https://auth0.com/docs/manage-users/my-account-api) diff --git a/examples/Passkeys.md b/examples/Passkeys.md new file mode 100644 index 0000000..7ec1a27 --- /dev/null +++ b/examples/Passkeys.md @@ -0,0 +1,252 @@ +# Passkey Authentication + +Passkeys let users sign up and log in with [WebAuthn](https://www.w3.org/TR/webauthn-2/) credentials (Touch ID, Face ID, Windows Hello, or a hardware security key) instead of a password. This guide covers the **primary authentication** flow on `ServerClient` — signing a user up or in with a passkey and establishing a server-side session. + +> [!NOTE] +> Passkeys require a [Custom Domain](https://auth0.com/docs/customize/custom-domains) (WebAuthn binds the credential to the relying-party domain) and the native passkey feature enabled on your tenant. See the [Auth0 passkey documentation](https://auth0.com/docs/authenticate/database-connections/passkeys). + +> [!NOTE] +> Managing a logged-in user's enrolled passkeys (enroll a new passkey, list, rename, delete) is a **separate** surface on the My Account API. See [examples/MyAccountAuthenticationMethods.md](MyAccountAuthenticationMethods.md). + +## Table of Contents + +- [How the flow works](#how-the-flow-works) +- [Prerequisites](#prerequisites) +- [1. Passkey Signup](#1-passkey-signup) +- [2. Passkey Login](#2-passkey-login) +- [3. Organizations](#3-organizations) +- [4. Step-up MFA during passkey login](#4-step-up-mfa-during-passkey-login) +- [5. DPoP-bound passkey tokens](#5-dpop-bound-passkey-tokens) +- [Error Handling](#error-handling) +- [Additional Resources](#additional-resources) + +## How the flow works + +A passkey ceremony is always **two steps**, because the WebAuthn signature happens in the browser between them: + +1. **Challenge** — the SDK asks Auth0 for a challenge (`passkey_signup_challenge` / `passkey_login_challenge`). Auth0 returns an `auth_session` and the WebAuthn options (`authn_params_public_key`). +2. **Browser** — your front end passes those options to `navigator.credentials.create()` (signup) or `navigator.credentials.get()` (login). The authenticator produces a signed credential. +3. **Verify / sign-in** — the SDK exchanges the signed credential for tokens (`signin_with_passkey`) and **creates a server-side session**, exactly like every other login path. + +``` +ServerClient.passkey_*_challenge() ──► auth_session + authn_params_public_key + │ + navigator.credentials.create()/get() (browser signs) + │ +ServerClient.signin_with_passkey() ◄── signed authn_response + └─► tokens validated, session persisted → PasskeyLoginResult +``` + +## Prerequisites + +```python +from auth0_server_python.auth_server.server_client import ServerClient + +server_client = ServerClient( + domain="YOUR_CUSTOM_DOMAIN", + client_id="YOUR_CLIENT_ID", + client_secret="YOUR_CLIENT_SECRET", + secret="YOUR_SECRET", +) +``` + +The **Passkey** grant (`urn:okta:params:oauth:grant-type:webauthn`) must be enabled for your application under **Applications → Your App → Grant Types**. + +## 1. Passkey Signup + +### Step 1 — Request a signup challenge + +```python +from auth0_server_python.auth_types import PasskeyUserProfile + +challenge = await server_client.passkey_signup_challenge( + user_profile=PasskeyUserProfile( + email="new.user@example.com", + name="Jane Doe", + ), + connection="Username-Password-Authentication", # optional database connection (realm) + store_options={"request": request, "response": response}, +) + +# Hand these to the browser: +# challenge.auth_session -> opaque session credential (Tier 1, do not log) +# challenge.authn_params_public_key -> pass to navigator.credentials.create() +``` + +> [!TIP] +> `PasskeyUserProfile` allows extra fields — any additional profile attribute your tenant accepts (for example `given_name`, `family_name`, `picture`) passes through without an SDK change. Pass tenant-specific custom data via the separate `user_metadata` argument. + +### Step 2 — Create the credential in the browser + +Pass `authn_params_public_key` to `navigator.credentials.create()`. The resulting credential serializes to the shape the SDK expects in step 3 (`id`, `rawId`, `type`, and a `response` object with `clientDataJSON` + `attestationObject`). + +### Step 3 — Verify and establish the session + +```python +from auth0_server_python.auth_types import PasskeyAuthResponse + +result = await server_client.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=PasskeyAuthResponse( + id=credential["id"], + raw_id=credential["rawId"], # accepts rawId alias too + type="public-key", + response={ + "clientDataJSON": credential["response"]["clientDataJSON"], + "attestationObject": credential["response"]["attestationObject"], + }, + ), + store_options={"request": request, "response": response}, +) + +user = result.state_data["user"] +print(f"Signed up and logged in: {user['sub']}") +``` + +`signin_with_passkey` returns a `PasskeyLoginResult` whose `state_data` holds the user claims and token sets — the same shape as `complete_interactive_login` and `login_with_custom_token_exchange`. The session is persisted to your configured state store. + +## 2. Passkey Login + +Identical shape, different endpoints. The login challenge takes an optional `username` hint (for conditional UI), and the browser uses `navigator.credentials.get()`. + +```python +# Step 1 — login challenge +challenge = await server_client.passkey_login_challenge( + username="existing.user@example.com", # optional + connection="Username-Password-Authentication", # optional + store_options={"request": request, "response": response}, +) + +# Step 2 — browser: navigator.credentials.get(challenge.authn_params_public_key) + +# Step 3 — sign in. The login credential's response carries +# clientDataJSON + authenticatorData + signature + userHandle. +result = await server_client.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=PasskeyAuthResponse( + id=credential["id"], + raw_id=credential["rawId"], + type="public-key", + response={ + "clientDataJSON": credential["response"]["clientDataJSON"], + "authenticatorData": credential["response"]["authenticatorData"], + "signature": credential["response"]["signature"], + "userHandle": credential["response"]["userHandle"], + }, + ), + store_options={"request": request, "response": response}, +) +``` + +> [!NOTE] +> The SDK is transparent to the signup-vs-login difference in the credential `response` — both flow through the same `PasskeyAuthResponse.response` dict. Send exactly the keys the browser produced. + +## 3. Organizations + +Pass an `organization` (ID or name) on the challenge to scope the passkey ceremony to an organization. The resulting `id_token` carries the `org_id` claim, validated automatically at session creation. + +```python +challenge = await server_client.passkey_login_challenge( + organization="org_abc123", + store_options={"request": request, "response": response}, +) +# ... signin_with_passkey(organization="org_abc123", ...) +``` + +## 4. Step-up MFA during passkey login + +If tenant policy requires a second factor, `signin_with_passkey` raises `MfaRequiredError` — the login does **not** complete silently. The raw MFA token is encrypted by the SDK before it reaches you, and stored server-side so your challenge/verify routes can retrieve it without a client round-trip. + +```python +from auth0_server_python.error import MfaRequiredError + +try: + result = await server_client.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=authn_response, + store_options={"request": request, "response": response}, + ) +except MfaRequiredError as e: + # e.mfa_token is ENCRYPTED — hand it straight to MfaClient. + # See examples/MFA.md for the challenge/verify flow. + ... +``` + +See [examples/MFA.md](MFA.md) for the full challenge → verify continuation. + +> [!NOTE] +> A passkey is supported as a **first** factor today. WebAuthn as a **second** factor is currently only available through Universal Login (hosted), not as a headless API a server SDK can drive — so this SDK does not implement it. The response models are forward-tolerant for when that capability ships. + +## 5. DPoP-bound passkey tokens + +Pass a `dpop_key` to bind the issued tokens to a key you hold (RFC 9449). When supplied, the SDK attaches a DPoP proof to the token exchange and Auth0 issues a DPoP-bound token; if the server returns an unbound (`Bearer`) token instead, `signin_with_passkey` raises rather than accept the downgrade. + +```python +from jwcrypto import jwk + +dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") # you create and keep this key + +result = await server_client.signin_with_passkey( + auth_session=challenge.auth_session, + authn_response=authn_response, + dpop_key=dpop_key, + store_options={"request": request, "response": response}, +) +``` + +Reuse the **same** `dpop_key` for any My Account API calls made with the resulting token. See [examples/DPoP.md](DPoP.md) for the full picture. + +## Error Handling + +The three passkey methods raise `PasskeyError` (a subclass of `Auth0Error`). Input-validation failures raise `MissingRequiredArgumentError`; a required step-up raises `MfaRequiredError`. For most code, catching `Auth0Error` is enough. + +### Basic handling (recommended) + +```python +from auth0_server_python.error import Auth0Error + +try: + result = await server_client.signin_with_passkey( + auth_session=auth_session, + authn_response=authn_response, + store_options={"request": request, "response": response}, + ) +except Auth0Error as e: + return {"error": str(e)} +``` + +### Advanced handling (when actions differ by case) + +```python +from auth0_server_python.error import PasskeyError, MfaRequiredError, Auth0Error + +try: + result = await server_client.signin_with_passkey( + auth_session=auth_session, + authn_response=authn_response, + store_options={"request": request, "response": response}, + ) +except MfaRequiredError as e: + return start_mfa(e.mfa_token) # step-up required — continue with MfaClient +except PasskeyError as e: + return {"error": e.code, "detail": e.message} # branch on e.code, never on message text +except Auth0Error as e: + return {"error": str(e)} +``` + +### Common error codes (`PasskeyErrorCode`) + +- `passkey_challenge_error` — the signup/login challenge request failed +- `passkey_token_error` — token exchange failed (also used for a rejected DPoP downgrade) +- `invalid_response` — Auth0 returned a response that could not be parsed + +> [!NOTE] +> `auth_session` is a short-lived (typically ~5 min) Tier 1 credential. It is redacted in the SDK's model `repr()`, and you should never log or persist it. If the ceremony takes too long, re-request the challenge. + +## Additional Resources + +- [Managing passkeys via My Account API](MyAccountAuthenticationMethods.md) — enroll/list/delete a logged-in user's passkeys +- [DPoP](DPoP.md) — sender-constrained tokens +- [MFA](MFA.md) — handling `MfaRequiredError` +- [Auth0 Passkey documentation](https://auth0.com/docs/authenticate/database-connections/passkeys) +- [WebAuthn Level 2 (W3C)](https://www.w3.org/TR/webauthn-2/) diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index a0e0a19..30fadb2 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -12,13 +12,21 @@ def _base64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") +def _validate_dpop_key(key: "jwk.JWK") -> dict: + """Return the public JWK after enforcing the EC P-256 requirement (ES256).""" + public_jwk = key.export_public(as_dict=True) + if public_jwk.get("kty") != "EC" or public_jwk.get("crv") != "P-256": + raise ValueError("DPoP key must be an EC P-256 key") + return public_jwk + + def make_dpop_proof_for_token_endpoint(key: "jwk.JWK", method: str, url: str, nonce: str = None) -> str: """ Build a DPoP proof JWT for use at the token endpoint (RFC 9449 §4.2). Unlike resource-server proofs, token-endpoint proofs do NOT include `ath` because no access token exists yet at issuance time. """ - public_jwk = key.export_public(as_dict=True) + public_jwk = _validate_dpop_key(key) htu = url.split("?")[0].split("#")[0] header = {"typ": "dpop+jwt", "alg": "ES256", "jwk": public_jwk} payload = { @@ -36,9 +44,7 @@ def make_dpop_proof_for_token_endpoint(key: "jwk.JWK", method: str, url: str, no class DPoPAuth(httpx.Auth): def __init__(self, token: str, key: "jwk.JWK") -> None: - public_jwk = key.export_public(as_dict=True) - if public_jwk.get("kty") != "EC" or public_jwk.get("crv") != "P-256": - raise ValueError("DPoP key must be an EC P-256 key") + public_jwk = _validate_dpop_key(key) try: token.encode("ascii") except UnicodeEncodeError: diff --git a/src/auth0_server_python/auth_server/mfa_client.py b/src/auth0_server_python/auth_server/mfa_client.py index bfe1e08..587dc51 100644 --- a/src/auth0_server_python/auth_server/mfa_client.py +++ b/src/auth0_server_python/auth_server/mfa_client.py @@ -3,6 +3,7 @@ Handles Multi-Factor Authentication operations against the Auth0 MFA API. """ +import json import time from typing import Any, Callable, Optional, Union @@ -188,6 +189,70 @@ def _resolve_encrypted_token( raise MfaTokenInvalidError() return token + @staticmethod + def _parse_error_body(response: httpx.Response) -> dict[str, Any]: + """ + Parse an error response body as JSON, falling back to a status-coded + stub when the body is not JSON (e.g. a gateway 502/504 HTML page). + + Never raises — the caller always gets a dict it can read error fields + from, so a non-JSON error surfaces the real HTTP status rather than a + JSON-parser exception folded into the message. + """ + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + data = None + if not isinstance(data, dict): + return { + "error_description": f"Request failed with status {response.status_code}", + } + return data + + async def _raise_mfa_required( + self, + error_data: dict[str, Any], + *, + audience: str, + scope: str, + default_description: str, + store_pending: bool = False, + store_options: Optional[dict[str, Any]] = None, + ) -> None: + """ + Encrypt the server-issued mfa_token and raise MfaRequiredError. + + Shared by every site that handles an `mfa_required` response so the + encrypt-then-raise behaviour cannot drift between entry points. Returns + only when the response carries no mfa_token (caller then falls through + to its own typed error). + + store_pending controls whether the encrypted token is persisted to the + state store before raising. It is an explicit argument so the difference + between entry points is visible: the passkey grant persists it here, + while the refresh-token path relies on its get_access_token caller. + """ + raw_mfa_token = error_data.get("mfa_token") + if not raw_mfa_token: + return + mfa_requirements_data = error_data.get("mfa_requirements") + mfa_requirements = ( + MfaRequirements(**mfa_requirements_data) if mfa_requirements_data else None + ) + encrypted_token = self._encrypt_mfa_token( + raw_mfa_token=raw_mfa_token, + audience=audience, + scope=scope, + mfa_requirements=mfa_requirements, + ) + if store_pending: + await self.store_pending_mfa(encrypted_token, store_options) + raise MfaRequiredError( + error_data.get("error_description", default_description), + mfa_token=encrypted_token, + mfa_requirements=mfa_requirements, + ) + # ============================================================================ # MFA API OPERATIONS # ============================================================================ @@ -222,7 +287,7 @@ async def list_authenticators( ) if response.status_code != 200: - error_data = response.json() + error_data = self._parse_error_body(response) raise MfaListAuthenticatorsError( error_data.get("error_description", "Failed to list authenticators"), error_data @@ -235,8 +300,8 @@ async def list_authenticators( raise except Exception as e: raise MfaListAuthenticatorsError( - f"Unexpected error listing authenticators: {str(e)}" - ) + "Unexpected error listing authenticators" + ) from e async def enroll_authenticator( self, @@ -299,7 +364,7 @@ async def enroll_authenticator( ) if response.status_code != 200: - error_data = response.json() + error_data = self._parse_error_body(response) raise MfaEnrollmentError( error_data.get("error_description", "Failed to enroll authenticator"), error_data @@ -321,8 +386,8 @@ async def enroll_authenticator( raise except Exception as e: raise MfaEnrollmentError( - f"Unexpected error enrolling authenticator: {str(e)}" - ) + "Unexpected error enrolling authenticator" + ) from e async def challenge_authenticator( self, @@ -377,7 +442,7 @@ async def challenge_authenticator( ) if response.status_code != 200: - error_data = response.json() + error_data = self._parse_error_body(response) raise MfaChallengeError( error_data.get("error_description", "Failed to challenge authenticator"), error_data @@ -390,8 +455,8 @@ async def challenge_authenticator( raise except Exception as e: raise MfaChallengeError( - f"Unexpected error challenging authenticator: {str(e)}" - ) + "Unexpected error challenging authenticator" + ) from e async def verify( self, @@ -460,25 +525,16 @@ async def verify( ) if response.status_code != 200: - error_data = response.json() + error_data = self._parse_error_body(response) if error_data.get("error") == "mfa_required": - new_raw_token = error_data.get("mfa_token") - mfa_requirements_data = error_data.get("mfa_requirements") - mfa_requirements = None - if mfa_requirements_data: - mfa_requirements = MfaRequirements(**mfa_requirements_data) - - new_encrypted = self._encrypt_mfa_token( - raw_mfa_token=new_raw_token, + # Chained MFA: re-encrypt the new token with the original + # audience/scope from the incoming context before raising. + await self._raise_mfa_required( + error_data, audience=context.audience, scope=context.scope, - mfa_requirements=mfa_requirements, - ) - raise MfaRequiredError( - error_data.get("error_description", "Additional MFA factor required"), - mfa_token=new_encrypted, - mfa_requirements=mfa_requirements, + default_description="Additional MFA factor required", ) raise MfaVerifyError( @@ -504,8 +560,8 @@ async def verify( raise except Exception as e: raise MfaVerifyError( - f"Unexpected error during MFA verification: {str(e)}" - ) + "Unexpected error during MFA verification" + ) from e async def _persist_mfa_tokens( self, diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index c31d2e0..f194271 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -34,7 +34,6 @@ LoginWithCustomTokenExchangeResult, LogoutOptions, LogoutTokenClaims, - MfaRequirements, PasskeyAuthResponse, PasskeyLoginChallengeResponse, PasskeyLoginResult, @@ -1144,21 +1143,17 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, error_code = error_data.get("error", "refresh_token_error") if error_code == "mfa_required": - raw_mfa_token = error_data.get("mfa_token") - mfa_requirements_data = error_data.get("mfa_requirements") - mfa_requirements = MfaRequirements(**mfa_requirements_data) if mfa_requirements_data else None - if raw_mfa_token: - encrypted_token = self._mfa_client._encrypt_mfa_token( - raw_mfa_token=raw_mfa_token, - audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, - scope=merged_scope or "", - mfa_requirements=mfa_requirements, - ) - raise MfaRequiredError( - error_data.get("error_description", "MFA required"), - mfa_token=encrypted_token, - mfa_requirements=mfa_requirements, - ) + # Encrypt + raise via the shared helper so this matches + # the passkey and chained-verify sites. Returns only when + # no mfa_token is present (then falls through to ApiError). + # store_pending is left False here: the get_access_token + # caller persists the token in its own catch block. + await self._mfa_client._raise_mfa_required( + error_data, + audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, + scope=merged_scope or "", + default_description="MFA required", + ) raise ApiError( error_code, @@ -2407,7 +2402,10 @@ async def custom_token_exchange( # before trusting any claim from the token. if self._normalize_url(claims.get("iss", "")) == self._normalize_url(metadata.get("issuer")): token_response.act = claims.get("act") - except Exception: + except (jwt.InvalidTokenError, ValueError, KeyError): + # A genuinely absent/optional act claim or a benign decode + # gap leaves act None. Anything outside these types (an + # unexpected verify failure) surfaces rather than being masked. token_response.act = None return token_response @@ -2817,10 +2815,13 @@ async def signin_with_passkey( ) response = await client.post(token_endpoint, json=body, headers=headers) - # RFC 9449 §8.2 — nonce retry for DPoP token endpoint calls + # RFC 9449 — the authorization server signals a required nonce + # with HTTP 400 + error="use_dpop_nonce" + DPoP-Nonce. Accept + # 401 as well so the retry holds against servers that mirror the + # resource-server status. if ( dpop_key is not None - and response.status_code == 401 + and response.status_code in (400, 401) and response.headers.get("DPoP-Nonce") ): nonce = response.headers["DPoP-Nonce"] @@ -2839,22 +2840,17 @@ async def signin_with_passkey( ) error_code = error_data.get("error", PasskeyErrorCode.TOKEN_EXCHANGE_FAILED) if error_code == "mfa_required": - raw_mfa_token = error_data.get("mfa_token") - mfa_requirements_data = error_data.get("mfa_requirements") - mfa_requirements = MfaRequirements(**mfa_requirements_data) if mfa_requirements_data else None - if raw_mfa_token: - encrypted_token = self._mfa_client._encrypt_mfa_token( - raw_mfa_token=raw_mfa_token, - audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, - scope=scope or "", - mfa_requirements=mfa_requirements, - ) - await self._mfa_client.store_pending_mfa(encrypted_token, store_options) - raise MfaRequiredError( - "Multifactor authentication required", - mfa_token=encrypted_token, - mfa_requirements=mfa_requirements, - ) + # Passkey grant persists the pending token here so the + # challenge/verify routes can retrieve it server-side. + # Returns only when no mfa_token is present. + await self._mfa_client._raise_mfa_required( + error_data, + audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, + scope=scope or "", + default_description="Multifactor authentication required", + store_pending=True, + store_options=store_options, + ) raise PasskeyError( error_code, error_data.get("error_description", "Passkey token exchange failed"), diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index bdeb137..ae6c6ce 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -452,8 +452,9 @@ class ListConnectedAccountConnectionsResponse(BaseModel): # MFA Types # ============================================================================= -# Type aliases using Literal types -AuthenticatorType = Literal["otp", "oob", "recovery-code"] +# Type aliases using Literal types. Used to validate caller-supplied input. +# Server-controlled response fields use plain str instead, so a new factor or +# challenge type (e.g. a future webauthn second factor) does not fail closed. OobChannel = Literal["sms", "voice", "auth0", "email"] ChallengeType = Literal["otp", "oob"] @@ -461,8 +462,11 @@ class ListConnectedAccountConnectionsResponse(BaseModel): class AuthenticatorResponse(BaseModel): """Represents an MFA authenticator enrolled by a user.""" + model_config = ConfigDict(extra="allow") id: str - authenticator_type: AuthenticatorType + # Server-controlled value; kept as str so a new factor type (e.g. a future + # webauthn second factor) does not fail closed when Auth0 adds it. + authenticator_type: str active: bool name: Optional[str] = None oob_channel: Optional[OobChannel] = None @@ -545,7 +549,10 @@ class ChallengeOptions(BaseModel): class ChallengeResponse(BaseModel): """Response from initiating an MFA challenge.""" - challenge_type: ChallengeType + model_config = ConfigDict(extra="allow") + # Server-controlled value; kept as str so a new challenge type does not fail + # closed when Auth0 adds it. + challenge_type: str oob_code: Optional[str] = None binding_method: Optional[str] = None expires_in: Optional[int] = None @@ -591,6 +598,7 @@ class VerifyRecoveryCodeOptions(BaseModel): class MfaVerifyResponse(BaseModel): """Response from MFA verification.""" + model_config = ConfigDict(extra="allow") access_token: str token_type: str = "Bearer" expires_in: int diff --git a/src/auth0_server_python/tests/test_dpop_auth.py b/src/auth0_server_python/tests/test_dpop_auth.py index b6beb69..2cd120b 100644 --- a/src/auth0_server_python/tests/test_dpop_auth.py +++ b/src/auth0_server_python/tests/test_dpop_auth.py @@ -7,7 +7,11 @@ from jwcrypto import jwk from auth0_server_python.auth_schemes.bearer_auth import BearerAuth -from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, _base64url +from auth0_server_python.auth_schemes.dpop_auth import ( + DPoPAuth, + _base64url, + make_dpop_proof_for_token_endpoint, +) from auth0_server_python.auth_server.my_account_client import _make_auth @@ -115,6 +119,44 @@ def test_dpop_proof_uniqueness(ec_key): assert len(jtis) == 10 +def test_dpop_nonce_retry_on_resource_server(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(request) + + first = next(flow) + _, first_payload = _decode_jwt_parts(first.headers["DPoP"]) + assert "nonce" not in first_payload + + nonce_response = httpx.Response( + status_code=401, + headers={"DPoP-Nonce": "rs-nonce-xyz"}, + request=first, + ) + retried = flow.send(nonce_response) + _, retried_payload = _decode_jwt_parts(retried.headers["DPoP"]) + assert retried_payload["nonce"] == "rs-nonce-xyz" + assert retried_payload["ath"] == first_payload["ath"] + assert retried_payload["jti"] != first_payload["jti"] + + +def test_dpop_no_retry_without_nonce_header(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(request) + next(flow) + + ok_response = httpx.Response(status_code=200, request=request) + with pytest.raises(StopIteration): + flow.send(ok_response) + + +def test_token_endpoint_proof_rejects_non_ec_key(): + rsa_key = jwk.JWK.generate(kty="RSA", size=2048) + with pytest.raises(ValueError, match="EC P-256"): + make_dpop_proof_for_token_endpoint(rsa_key, "POST", "https://example.com/oauth/token") + + def test_dpop_repr_redacts_credentials(ec_key): auth = DPoPAuth(token="secret_access_token_value", key=ec_key) assert "secret_access_token_value" not in repr(auth) diff --git a/src/auth0_server_python/tests/test_mfa_client.py b/src/auth0_server_python/tests/test_mfa_client.py index 3533b57..ac7e9a4 100644 --- a/src/auth0_server_python/tests/test_mfa_client.py +++ b/src/auth0_server_python/tests/test_mfa_client.py @@ -2,6 +2,7 @@ Tests for MfaClient — MFA API operations. """ +import json from unittest.mock import AsyncMock, MagicMock import pytest @@ -261,6 +262,21 @@ async def test_list_authenticators_api_error(self, mocker): await client.list_authenticators({"mfa_token": _enc("bad_tok")}) assert "Invalid MFA token" in str(exc.value) + @pytest.mark.asyncio + async def test_list_authenticators_non_json_error_body(self, mocker): + """A non-JSON error body (e.g. a gateway 502 HTML page) surfaces the + HTTP status, not a JSON-parser exception folded into the message.""" + client = _make_client() + response = AsyncMock() + response.status_code = 502 + response.json = MagicMock(side_effect=json.JSONDecodeError("Expecting value", "", 0)) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MfaListAuthenticatorsError) as exc: + await client.list_authenticators({"mfa_token": _enc()}) + assert "502" in str(exc.value) + assert "Expecting value" not in str(exc.value) + @pytest.mark.asyncio async def test_list_authenticators_unexpected_error(self, mocker): client = _make_client() @@ -268,7 +284,12 @@ async def test_list_authenticators_unexpected_error(self, mocker): with pytest.raises(MfaListAuthenticatorsError) as exc: await client.list_authenticators({"mfa_token": _enc()}) - assert "network down" in str(exc.value) + # Generic message — the underlying error is not leaked into the message... + assert "network down" not in str(exc.value) + assert "Unexpected error listing authenticators" in str(exc.value) + # ...but is preserved on the exception chain for internal debugging. + assert isinstance(exc.value.__cause__, Exception) + assert "network down" in str(exc.value.__cause__) # ── enroll_authenticator ───────────────────────────────────────────────────── @@ -731,7 +752,12 @@ async def test_verify_unexpected_error(self, mocker): with pytest.raises(MfaVerifyError) as exc: await client.verify({"mfa_token": _enc(), "otp": "123456"}) - assert "connection reset" in str(exc.value) + # Generic message — the underlying error is not leaked into the message... + assert "connection reset" not in str(exc.value) + assert "Unexpected error during MFA verification" in str(exc.value) + # ...but is preserved on the exception chain for internal debugging. + assert isinstance(exc.value.__cause__, Exception) + assert "connection reset" in str(exc.value.__cause__) @pytest.mark.asyncio async def test_verify_persist_updates_session(self, mocker): diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index ffe6c67..117c07f 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1,3 +1,4 @@ +import base64 import json import time import unicodedata @@ -5,6 +6,7 @@ from urllib.parse import parse_qs, urlparse import pytest +from jwcrypto import jwk from auth0_server_python.auth_server.mfa_client import MfaClient from auth0_server_python.auth_server.my_account_client import MyAccountClient @@ -5978,9 +5980,6 @@ async def test_signin_with_passkey_missing_expires_at_calculates(mocker): @pytest.mark.asyncio async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): - import base64 - import json as _json - from jwcrypto import jwk as jwk_module client = ServerClient( domain="auth0.local", client_id="test_client_id", @@ -6004,7 +6003,7 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE_DPOP) mock_post.return_value = mock_response - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") await client.signin_with_passkey( auth_session="session_xyz", authn_response=_make_passkey_authn_response(), @@ -6018,7 +6017,7 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): proof = kwargs["headers"]["DPoP"] payload_b64 = proof.split(".")[1] padding = 4 - len(payload_b64) % 4 - payload = _json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + payload = json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) assert "ath" not in payload assert "jti" in payload assert payload["htm"] == "POST" @@ -6027,9 +6026,6 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): @pytest.mark.asyncio async def test_signin_with_passkey_dpop_nonce_retry(mocker): - import base64 - import json as _json - from jwcrypto import jwk as jwk_module client = ServerClient( domain="auth0.local", client_id="test_client_id", @@ -6049,8 +6045,9 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + # RFC 9449 §8.1 — the token endpoint signals a required nonce with HTTP 400. nonce_response = AsyncMock() - nonce_response.status_code = 401 + nonce_response.status_code = 400 nonce_response.headers = {"DPoP-Nonce": "server-nonce-abc"} nonce_response.json = MagicMock(return_value={"error": "use_dpop_nonce"}) @@ -6060,7 +6057,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): mock_post.side_effect = [nonce_response, success_response] - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") result = await client.signin_with_passkey( auth_session="session_xyz", authn_response=_make_passkey_authn_response(), @@ -6075,15 +6072,65 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): proof = second_call_kwargs["headers"]["DPoP"] payload_b64 = proof.split(".")[1] padding = 4 - len(payload_b64) % 4 - payload = _json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + payload = json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) assert payload["nonce"] == "server-nonce-abc" +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_nonce_retry_on_401(mocker): + """Token endpoint nonce retry must also hold when the server returns 401 + DPoP-Nonce.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + nonce_response = AsyncMock() + nonce_response.status_code = 401 + nonce_response.headers = {"DPoP-Nonce": "server-nonce-401"} + nonce_response.json = MagicMock(return_value={"error": "use_dpop_nonce"}) + + success_response = AsyncMock() + success_response.status_code = 200 + success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE_DPOP) + + mock_post.side_effect = [nonce_response, success_response] + + dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + + assert mock_post.await_count == 2 + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_dpop_123" + second_call_kwargs = mock_post.call_args_list[1][1] + proof = second_call_kwargs["headers"]["DPoP"] + payload_b64 = proof.split(".")[1] + padding = 4 - len(payload_b64) % 4 + payload = json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + assert payload["nonce"] == "server-nonce-401" + + @pytest.mark.asyncio async def test_signin_with_passkey_dpop_rejects_bearer_downgrade(mocker): """Server returning token_type=Bearer when DPoP was requested must raise PasskeyError.""" from auth0_server_python.error import PasskeyError - from jwcrypto import jwk as jwk_module + client = ServerClient( domain="auth0.local", client_id="test_client_id", @@ -6103,7 +6150,7 @@ async def test_signin_with_passkey_dpop_rejects_bearer_downgrade(mocker): mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) mock_post.return_value = mock_response - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + dpop_key = jwk.JWK.generate(kty="EC", crv="P-256") with pytest.raises(PasskeyError) as exc: await client.signin_with_passkey( auth_session="session_xyz",