diff --git a/src/extension_shield/api/supabase_auth.py b/src/extension_shield/api/supabase_auth.py index f11eb21..c4eb241 100644 --- a/src/extension_shield/api/supabase_auth.py +++ b/src/extension_shield/api/supabase_auth.py @@ -56,27 +56,15 @@ def _get_jwks_by_kid(jwks_url: str) -> Dict[str, Dict[str, Any]]: def _get_expected_issuer(supabase_url: str) -> str: - """Derive the expected JWT issuer from the Supabase URL.""" return f"{supabase_url.rstrip('/')}/auth/v1" def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]: - """ - Verify a Supabase access token using the project's JWKS. - - Validates: - - Signature (RS256 via JWKS) - - Expiration (exp) - - Audience (aud) if configured - - Issuer (iss) must match {SUPABASE_URL}/auth/v1 - - Returns: - Decoded JWT payload dict if valid; otherwise None. - """ settings = get_settings() jwks_url = settings.supabase_jwks_url aud = settings.supabase_jwt_aud supabase_url = settings.supabase_url + if not jwks_url or not supabase_url: return None @@ -90,8 +78,8 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]: jwks_by_kid = _get_jwks_by_kid(jwks_url) jwk_data = jwks_by_kid.get(str(kid)) + if not jwk_data: - # Key rotation: refresh once before giving up. global _JWKS_CACHE _JWKS_CACHE = None jwks_by_kid = _get_jwks_by_kid(jwks_url) @@ -101,9 +89,14 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]: public_key = jwk.construct(jwk_data) - # Verify signature manually first (jose.jwt.decode also does this, but this - # gives clearer control and avoids surprises around key formats). - message, encoded_sig = token.rsplit(".", 1) + # ✅ FIX: validate token structure before splitting + parts = token.split(".") + if len(parts) != 3: + return None + + message = ".".join(parts[:2]) + encoded_sig = parts[2] + decoded_sig = base64url_decode(encoded_sig.encode("utf-8")) if not public_key.verify(message.encode("utf-8"), decoded_sig): return None @@ -120,13 +113,10 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]: "verify_exp": True, }, ) - - # Explicitly validate decoded claims - # Validate issuer + if payload.get("iss") != expected_issuer: return None - - # Validate audience (support both string and list) + if aud: token_aud = payload.get("aud") if isinstance(token_aud, list): @@ -137,21 +127,18 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]: return None else: return None - - # exp is validated by jwt.decode with verify_exp=True - + return payload + except Exception: return None def get_current_user_id(request: Request) -> Optional[str]: - """ - Best-effort user identity extraction from Authorization header. - """ authz = request.headers.get("authorization") or request.headers.get("Authorization") if not authz: return None + parts = authz.split() if len(parts) != 2 or parts[0].lower() != "bearer": return None @@ -159,7 +146,6 @@ def get_current_user_id(request: Request) -> Optional[str]: payload = verify_supabase_access_token(parts[1]) if not payload: return None - sub = payload.get("sub") - return str(sub) if sub else None - + sub = payload.get("sub") + return str(sub) if sub else None \ No newline at end of file