Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 17 additions & 31 deletions src/extension_shield/api/supabase_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,15 @@ def _get_jwks_by_kid(jwks_url: str) -> Dict[str, Dict[str, Any]]:


def _get_expected_issuer(supabase_url: str) -> str:
"""Derive the expected JWT issuer from the Supabase URL."""
return f"{supabase_url.rstrip('/')}/auth/v1"


def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]:
"""
Verify a Supabase access token using the project's JWKS.

Validates:
- Signature (RS256 via JWKS)
- Expiration (exp)
- Audience (aud) if configured
- Issuer (iss) must match {SUPABASE_URL}/auth/v1

Returns:
Decoded JWT payload dict if valid; otherwise None.
"""
settings = get_settings()
jwks_url = settings.supabase_jwks_url
aud = settings.supabase_jwt_aud
supabase_url = settings.supabase_url

if not jwks_url or not supabase_url:
return None

Expand All @@ -90,8 +78,8 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]:

jwks_by_kid = _get_jwks_by_kid(jwks_url)
jwk_data = jwks_by_kid.get(str(kid))

if not jwk_data:
# Key rotation: refresh once before giving up.
global _JWKS_CACHE
_JWKS_CACHE = None
jwks_by_kid = _get_jwks_by_kid(jwks_url)
Expand All @@ -101,9 +89,14 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]:

public_key = jwk.construct(jwk_data)

# Verify signature manually first (jose.jwt.decode also does this, but this
# gives clearer control and avoids surprises around key formats).
message, encoded_sig = token.rsplit(".", 1)
# ✅ FIX: validate token structure before splitting
parts = token.split(".")
if len(parts) != 3:
return None

message = ".".join(parts[:2])
encoded_sig = parts[2]

decoded_sig = base64url_decode(encoded_sig.encode("utf-8"))
if not public_key.verify(message.encode("utf-8"), decoded_sig):
return None
Expand All @@ -120,13 +113,10 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]:
"verify_exp": True,
},
)

# Explicitly validate decoded claims
# Validate issuer

if payload.get("iss") != expected_issuer:
return None

# Validate audience (support both string and list)

if aud:
token_aud = payload.get("aud")
if isinstance(token_aud, list):
Expand All @@ -137,29 +127,25 @@ def verify_supabase_access_token(token: str) -> Optional[Dict[str, Any]]:
return None
else:
return None

# exp is validated by jwt.decode with verify_exp=True


return payload

except Exception:
return None


def get_current_user_id(request: Request) -> Optional[str]:
"""
Best-effort user identity extraction from Authorization header.
"""
authz = request.headers.get("authorization") or request.headers.get("Authorization")
if not authz:
return None

parts = authz.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
return None

payload = verify_supabase_access_token(parts[1])
if not payload:
return None
sub = payload.get("sub")
return str(sub) if sub else None


sub = payload.get("sub")
return str(sub) if sub else None