Skip to content

Commit ff8f59c

Browse files
address review feeback
1 parent d1101d4 commit ff8f59c

11 files changed

Lines changed: 108 additions & 54 deletions

File tree

airflow-core/src/airflow/api_fastapi/auth/tokens.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ class JWTGenerator:
418418

419419
kid: str = attrs.field(default=attrs.Factory(_generate_kid, takes_self=True))
420420
valid_for: float
421+
workload_valid_for: float = attrs.field(
422+
factory=_conf_factory("execution_api", "jwt_workload_token_expiration_time", fallback="86400"),
423+
converter=float,
424+
)
421425
audience: str
422426
issuer: str | list[str] | None = attrs.field(
423427
factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True, fallback=None)
@@ -447,18 +451,6 @@ def signing_arg(self) -> AllowedPrivateKeys | str:
447451
assert self._secret_key
448452
return self._secret_key
449453

450-
def generate_workload_token(self, sub: str) -> str:
451-
"""Generate a long-lived workload token for executor queues."""
452-
from airflow.configuration import conf
453-
454-
workload_valid_for = conf.getint(
455-
"execution_api", "jwt_workload_token_expiration_time", fallback=86400
456-
)
457-
return self.generate(
458-
extras={"sub": sub, "scope": "workload"},
459-
valid_for=workload_valid_for,
460-
)
461-
462454
def generate(
463455
self,
464456
extras: dict[str, Any] | None = None,

airflow-core/src/airflow/api_fastapi/execution_api/app.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import json
21+
import secrets
2122
import time
2223
from contextlib import AsyncExitStack
2324
from functools import cached_property
@@ -76,6 +77,7 @@ def _jwt_generator():
7677

7778
generator = JWTGenerator(
7879
valid_for=conf.getint("execution_api", "jwt_expiration_time"),
80+
workload_valid_for=conf.getint("execution_api", "jwt_workload_token_expiration_time", fallback=86400),
7981
audience=conf.get_mandatory_list_value("execution_api", "jwt_audience")[0],
8082
issuer=conf.get("api_auth", "jwt_issuer", fallback=None),
8183
# Since this one is used across components/server, there is no point trying to generate one, error
@@ -142,6 +144,11 @@ async def dispatch(self, request: Request, call_next):
142144
validator: JWTValidator = await services.aget(JWTValidator)
143145
claims = await validator.avalidated_claims(token, {})
144146

147+
# Workload tokens are long-lived and meant to survive queue
148+
# wait times so avoid refreshing them.
149+
if claims.get("scope") == "workload":
150+
return response
151+
145152
now = int(time.time())
146153
validity = conf.getint("execution_api", "jwt_expiration_time")
147154
refresh_when_less_than = max(int(validity * 0.20), 30)
@@ -341,8 +348,14 @@ async def always_allow(request: Request):
341348
# Cadwyn's versioned sub-apps don't inherit the main app's state,
342349
# so lookups raise ServiceNotFoundError. This registry provides
343350
# services needed by routes called during dag.test().
351+
#
352+
stub_generator = JWTGenerator(
353+
secret_key=secrets.token_urlsafe(32),
354+
audience="in-process",
355+
valid_for=3600,
356+
)
344357
registry = svcs.Registry()
345-
registry.register_factory(JWTGenerator, _jwt_generator)
358+
registry.register_value(JWTGenerator, stub_generator)
346359

347360
async def _in_process_container(request: Request):
348361
async with svcs.Container(registry) as cont:

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
},
103103
response_model_exclude_unset=True,
104104
)
105-
async def ti_run(
105+
def ti_run(
106106
task_instance_id: UUID,
107107
ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
108108
response: Response,
@@ -286,18 +286,18 @@ async def ti_run(
286286
if ti.next_method:
287287
context.next_method = ti.next_method
288288
context.next_kwargs = ti.next_kwargs
289-
290-
generator: JWTGenerator = await services.aget(JWTGenerator)
291-
execution_token = generator.generate(extras={"sub": str(task_instance_id)})
292-
response.headers["X-Execution-Token"] = execution_token
293-
294-
return context
295289
except SQLAlchemyError:
296290
log.exception("Error marking Task Instance state as running")
297291
raise HTTPException(
298292
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
299293
)
300294

295+
generator: JWTGenerator = services.get(JWTGenerator)
296+
execution_token = generator.generate(extras={"sub": str(task_instance_id)})
297+
response.headers["X-Execution-Token"] = execution_token
298+
299+
return context
300+
301301

302302
@ti_id_router.patch(
303303
"/{task_instance_id}/state",

airflow-core/src/airflow/config_templates/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ execution_api:
20742074
Seconds until workload JWT tokens expire. These long-lived tokens are sent
20752075
with task workloads to executors and can only call the /run endpoint.
20762076
Set long enough to cover maximum expected queue wait time.
2077-
version_added: ~
2077+
version_added: 3.2.0
20782078
type: integer
20792079
example: ~
20802080
default: "86400"

airflow-core/src/airflow/executors/workloads/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@ class BaseWorkloadSchema(BaseModel):
7474

7575
@staticmethod
7676
def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str:
77-
return generator.generate_workload_token(sub_id) if generator else ""
77+
if not generator:
78+
return ""
79+
return generator.generate(
80+
extras={"sub": sub_id, "scope": "workload"},
81+
valid_for=generator.workload_valid_for,
82+
)
7883

7984

8085
class BaseDagBundleWorkload(BaseWorkloadSchema, ABC):

airflow-core/tests/unit/api_fastapi/auth/test_tokens.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -160,48 +160,33 @@ def test_secret_key_with_configured_kid():
160160
assert header["kid"] == "my-custom-kid"
161161

162162

163-
def test_generate_workload_token():
164-
"""generate_workload_token() produces a token with scope 'workload' and 24h expiry."""
163+
def test_generate_with_custom_valid_for():
164+
"""generate() accepts a valid_for override."""
165165
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)
166+
token = generator.generate(extras={"sub": "user"}, valid_for=3600)
167+
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
168+
assert claims["exp"] - claims["iat"] == 3600
166169

167-
with patch.dict(
168-
"os.environ",
169-
{"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"},
170-
):
171-
token = generator.generate_workload_token(sub="ti-123")
172170

171+
def test_generate_workload_scope_via_extras():
172+
"""generate() with scope='workload' in extras produces a workload-scoped token."""
173+
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)
174+
175+
token = generator.generate(extras={"sub": "ti-123", "scope": "workload"}, valid_for=86400)
173176
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
174177
assert claims["sub"] == "ti-123"
175178
assert claims["scope"] == "workload"
176-
# Workload token should have ~24h validity, not the generator's default 60s
177179
assert claims["exp"] - claims["iat"] == 86400
178180

179181

180-
def test_generate_with_custom_valid_for():
181-
"""generate() accepts a valid_for override."""
182-
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)
183-
token = generator.generate(extras={"sub": "user"}, valid_for=3600)
184-
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
185-
assert claims["exp"] - claims["iat"] == 3600
186-
187-
188-
def test_workload_token_vs_regular_token_scope():
189-
"""Regular tokens have no scope, workload tokens have scope 'workload'."""
182+
def test_regular_token_has_no_scope():
183+
"""Regular tokens without scope in extras have no scope claim."""
190184
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)
191185

192186
regular = generator.generate(extras={"sub": "user"})
193187
regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"], audience="test")
194188
assert "scope" not in regular_claims
195189

196-
with patch.dict(
197-
"os.environ",
198-
{"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"},
199-
):
200-
workload = generator.generate_workload_token(sub="ti-123")
201-
202-
workload_claims = jwt.decode(workload, "test-secret", algorithms=["HS512"], audience="test")
203-
assert workload_claims["scope"] == "workload"
204-
205190

206191
@pytest.fixture
207192
def jwt_generator(ed25519_private_key: Ed25519PrivateKey):

airflow-core/tests/unit/api_fastapi/execution_api/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ async def mock_jwt_bearer(request: Request):
5959
with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
6060
mock_generator = MagicMock(spec=JWTGenerator)
6161
mock_generator.generate.return_value = "mock-execution-token"
62-
mock_generator.generate_workload_token.return_value = "mock-workload-token"
6362
lifespan.registry.register_value(JWTGenerator, mock_generator)
6463

6564
yield client
6665

66+
lifespan.registry.close()
67+
6768
exec_app.dependency_overrides.pop(_jwt_bearer, None)

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3256,6 +3256,44 @@ def test_invalid_scope_value_rejected(self, client, session, create_task_instanc
32563256
assert resp.status_code == 403
32573257
assert "Invalid token scope" in resp.json()["detail"]
32583258

3259+
def test_workload_scope_accepted_on_run_endpoint(
3260+
self, client, session, create_task_instance, time_machine
3261+
):
3262+
"""workload scoped tokens should be accepted on the /run endpoint."""
3263+
instant = timezone.parse("2024-10-31T12:00:00Z")
3264+
time_machine.move_to(instant, tick=False)
3265+
3266+
ti = create_task_instance(
3267+
task_id="test_workload_run",
3268+
state=State.QUEUED,
3269+
dagrun_state=DagRunState.RUNNING,
3270+
session=session,
3271+
start_date=instant,
3272+
dag_id=str(uuid4()),
3273+
)
3274+
session.commit()
3275+
3276+
validator = mock.AsyncMock(spec=JWTValidator)
3277+
validator.avalidated_claims.side_effect = lambda cred, validators: {
3278+
"sub": str(ti.id),
3279+
"scope": "workload",
3280+
"exp": 9999999999,
3281+
"iat": 1000000000,
3282+
}
3283+
lifespan.registry.register_value(JWTValidator, validator)
3284+
3285+
resp = client.patch(
3286+
f"/execution/task-instances/{ti.id}/run",
3287+
json={
3288+
"state": "running",
3289+
"hostname": "test-host",
3290+
"unixname": "test-user",
3291+
"pid": 100,
3292+
"start_date": "2024-10-31T12:00:00Z",
3293+
},
3294+
)
3295+
assert resp.status_code == 200
3296+
32593297
def test_no_scope_defaults_to_execution(self, client, session, create_task_instance):
32603298
"""Tokens without scope claim should default to 'execution'."""
32613299
ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING)

airflow-core/tests/unit/executors/test_workloads.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
from pathlib import PurePosixPath
2121
from uuid import uuid4
2222

23+
import jwt
24+
25+
from airflow.api_fastapi.auth.tokens import JWTGenerator
2326
from airflow.executors import workloads
2427
from airflow.executors.workloads import TaskInstance, TaskInstanceDTO
25-
from airflow.executors.workloads.base import BundleInfo
28+
from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo
2629
from airflow.executors.workloads.task import ExecuteTask
2730

2831

@@ -61,3 +64,21 @@ def test_token_excluded_from_workload_repr():
6164
assert fake_token not in workload_repr, f"JWT token leaked into repr! Found token in: {workload_repr}"
6265
# But token should still be accessible as an attribute
6366
assert workload.token == fake_token
67+
68+
69+
def test_generate_token_produces_workload_scope():
70+
"""generate_token should create a JWT with scope 'workload' and workload_valid_for expiry."""
71+
generator = JWTGenerator(
72+
secret_key="test-secret", audience="test", valid_for=60, workload_valid_for=86400
73+
)
74+
token = BaseWorkloadSchema.generate_token("ti-123", generator)
75+
76+
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
77+
assert claims["sub"] == "ti-123"
78+
assert claims["scope"] == "workload"
79+
assert claims["exp"] - claims["iat"] == 86400
80+
81+
82+
def test_generate_token_without_generator():
83+
"""generate_token should return empty string when no generator is provided."""
84+
assert BaseWorkloadSchema.generate_token("ti-123", None) == ""

airflow-core/tests/unit/jobs/test_scheduler_job.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ def set_instance_attrs(self) -> Generator:
302302
def mock_executors(self):
303303
mock_jwt_generator = MagicMock(spec=JWTGenerator)
304304
mock_jwt_generator.generate.return_value = "mock-token"
305-
mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token"
306305

307306
default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0)
308307
default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path")

0 commit comments

Comments
 (0)