Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,12 @@ def validate_bearer_token(request: Request, token: UncheckedBearerToken):
def access_token(request: Request) -> Mapping[str, Any] | None:
"""Get the decoded and verified access token of the user making the request"""
return getattr(request.state, "decoded_access_token", None)


def fedid(
access_token: Annotated[Mapping[str, Any] | None, Depends(access_token)],
) -> str | None:
return access_token.get("fedid") if access_token else None


Fedid = Annotated[str | None, Depends(fedid)]
15 changes: 4 additions & 11 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Annotated, Any
from typing import Annotated

import jwt
from fastapi import (
Expand Down Expand Up @@ -36,7 +36,7 @@
from blueapi import __version__
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.service import interface
from blueapi.service.authentication import build_access_token_check
from blueapi.service.authentication import Fedid, build_access_token_check
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

Expand Down Expand Up @@ -267,18 +267,11 @@ def submit_task(
response: Response,
task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
user: Fedid,
) -> TaskResponse:
"""Submit a task to the worker."""
try:
# Extract user from jwt if using OIDC (if jwt exists)
access_token: dict[str, Any] | None = getattr(
request.state, "decoded_access_token", None
)
if access_token:
user: str = access_token.get("fedid", "Unknown")
else:
user = "Unknown"

user = user or "Unknown"
task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
Expand Down
Loading