Skip to content

Commit 4ef9dfe

Browse files
Add token refresh mechanism for Execution API (#59553)
Tasks waiting in Celery queue may have their JWT tokens expire before execution starts. This adds a token refresh endpoint that allows the supervisor to refresh expired tokens before task execution. Changes: - Add /token/refresh endpoint to Execution API - Add client-side token refresh logic in supervisor.py - Add tests for the new endpoint Fixes: #59553
1 parent 27047f9 commit 4ef9dfe

6 files changed

Lines changed: 681 additions & 0 deletions

File tree

airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@
2929
hitl,
3030
task_instances,
3131
task_reschedules,
32+
tokens,
3233
variables,
3334
xcoms,
3435
)
3536

3637
execution_api_router = APIRouter()
3738
execution_api_router.include_router(health.router, prefix="/health", tags=["Health"])
3839

40+
execution_api_router.include_router(tokens.router, prefix="/token", tags=["Token"])
41+
3942
# _Every_ single endpoint under here must be authenticated. Some do further checks on top of these
4043
authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item]
4144

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
Token refresh endpoint for the Execution API.
20+
21+
This module provides an endpoint for workers to refresh expired JWT tokens.
22+
When a task waits in a queue (e.g., Celery) for longer than the token validity
23+
period, the worker can use this endpoint to obtain a fresh token before
24+
starting task execution.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import jwt
30+
import structlog
31+
from fastapi import APIRouter, Body, HTTPException, status
32+
from pydantic import BaseModel
33+
from sqlalchemy.exc import NoResultFound
34+
from sqlalchemy.sql import select
35+
36+
from airflow.api_fastapi.common.db.common import SessionDep
37+
from airflow.api_fastapi.execution_api.app import _jwt_generator, _jwt_validator
38+
from airflow.models.taskinstance import TaskInstance as TI
39+
from airflow.utils.state import TaskInstanceState
40+
41+
log = structlog.get_logger(logger_name=__name__)
42+
43+
router = APIRouter()
44+
45+
46+
class TokenRefreshRequest(BaseModel):
47+
"""Request body for token refresh."""
48+
49+
token: str
50+
"""The expired (or about to expire) JWT token to refresh."""
51+
52+
53+
class TokenRefreshResponse(BaseModel):
54+
"""Response body for token refresh."""
55+
56+
access_token: str
57+
"""The new JWT token."""
58+
59+
60+
@router.post(
61+
"/refresh",
62+
status_code=status.HTTP_200_OK,
63+
responses={
64+
status.HTTP_400_BAD_REQUEST: {"description": "Invalid token format or signature"},
65+
status.HTTP_403_FORBIDDEN: {"description": "Task is not in a valid state for token refresh"},
66+
status.HTTP_404_NOT_FOUND: {"description": "Task instance not found"},
67+
},
68+
)
69+
def refresh_token(
70+
session: SessionDep,
71+
body: TokenRefreshRequest = Body(...),
72+
) -> TokenRefreshResponse:
73+
"""
74+
Refresh an expired or expiring JWT token for task execution.
75+
76+
This endpoint allows workers to obtain a fresh token when the original token
77+
(embedded in the workload) has expired while waiting in a queue. The endpoint:
78+
79+
1. Validates the token signature (but ignores expiration)
80+
2. Extracts the task instance ID from the token
81+
3. Verifies the task is in QUEUED or RUNNING state
82+
4. Issues a fresh token
83+
84+
This is necessary for distributed executors like Celery where tasks may wait
85+
in queues longer than the token validity period.
86+
"""
87+
token = body.token
88+
validator = _jwt_validator()
89+
generator = _jwt_generator()
90+
91+
try:
92+
if validator.secret_key:
93+
key = validator.secret_key
94+
elif validator.jwks:
95+
header = jwt.get_unverified_header(token)
96+
kid = header.get("kid")
97+
if not kid:
98+
raise HTTPException(
99+
status_code=status.HTTP_400_BAD_REQUEST,
100+
detail="Token missing 'kid' header",
101+
)
102+
from asgiref.sync import async_to_sync
103+
104+
key = async_to_sync(validator.jwks.get_key)(kid)
105+
else:
106+
raise HTTPException(
107+
status_code=status.HTTP_400_BAD_REQUEST,
108+
detail="No validation key configured",
109+
)
110+
111+
claims = jwt.decode(
112+
token,
113+
key,
114+
audience=validator.audience,
115+
issuer=validator.issuer,
116+
options={
117+
"verify_exp": False,
118+
"require": ["sub"],
119+
},
120+
algorithms=validator.algorithm,
121+
leeway=validator.leeway,
122+
)
123+
except jwt.InvalidTokenError as e:
124+
log.warning("Invalid token for refresh", error=str(e))
125+
raise HTTPException(
126+
status_code=status.HTTP_400_BAD_REQUEST,
127+
detail=f"Invalid token: {e}",
128+
)
129+
130+
task_instance_id = claims.get("sub")
131+
if not task_instance_id:
132+
raise HTTPException(
133+
status_code=status.HTTP_400_BAD_REQUEST,
134+
detail="Token missing 'sub' claim (task instance ID)",
135+
)
136+
137+
try:
138+
ti_state = session.execute(select(TI.state).where(TI.id == task_instance_id)).scalar_one()
139+
except NoResultFound:
140+
log.warning("Task instance not found for token refresh", task_instance_id=task_instance_id)
141+
raise HTTPException(
142+
status_code=status.HTTP_404_NOT_FOUND,
143+
detail="Task instance not found",
144+
)
145+
146+
valid_states = {TaskInstanceState.QUEUED, TaskInstanceState.RUNNING}
147+
if ti_state not in valid_states:
148+
log.warning(
149+
"Task not in valid state for token refresh",
150+
task_instance_id=task_instance_id,
151+
state=ti_state,
152+
)
153+
raise HTTPException(
154+
status_code=status.HTTP_403_FORBIDDEN,
155+
detail=f"Task is in '{ti_state}' state, token refresh only allowed for QUEUED or RUNNING tasks",
156+
)
157+
158+
new_token = generator.generate({"sub": task_instance_id})
159+
160+
log.info("Token refreshed successfully", task_instance_id=task_instance_id)
161+
162+
return TokenRefreshResponse(access_token=new_token)

0 commit comments

Comments
 (0)