From 3b5b8a24e130c7f968b51ab140d2b54438da6664 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 21 May 2026 13:46:19 +0100 Subject: [PATCH] refactor: Move auth extractors into authentication module And split the auth check into two so that other methods can access the raw bearer token if required. --- src/blueapi/service/authentication.py | 55 ++++++++++++++++- src/blueapi/service/main.py | 31 +--------- .../unit_tests/service/test_authentication.py | 60 +++++++++++++++++-- 3 files changed, 112 insertions(+), 34 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index b107f7b2b2..944dccf5d3 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -6,16 +6,20 @@ import time import webbrowser from abc import ABC, abstractmethod +from collections.abc import Mapping from functools import cached_property from http import HTTPStatus from pathlib import Path -from typing import Any, cast +from typing import Annotated, Any, cast import httpx import jwt import requests +from fastapi import Depends, HTTPException, Request +from fastapi.security.utils import get_authorization_scheme_param from pydantic import TypeAdapter from requests.auth import AuthBase +from starlette.status import HTTP_401_UNAUTHORIZED from blueapi.config import OIDCConfig, ServiceAccount from blueapi.service.model import Cache @@ -272,3 +276,52 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +def unchecked_bearer_token(req: Request) -> str | None: + """Get bearer token value from authorization header""" + auth = req.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(auth) + if scheme.casefold() != "bearer": + return None + return param.strip() + + +UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)] + + +def build_access_token_check(config: OIDCConfig): + """ + Create a function to validate the bearer token of requests + + The returned function should be used via fastAPI's 'Depends' mechanism to + ensure users are authenticated + """ + jwkclient = jwt.PyJWKClient(config.jwks_uri) + + def validate_bearer_token(request: Request, token: UncheckedBearerToken): + """Check that a bearer token is valid and inject into request state""" + if not token: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + signing_key = jwkclient.get_signing_key_from_jwt(token) + decoded: dict[str, Any] = jwt.decode( + token, + signing_key.key, + algorithms=config.id_token_signing_alg_values_supported, + verify=True, + audience=config.client_audience, + issuer=config.issuer, + ) + request.state.decoded_access_token = decoded + + return validate_bearer_token + + +def access_token(request: Request) -> Mapping[str, Any] | None: + """Get the decoded and verified access token of the user making the request""" + return getattr(request.state, "decoded_access_token", None) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index a53c46885a..5dae9462a2 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -19,7 +19,6 @@ from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse -from fastapi.security import OAuth2AuthorizationCodeBearer from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -37,6 +36,7 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.authentication import build_access_token_check from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -61,6 +61,7 @@ RUNNER: WorkerDispatcher | None = None LOGGER = logging.getLogger(__name__) +TRACER = get_tracer("interface") def _runner() -> WorkerDispatcher: @@ -117,7 +118,7 @@ def get_app(config: ApplicationConfig): ) dependencies = [] if config.oidc: - dependencies.append(Depends(decode_access_token(config.oidc))) + dependencies.append(Depends(build_access_token_check(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } @@ -140,32 +141,6 @@ def get_app(config: ApplicationConfig): return app -def decode_access_token(config: OIDCConfig): - jwkclient = jwt.PyJWKClient(config.jwks_uri) - oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=config.authorization_endpoint, - tokenUrl=config.token_endpoint, - refreshUrl=config.token_endpoint, - ) - - def inner(request: Request, access_token: str = Depends(oauth_scheme)): - signing_key = jwkclient.get_signing_key_from_jwt(access_token) - decoded: dict[str, Any] = jwt.decode( - access_token, - signing_key.key, - algorithms=config.id_token_signing_alg_values_supported, - verify=True, - audience=config.client_audience, - issuer=config.issuer, - ) - request.state.decoded_access_token = decoded - - return inner - - -TRACER = get_tracer("interface") - - async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 88227706be..01bc426e20 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -8,15 +8,19 @@ import pytest import responses import respx +from fastapi import HTTPException from pydantic import SecretStr from starlette.status import HTTP_200_OK, HTTP_403_FORBIDDEN from blueapi.config import OIDCConfig, ServiceAccount -from blueapi.service import main +from blueapi.service import authentication from blueapi.service.authentication import ( SessionCacheManager, SessionManager, TiledAuth, + access_token, + build_access_token_check, + unchecked_bearer_token, ) @@ -124,9 +128,9 @@ def test_poll_for_token_timeout( def test_server_raises_exception_for_invalid_token( oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock ): - inner = main.decode_access_token(oidc_config) + inner = authentication.build_access_token_check(oidc_config) with pytest.raises(jwt.PyJWTError): - inner(Mock(), access_token="Invalid Token") + inner(Mock(), token="Invalid Token") def test_processes_valid_token( @@ -134,8 +138,8 @@ def test_processes_valid_token( mock_authn_server: responses.RequestsMock, valid_token_with_jwt, ): - inner = main.decode_access_token(oidc_config) - inner(Mock(), access_token=valid_token_with_jwt["access_token"]) + inner = authentication.build_access_token_check(oidc_config) + inner(Mock(), token=valid_token_with_jwt["access_token"]) def test_session_cache_manager_returns_writable_file_path(tmp_path): @@ -182,3 +186,49 @@ def test_tiled_auth_sync_auth_flow(): result = next(flow) assert result.headers["Authorization"] == f"Bearer {access_token}" + + +@pytest.mark.parametrize( + "header,token", + [ + (None, None), + ("ApiKey foobar", None), + ("Bearer foobar", "foobar"), + ("Bearer with_whitespace ", "with_whitespace"), + ("Bearerfoobar", None), + ], +) +def test_unchecked_bearer_token(header: str | None, token: str | None): + req = Mock() + req.headers.get.side_effect = lambda key: header if key == "Authorization" else None + + assert unchecked_bearer_token(req) == token + + +def test_access_token(): + req = Mock() + req.state.decoded_access_token = {"foo": "bar"} + + assert access_token(req) == {"foo": "bar"} + + +def test_access_token_without_token(): + req = Mock() + del req.state.decoded_access_token + + assert access_token(req) is None + + +@patch("blueapi.service.authentication.jwt") +def test_build_access_token(mock_jwt: Mock): + # Return None when building client to ensure no field/method access + mock_jwt.PyJWKClient.return_value = None + oidc_config = Mock() + req = Mock() + + validate_fn = build_access_token_check(oidc_config) + + with pytest.raises(HTTPException, match="401"): + validate_fn(req, token=None) + + mock_jwt.decode.assert_not_called()