Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
import time
import webbrowser
from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import cached_property
from http import HTTPStatus
from pathlib import Path
from typing import Any, cast
from typing import Annotated, Any, cast

import httpx
import jwt
import requests
from fastapi import Depends, HTTPException, Request
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import TypeAdapter
from requests.auth import AuthBase
from starlette.status import HTTP_401_UNAUTHORIZED

from blueapi.config import OIDCConfig, ServiceAccount
from blueapi.service.model import Cache
Expand Down Expand Up @@ -272,3 +276,52 @@ def get_access_token(self):
def sync_auth_flow(self, request):
request.headers["Authorization"] = f"Bearer {self.get_access_token()}"
yield request


def unchecked_bearer_token(req: Request) -> str | None:
"""Get bearer token value from authorization header"""
auth = req.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(auth)
if scheme.casefold() != "bearer":
return None
return param.strip()


UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)]


def build_access_token_check(config: OIDCConfig):
"""
Create a function to validate the bearer token of requests

The returned function should be used via fastAPI's 'Depends' mechanism to
ensure users are authenticated
"""
jwkclient = jwt.PyJWKClient(config.jwks_uri)

def validate_bearer_token(request: Request, token: UncheckedBearerToken):
"""Check that a bearer token is valid and inject into request state"""
if not token:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)

signing_key = jwkclient.get_signing_key_from_jwt(token)
decoded: dict[str, Any] = jwt.decode(
token,
signing_key.key,
algorithms=config.id_token_signing_alg_values_supported,
verify=True,
audience=config.client_audience,
issuer=config.issuer,
)
request.state.decoded_access_token = decoded

return validate_bearer_token


def access_token(request: Request) -> Mapping[str, Any] | None:
"""Get the decoded and verified access token of the user making the request"""
return getattr(request.state, "decoded_access_token", None)
31 changes: 3 additions & 28 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from fastapi.datastructures import Address
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.security import OAuth2AuthorizationCodeBearer
from observability_utils.tracing import (
add_span_attributes,
get_tracer,
Expand All @@ -37,6 +36,7 @@
from blueapi import __version__
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.service import interface
from blueapi.service.authentication import build_access_token_check
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

Expand All @@ -61,6 +61,7 @@
RUNNER: WorkerDispatcher | None = None

LOGGER = logging.getLogger(__name__)
TRACER = get_tracer("interface")


def _runner() -> WorkerDispatcher:
Expand Down Expand Up @@ -117,7 +118,7 @@ def get_app(config: ApplicationConfig):
)
dependencies = []
if config.oidc:
dependencies.append(Depends(decode_access_token(config.oidc)))
dependencies.append(Depends(build_access_token_check(config.oidc)))
app.swagger_ui_init_oauth = {
"clientId": "NOT_SUPPORTED",
}
Expand All @@ -140,32 +141,6 @@ def get_app(config: ApplicationConfig):
return app


def decode_access_token(config: OIDCConfig):
jwkclient = jwt.PyJWKClient(config.jwks_uri)
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.authorization_endpoint,
tokenUrl=config.token_endpoint,
refreshUrl=config.token_endpoint,
)

def inner(request: Request, access_token: str = Depends(oauth_scheme)):
signing_key = jwkclient.get_signing_key_from_jwt(access_token)
decoded: dict[str, Any] = jwt.decode(
access_token,
signing_key.key,
algorithms=config.id_token_signing_alg_values_supported,
verify=True,
audience=config.client_audience,
issuer=config.issuer,
)
request.state.decoded_access_token = decoded

return inner


TRACER = get_tracer("interface")


async def on_key_error_404(_: Request, __: Exception):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down
60 changes: 55 additions & 5 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
import pytest
import responses
import respx
from fastapi import HTTPException
from pydantic import SecretStr
from starlette.status import HTTP_200_OK, HTTP_403_FORBIDDEN

from blueapi.config import OIDCConfig, ServiceAccount
from blueapi.service import main
from blueapi.service import authentication
from blueapi.service.authentication import (
SessionCacheManager,
SessionManager,
TiledAuth,
access_token,
build_access_token_check,
unchecked_bearer_token,
)


Expand Down Expand Up @@ -124,18 +128,18 @@ def test_poll_for_token_timeout(
def test_server_raises_exception_for_invalid_token(
oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock
):
inner = main.decode_access_token(oidc_config)
inner = authentication.build_access_token_check(oidc_config)
with pytest.raises(jwt.PyJWTError):
inner(Mock(), access_token="Invalid Token")
inner(Mock(), token="Invalid Token")


def test_processes_valid_token(
oidc_config: OIDCConfig,
mock_authn_server: responses.RequestsMock,
valid_token_with_jwt,
):
inner = main.decode_access_token(oidc_config)
inner(Mock(), access_token=valid_token_with_jwt["access_token"])
inner = authentication.build_access_token_check(oidc_config)
inner(Mock(), token=valid_token_with_jwt["access_token"])


def test_session_cache_manager_returns_writable_file_path(tmp_path):
Expand Down Expand Up @@ -182,3 +186,49 @@ def test_tiled_auth_sync_auth_flow():
result = next(flow)

assert result.headers["Authorization"] == f"Bearer {access_token}"


@pytest.mark.parametrize(
"header,token",
[
(None, None),
("ApiKey foobar", None),
("Bearer foobar", "foobar"),
("Bearer with_whitespace ", "with_whitespace"),
("Bearerfoobar", None),
],
)
def test_unchecked_bearer_token(header: str | None, token: str | None):
req = Mock()
req.headers.get.side_effect = lambda key: header if key == "Authorization" else None

assert unchecked_bearer_token(req) == token


def test_access_token():
req = Mock()
req.state.decoded_access_token = {"foo": "bar"}

assert access_token(req) == {"foo": "bar"}


def test_access_token_without_token():
req = Mock()
del req.state.decoded_access_token

assert access_token(req) is None


@patch("blueapi.service.authentication.jwt")
def test_build_access_token(mock_jwt: Mock):
# Return None when building client to ensure no field/method access
mock_jwt.PyJWKClient.return_value = None
oidc_config = Mock()
req = Mock()

validate_fn = build_access_token_check(oidc_config)

with pytest.raises(HTTPException, match="401"):
validate_fn(req, token=None)

mock_jwt.decode.assert_not_called()
Loading