Skip to content

Commit e879863

Browse files
further enhanced the implementation
1 parent 09da7cc commit e879863

10 files changed

Lines changed: 100 additions & 152 deletions

File tree

airflow-core/src/airflow/api_fastapi/auth/tokens.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"JWKS",
4747
"JWTGenerator",
4848
"JWTValidator",
49-
"TOKEN_SCOPE_QUEUE",
49+
"TOKEN_SCOPE_WORKLOAD",
5050
"generate_private_key",
5151
"get_sig_validation_args",
5252
"get_signing_args",
@@ -55,7 +55,7 @@
5555
"key_to_jwk_dict",
5656
]
5757

58-
TOKEN_SCOPE_QUEUE = "queue"
58+
TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload"
5959

6060

6161
class InvalidClaimError(ValueError):
@@ -437,15 +437,28 @@ def signing_arg(self) -> AllowedPrivateKeys | str:
437437
assert self._secret_key
438438
return self._secret_key
439439

440-
def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str:
441-
"""Generate a signed JWT for the subject."""
440+
def generate(
441+
self,
442+
extras: dict[str, Any] | None = None,
443+
headers: dict[str, Any] | None = None,
444+
expiry: int | None = None,
445+
) -> str:
446+
"""
447+
Generate a signed JWT.
448+
449+
Args:
450+
extras: Additional claims to include in the token. These are merged with default claims.
451+
headers: Additional headers to include in the JWT.
452+
expiry: Optional custom expiry time in seconds. If not provided, uses self.valid_for.
453+
"""
442454
now = int(datetime.now(tz=timezone.utc).timestamp())
455+
valid_for = expiry if expiry is not None else self.valid_for
443456
claims = {
444457
"jti": uuid.uuid4().hex,
445458
"iss": self.issuer,
446459
"aud": self.audience,
447460
"nbf": now,
448-
"exp": int(now + self.valid_for),
461+
"exp": int(now + valid_for),
449462
"iat": now,
450463
}
451464

@@ -461,38 +474,20 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any]
461474
headers["kid"] = self.kid
462475
return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers)
463476

464-
def generate_queue_token(self, sub: str) -> str:
477+
def generate_workload_token(self, sub: str) -> str:
465478
"""
466-
Generate a long-lived queue token for task workloads.
479+
Generate a long-lived workload token for task execution.
467480
468-
Queue tokens have a special 'scope' claim that restricts them to the /run endpoint only.
469-
They are valid for longer (default 24h) to survive queue wait times.
481+
Workload tokens have a special 'scope' claim that restricts them to the /run endpoint only.
482+
They are valid for longer (default 24h) to survive executor queue wait times.
470483
"""
471484
from airflow.configuration import conf
472485

473-
queue_expiry = conf.getint("execution_api", "jwt_queue_token_expiration_time", fallback=86400)
474-
now = int(datetime.now(tz=timezone.utc).timestamp())
475-
476-
claims = {
477-
"jti": uuid.uuid4().hex,
478-
"iss": self.issuer,
479-
"aud": self.audience,
480-
"nbf": now,
481-
"exp": now + queue_expiry,
482-
"iat": now,
483-
"sub": sub,
484-
"scope": TOKEN_SCOPE_QUEUE,
485-
}
486-
487-
if claims["iss"] is None:
488-
del claims["iss"]
489-
if claims["aud"] is None:
490-
del claims["aud"]
491-
492-
headers = {"alg": self.algorithm}
493-
if self._private_key:
494-
headers["kid"] = self.kid
495-
return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers)
486+
workload_expiry = conf.getint("execution_api", "jwt_workload_token_expiration_time", fallback=86400)
487+
return self.generate(
488+
extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD},
489+
expiry=workload_expiry,
490+
)
496491

497492

498493
def generate_private_key(key_type: str = "RSA", key_size: int = 2048):

airflow-core/src/airflow/api_fastapi/execution_api/app.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -296,24 +296,16 @@ class InProcessExecutionAPI:
296296
@cached_property
297297
def app(self):
298298
if not self._app:
299-
from unittest.mock import AsyncMock, MagicMock
300-
301-
from airflow.api_fastapi.auth.tokens import JWTValidator
302299
from airflow.api_fastapi.common.dagbag import create_dag_bag
300+
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
303301
from airflow.api_fastapi.execution_api.deps import (
304-
DepContainer,
305302
JWTBearerDep,
306-
JWTBearerQueueDep,
307303
JWTBearerTIPathDep,
308304
)
309305
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
306+
from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep
310307
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
311308
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
312-
from airflow.configuration import conf
313-
314-
# Set a dummy JWT secret so the lifespan can create JWT services without failing.
315-
if not conf.get("api_auth", "jwt_secret", fallback=None):
316-
conf.set("api_auth", "jwt_secret", "in-process-test-secret-key")
317309

318310
self._app = create_task_execution_api_app()
319311

@@ -324,33 +316,11 @@ async def always_allow(): ...
324316

325317
self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow
326318
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow
327-
self._app.dependency_overrides[JWTBearerQueueDep.dependency] = always_allow
319+
self._app.dependency_overrides[JWTBearerWorkloadDep.dependency] = always_allow
328320
self._app.dependency_overrides[has_connection_access] = always_allow
329321
self._app.dependency_overrides[has_variable_access] = always_allow
330322
self._app.dependency_overrides[has_xcom_access] = always_allow
331323

332-
# Create a mock container that provides mock JWT services
333-
mock_jwt_generator = MagicMock(spec=JWTGenerator)
334-
mock_jwt_generator.generate.return_value = "mock-execution-token"
335-
336-
mock_jwt_validator = AsyncMock(spec=JWTValidator)
337-
mock_jwt_validator.avalidated_claims.return_value = {"sub": "test", "exp": 9999999999}
338-
339-
class MockContainer:
340-
"""A mock svcs container that returns mock services."""
341-
342-
async def aget(self, svc_type):
343-
if svc_type is JWTGenerator:
344-
return mock_jwt_generator
345-
if svc_type is JWTValidator:
346-
return mock_jwt_validator
347-
raise ValueError(f"Unknown service type: {svc_type}")
348-
349-
async def mock_container_dep():
350-
return MockContainer()
351-
352-
self._app.dependency_overrides[DepContainer.dependency] = mock_container_dep
353-
354324
return self._app
355325

356326
@cached_property

airflow-core/src/airflow/api_fastapi/execution_api/deps.py

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from fastapi.security import HTTPBearer
2727
from sqlalchemy import select
2828

29-
from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_QUEUE, JWTValidator
29+
from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_WORKLOAD, JWTValidator
3030
from airflow.api_fastapi.common.db.common import AsyncSessionDep
3131
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
3232
from airflow.configuration import conf
@@ -46,15 +46,12 @@ async def _container(request: Request):
4646
DepContainer: svcs.Container = Depends(_container)
4747

4848

49-
class JWTBearer(HTTPBearer):
49+
class _BaseJWTBearer(HTTPBearer):
5050
"""
51-
A FastAPI security dependency that validates JWT tokens for the Execution API.
51+
Base class for JWT validation in the Execution API.
5252
53-
This validates tokens are signed and that the ``sub`` is a UUID. Queue-scoped tokens
54-
(with scope="queue") are rejected - they can only be used on the /run endpoint.
55-
56-
The dependency result will be a `TIToken` object containing the ``id`` UUID (from the ``sub``)
57-
and other validated claims.
53+
Validates JWT tokens are properly signed and extracts claims. Subclasses
54+
handle scope-specific validation.
5855
"""
5956

6057
def __init__(
@@ -88,14 +85,8 @@ async def __call__( # type: ignore[override]
8885
validators = self.required_claims
8986
claims = await validator.avalidated_claims(creds.credentials, validators)
9087

91-
# Reject queue-scoped tokens - they can only be used on /run endpoint
92-
# Only check if scope claim is present (allows backwards compatibility with tests)
93-
scope = claims.get("scope")
94-
if scope is not None and scope == TOKEN_SCOPE_QUEUE:
95-
raise HTTPException(
96-
status_code=status.HTTP_403_FORBIDDEN,
97-
detail="Queue tokens cannot access this endpoint. Use the token from /run response.",
98-
)
88+
# Let subclasses validate scope
89+
self._check_scope(claims)
9990

10091
return TIToken(id=claims["sub"], claims=claims)
10192
except HTTPException:
@@ -104,64 +95,51 @@ async def __call__( # type: ignore[override]
10495
log.warning("Failed to validate JWT", exc_info=True)
10596
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}")
10697

98+
def _check_scope(self, claims: dict[str, Any]) -> None:
99+
"""Override in subclasses to validate scope. Raise HTTPException if invalid."""
100+
pass
107101

108-
class JWTBearerQueueScope(HTTPBearer):
109-
"""
110-
JWT auth dependency that ONLY accepts queue-scoped tokens.
111102

112-
Used exclusively by the /run endpoint. Queue tokens have scope="queue" and are
113-
long-lived to survive executor queue wait times. The /run endpoint validates
114-
the queue token and issues a short-lived execution token for subsequent API calls.
103+
class JWTBearer(_BaseJWTBearer):
115104
"""
105+
JWT validation that rejects workload-scoped tokens.
116106
117-
def __init__(self, path_param_name: str | None = None):
118-
super().__init__(auto_error=False)
119-
self.path_param_name = path_param_name
107+
Used for most Execution API endpoints. Workload-scoped tokens can only be used
108+
on the /run endpoint, which exchanges them for short-lived execution tokens.
109+
"""
120110

121-
async def __call__( # type: ignore[override]
122-
self,
123-
request: Request,
124-
services=DepContainer,
125-
) -> TIToken | None:
126-
creds = await super().__call__(request)
127-
if not creds:
128-
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token")
111+
def _check_scope(self, claims: dict[str, Any]) -> None:
112+
if claims.get("scope") == TOKEN_SCOPE_WORKLOAD:
113+
raise HTTPException(
114+
status_code=status.HTTP_403_FORBIDDEN,
115+
detail="Workload tokens cannot access this endpoint. Use the token from /run response.",
116+
)
129117

130-
validator: JWTValidator = await services.aget(JWTValidator)
131118

132-
try:
133-
if self.path_param_name:
134-
id = request.path_params[self.path_param_name]
135-
validators: dict[str, Any] = {"sub": {"essential": True, "value": id}}
136-
else:
137-
validators = {}
138-
claims = await validator.avalidated_claims(creds.credentials, validators)
119+
class JWTBearerWorkloadScope(_BaseJWTBearer):
120+
"""
121+
JWT validation that ONLY accepts workload-scoped tokens.
139122
140-
# Only accept queue-scoped tokens (if scope claim is present)
141-
# This allows backwards compatibility with tests that don't set scope
142-
scope = claims.get("scope")
143-
if scope is not None and scope != TOKEN_SCOPE_QUEUE:
144-
raise HTTPException(
145-
status_code=status.HTTP_403_FORBIDDEN,
146-
detail="This endpoint requires a queue-scoped token",
147-
)
123+
Used exclusively by the /run endpoint. Workload tokens have scope="ExecuteTaskWorkload"
124+
and are long-lived to survive executor queue wait times. The /run endpoint validates
125+
the workload token and issues a short-lived execution token for subsequent API calls.
126+
"""
148127

149-
return TIToken(id=claims["sub"], claims=claims)
150-
except HTTPException:
151-
raise
152-
except Exception as err:
153-
log.warning("Failed to validate JWT", exc_info=True)
154-
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}")
128+
def _check_scope(self, claims: dict[str, Any]) -> None:
129+
scope = claims.get("scope")
130+
# Reject if scope is explicitly set to something other than workload scope
131+
if scope is not None and scope != TOKEN_SCOPE_WORKLOAD:
132+
raise HTTPException(
133+
status_code=status.HTTP_403_FORBIDDEN,
134+
detail="This endpoint requires a workload-scoped token",
135+
)
155136

156137

157138
JWTBearerDep: TIToken = Depends(JWTBearer())
158139

159140
# This checks that the UUID in the url matches the one in the token for us.
160141
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
161142

162-
# For /run endpoint only - accepts queue-scoped tokens and validates task_instance_id
163-
JWTBearerQueueDep = Depends(JWTBearerQueueScope(path_param_name="task_instance_id"))
164-
165143

166144
async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None:
167145
"""Return the team name associated to the task (if any)."""

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import attrs
2929
import structlog
3030
from cadwyn import VersionedAPIRouter
31-
from fastapi import Body, HTTPException, Query, Response, status
31+
from fastapi import Body, Depends, HTTPException, Query, Response, status
3232
from pydantic import JsonValue
3333
from sqlalchemy import func, or_, tuple_, update
3434
from sqlalchemy.engine import CursorResult, Row
@@ -60,7 +60,7 @@
6060
TISuccessStatePayload,
6161
TITerminalStatePayload,
6262
)
63-
from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerQueueDep, JWTBearerTIPathDep
63+
from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerTIPathDep, JWTBearerWorkloadScope
6464
from airflow.exceptions import TaskNotFound
6565
from airflow.models.asset import AssetActive
6666
from airflow.models.dag import DagModel
@@ -90,6 +90,9 @@
9090

9191
log = structlog.get_logger(__name__)
9292

93+
# For /run endpoint only - accepts workload-scoped tokens and validates task_instance_id
94+
JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id"))
95+
9396

9497
@router.patch(
9598
"/{task_instance_id}/run",
@@ -100,7 +103,7 @@
100103
HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"},
101104
},
102105
response_model_exclude_unset=True,
103-
dependencies=[JWTBearerQueueDep],
106+
dependencies=[JWTBearerWorkloadDep],
104107
)
105108
async def ti_run(
106109
task_instance_id: UUID,

airflow-core/src/airflow/config_templates/config.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,14 +1822,14 @@ execution_api:
18221822
type: integer
18231823
example: ~
18241824
default: "600"
1825-
jwt_queue_token_expiration_time:
1825+
jwt_workload_token_expiration_time:
18261826
description: |
1827-
Number in seconds until the queue JWT token expires. Queue tokens are long-lived tokens
1827+
Number in seconds until the workload JWT token expires. Workload tokens are long-lived tokens
18281828
sent with task workloads to executors (e.g., Celery). They can only be used to call
18291829
the /run endpoint, which then issues a short-lived execution token.
18301830
18311831
This should be set long enough to cover the maximum expected queue wait time.
1832-
version_added: 3.1.0
1832+
version_added: 3.1.7
18331833
type: integer
18341834
example: ~
18351835
default: "86400"

airflow-core/src/airflow/executors/workloads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ class BaseWorkload(BaseModel):
4646
@staticmethod
4747
def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str:
4848
"""
49-
Generate a queue-scoped token for this workload.
49+
Generate a workload-scoped token for this workload.
5050
51-
Queue tokens are long-lived and can only be used on the /run endpoint,
51+
Workload tokens are long-lived and can only be used on the /run endpoint,
5252
which exchanges them for short-lived execution tokens.
5353
"""
54-
return generator.generate_queue_token(sub_id) if generator else ""
54+
return generator.generate_workload_token(sub_id) if generator else ""
5555

5656

5757
class BundleInfo(BaseModel):

0 commit comments

Comments
 (0)