From a5a68cd9fbba17c7f413cb90e1126890304b5001 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 14 Mar 2026 04:37:56 -0500 Subject: [PATCH 1/7] layed out two mechanism based on new security scope arichitecture --- .../src/airflow/api_fastapi/auth/tokens.py | 22 +++++++++- .../execution_api/routes/task_instances.py | 14 +++++- .../src/airflow/config_templates/config.yml | 9 ++++ .../src/airflow/executors/workloads/base.py | 2 +- .../unit/api_fastapi/auth/test_tokens.py | 43 +++++++++++++++++++ .../api_fastapi/execution_api/conftest.py | 10 +++++ .../versions/head/test_task_instances.py | 30 +++++++++++++ .../tests_common/test_utils/mock_executor.py | 1 + task-sdk/src/airflow/sdk/api/client.py | 5 ++- task-sdk/tests/task_sdk/api/test_client.py | 32 ++++++++++++++ 10 files changed, 162 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 3375853a29a5d..0fac48d9f4ea0 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -447,15 +447,33 @@ def signing_arg(self) -> AllowedPrivateKeys | str: assert self._secret_key return self._secret_key - def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str: + def generate_workload_token(self, sub: str) -> str: + """Generate a long-lived workload token for executor queues.""" + from airflow.configuration import conf + + workload_valid_for = conf.getint( + "execution_api", "jwt_workload_token_expiration_time", fallback=86400 + ) + return self.generate( + extras={"sub": sub, "scope": "workload"}, + valid_for=workload_valid_for, + ) + + def generate( + self, + extras: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + valid_for: float | None = None, + ) -> str: """Generate a signed JWT for the subject.""" now = int(datetime.now(tz=timezone.utc).timestamp()) + effective_valid_for = valid_for if valid_for is not None else self.valid_for claims = { "jti": uuid.uuid4().hex, "iss": self.issuer, "aud": self.audience, "nbf": now, - "exp": int(now + self.valid_for), + "exp": int(now + effective_valid_for), "iat": now, } diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 5f5073c916b68..b925486da06d6 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, Security, status +from fastapi import Body, HTTPException, Query, Response, Security, status from opentelemetry import trace from opentelemetry.trace import StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -42,6 +42,7 @@ from airflow._shared.observability.traces import override_ids from airflow._shared.timezones import timezone +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime @@ -63,6 +64,7 @@ TISuccessStatePayload, TITerminalStatePayload, ) +from airflow.api_fastapi.execution_api.deps import DepContainer from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive @@ -97,6 +99,7 @@ @ti_id_router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, + dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])], responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, @@ -104,11 +107,13 @@ }, response_model_exclude_unset=True, ) -def ti_run( +async def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], + response: Response, session: SessionDep, dag_bag: DagBagDep, + services=DepContainer, ) -> TIRunContext: """ Run a TaskInstance. @@ -287,6 +292,11 @@ def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs + # Issue a short-lived execution token for subsequent API calls + generator: JWTGenerator = await services.aget(JWTGenerator) + execution_token = generator.generate(extras={"sub": str(task_instance_id)}) + response.headers["X-Execution-Token"] = execution_token + return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 55313198d5b39..6a1e9d2883f75 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2079,6 +2079,15 @@ execution_api: type: integer example: ~ default: "600" + jwt_workload_token_expiration_time: + description: | + Seconds until workload JWT tokens expire. These long-lived tokens are sent + with task workloads to executors and can only call the /run endpoint. + Set long enough to cover maximum expected queue wait time. + version_added: ~ + type: integer + example: ~ + default: "86400" jwt_audience: version_added: 3.0.0 description: | diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 97cf16ebaf64d..69bf85819edae 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -74,7 +74,7 @@ class BaseWorkloadSchema(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" + return generator.generate_workload_token(sub_id) if generator else "" class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index 6b848f723a004..a9b9612ad43b7 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -160,6 +160,49 @@ def test_secret_key_with_configured_kid(): assert header["kid"] == "my-custom-kid" +def test_generate_workload_token(): + """generate_workload_token() produces a token with scope 'workload' and 24h expiry.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + + with patch.dict( + "os.environ", + {"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"}, + ): + token = generator.generate_workload_token(sub="ti-123") + + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["sub"] == "ti-123" + assert claims["scope"] == "workload" + # Workload token should have ~24h validity, not the generator's default 60s + assert claims["exp"] - claims["iat"] == 86400 + + +def test_generate_with_custom_valid_for(): + """generate() accepts a valid_for override.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + token = generator.generate(extras={"sub": "user"}, valid_for=3600) + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["exp"] - claims["iat"] == 3600 + + +def test_workload_token_vs_regular_token_scope(): + """Regular tokens have no scope, workload tokens have scope 'workload'.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + + regular = generator.generate(extras={"sub": "user"}) + regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"], audience="test") + assert "scope" not in regular_claims + + with patch.dict( + "os.environ", + {"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"}, + ): + workload = generator.generate_workload_token(sub="ti-123") + + workload_claims = jwt.decode(workload, "test-secret", algorithms=["HS512"], audience="test") + assert workload_claims["scope"] == "workload" + + @pytest.fixture def jwt_generator(ed25519_private_key: Ed25519PrivateKey): key = ed25519_private_key diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 78bd0548df9d2..2c0f50e31c5b4 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -16,11 +16,15 @@ # under the License. from __future__ import annotations +from unittest.mock import MagicMock + import pytest from fastapi import FastAPI, Request from fastapi.testclient import TestClient from airflow.api_fastapi.app import cached_app +from airflow.api_fastapi.auth.tokens import JWTGenerator +from airflow.api_fastapi.execution_api.app import lifespan from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.api_fastapi.execution_api.security import _jwt_bearer @@ -53,6 +57,12 @@ async def mock_jwt_bearer(request: Request): exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: + # Register mock JWTGenerator after lifespan starts so endpoints can issue tokens + mock_generator = MagicMock(spec=JWTGenerator) + mock_generator.generate.return_value = "mock-execution-token" + mock_generator.generate_workload_token.return_value = "mock-workload-token" + lifespan.registry.register_value(JWTGenerator, mock_generator) + yield client exec_app.dependency_overrides.pop(_jwt_bearer, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 7f766ede71e4d..96759a3ec15c0 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -254,6 +254,36 @@ def test_ti_run_state_to_running( ) assert response.status_code == 409 + def test_ti_run_returns_execution_token(self, client, session, create_task_instance, time_machine): + """PATCH /run should return an X-Execution-Token header on success.""" + instant = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_exec_token", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "test-host", + "unixname": "test-user", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + }, + ) + + assert response.status_code == 200 + assert "X-Execution-Token" in response.headers + assert response.headers["X-Execution-Token"] == "mock-execution-token" + def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker): """Test that dynamic task mapping works correctly with parse-time values.""" with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True): diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 4e95ed3a4eea7..cf53e9ae56d33 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -59,6 +59,7 @@ def __init__(self, do_update=True, *args, **kwargs): # Mock JWT generator for token generation mock_jwt_generator = MagicMock() mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token" self.jwt_generator = mock_jwt_generator diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 19e691281f73b..71e3d59575281 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -963,7 +963,10 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * ) def _update_auth(self, response: httpx.Response): - if new_token := response.headers.get("Refreshed-API-Token"): + if new_token := response.headers.get("X-Execution-Token"): + log.debug("Received execution token, swapping auth") + self.auth = BearerAuth(new_token) + elif new_token := response.headers.get("Refreshed-API-Token"): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 0df8839c55f30..883844da6958a 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -250,6 +250,38 @@ def test_token_renewal(self): assert response.status_code == 200 assert response.request.headers["Authorization"] == "Bearer abc" + def test_execution_token_swap(self): + """X-Execution-Token header should replace the auth token.""" + responses: list[httpx.Response] = [ + httpx.Response(200, json={"ok": "1"}, headers={"X-Execution-Token": "exec-token-123"}), + httpx.Response(200, json={"ok": "2"}), + ] + client = make_client_w_responses(responses) + response = client.get("/") + assert response.status_code == 200 + + # Auth should have been swapped to the execution token + assert client.auth is not None + assert client.auth.token == "exec-token-123" + + # Next request should use the new token + response = client.get("/") + assert response.status_code == 200 + assert response.request.headers["Authorization"] == "Bearer exec-token-123" + + def test_execution_token_takes_priority_over_refreshed_token(self): + """When both headers present, X-Execution-Token should take priority.""" + responses: list[httpx.Response] = [ + httpx.Response( + 200, + json={"ok": "1"}, + headers={"X-Execution-Token": "exec-tok", "Refreshed-API-Token": "refresh-tok"}, + ), + ] + client = make_client_w_responses(responses) + client.get("/") + assert client.auth.token == "exec-tok" + @pytest.mark.parametrize( ("status_code", "description"), [ From 8e82295033394f41206a3d6ea4230a5f9ffc2ccf Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 14 Mar 2026 04:44:51 -0500 Subject: [PATCH 2/7] clean ups --- .../airflow/api_fastapi/execution_api/routes/task_instances.py | 1 - airflow-core/tests/unit/api_fastapi/execution_api/conftest.py | 1 - task-sdk/tests/task_sdk/api/test_client.py | 2 -- 3 files changed, 4 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index b925486da06d6..efaca9ba3fab4 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -292,7 +292,6 @@ async def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs - # Issue a short-lived execution token for subsequent API calls generator: JWTGenerator = await services.aget(JWTGenerator) execution_token = generator.generate(extras={"sub": str(task_instance_id)}) response.headers["X-Execution-Token"] = execution_token diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 2c0f50e31c5b4..8ce394c7d6163 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -57,7 +57,6 @@ async def mock_jwt_bearer(request: Request): exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: - # Register mock JWTGenerator after lifespan starts so endpoints can issue tokens mock_generator = MagicMock(spec=JWTGenerator) mock_generator.generate.return_value = "mock-execution-token" mock_generator.generate_workload_token.return_value = "mock-workload-token" diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 883844da6958a..a6bcf89d4d92d 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -260,11 +260,9 @@ def test_execution_token_swap(self): response = client.get("/") assert response.status_code == 200 - # Auth should have been swapped to the execution token assert client.auth is not None assert client.auth.token == "exec-token-123" - # Next request should use the new token response = client.get("/") assert response.status_code == 200 assert response.request.headers["Authorization"] == "Bearer exec-token-123" From 54448b1ecf2d3103747a89776cea82dd57d3a6ce Mon Sep 17 00:00:00 2001 From: Anish Date: Sun, 15 Mar 2026 20:03:15 -0500 Subject: [PATCH 3/7] fixing test --- .../airflow/api_fastapi/execution_api/app.py | 20 +++++++++++++++++++ .../tests/unit/jobs/test_scheduler_job.py | 1 + 2 files changed, 21 insertions(+) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index c7a9593c3c82f..4d050c47763e3 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -311,9 +311,13 @@ class InProcessExecutionAPI: @cached_property def app(self): if not self._app: + import svcs + + from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.datamodels.token import TIToken + from airflow.api_fastapi.execution_api.deps import _container from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access @@ -332,10 +336,26 @@ async def always_allow(request: Request): ) return TIToken(id=ti_id, claims={"scope": "execution"}) + # Override DepContainer (the svcs service locator) for in-process use. + # Cadwyn's versioned sub-apps don't share the main app's + # state.svcs_registry, so the default _container dependency fails. + # Any service resolved via DepContainer in routes called during + # dag.test() must be registered here. + registry = svcs.Registry() + registry.register_value( + JWTGenerator, + JWTGenerator(valid_for=600, audience="urn:airflow.apache.org:task", **get_signing_args()), + ) + + async def _test_container(request: Request): + async with svcs.Container(registry) as cont: + yield cont + self._app.dependency_overrides[_jwt_bearer] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow + self._app.dependency_overrides[_container] = _test_container return self._app diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index fd77644bc44b5..2842ab4025905 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -302,6 +302,7 @@ def set_instance_attrs(self) -> Generator: def mock_executors(self): mock_jwt_generator = MagicMock(spec=JWTGenerator) mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token" default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") From 6d5397c766ebbcb82ef9911623509027a9ecec26 Mon Sep 17 00:00:00 2001 From: Anish Date: Sun, 15 Mar 2026 21:25:45 -0500 Subject: [PATCH 4/7] refactor on cleanups --- airflow-core/src/airflow/api_fastapi/execution_api/app.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 4d050c47763e3..562bed24798d7 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -313,7 +313,7 @@ def app(self): if not self._app: import svcs - from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args + from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.datamodels.token import TIToken @@ -342,10 +342,7 @@ async def always_allow(request: Request): # Any service resolved via DepContainer in routes called during # dag.test() must be registered here. registry = svcs.Registry() - registry.register_value( - JWTGenerator, - JWTGenerator(valid_for=600, audience="urn:airflow.apache.org:task", **get_signing_args()), - ) + registry.register_factory(JWTGenerator, _jwt_generator) async def _test_container(request: Request): async with svcs.Container(registry) as cont: From 9c657d280a9466eb7c4490010c74de8b91cea4dd Mon Sep 17 00:00:00 2001 From: Anish Date: Sun, 15 Mar 2026 21:36:09 -0500 Subject: [PATCH 5/7] some more precise clarification on invariant overrides --- .../src/airflow/api_fastapi/execution_api/app.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 562bed24798d7..0264853661c40 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -336,15 +336,15 @@ async def always_allow(request: Request): ) return TIToken(id=ti_id, claims={"scope": "execution"}) - # Override DepContainer (the svcs service locator) for in-process use. - # Cadwyn's versioned sub-apps don't share the main app's - # state.svcs_registry, so the default _container dependency fails. - # Any service resolved via DepContainer in routes called during - # dag.test() must be registered here. + # Override _container (the svcs service locator behind DepContainer). + # The default _container reads request.app.state.svcs_registry, but + # Cadwyn's versioned sub-apps don't inherit the main app's state, + # so lookups raise ServiceNotFoundError. This registry provides + # services needed by routes called during dag.test(). registry = svcs.Registry() registry.register_factory(JWTGenerator, _jwt_generator) - async def _test_container(request: Request): + async def _in_process_container(request: Request): async with svcs.Container(registry) as cont: yield cont @@ -352,7 +352,7 @@ async def _test_container(request: Request): self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow - self._app.dependency_overrides[_container] = _test_container + self._app.dependency_overrides[_container] = _in_process_container return self._app From d07fcffa10246d5d9da92ef617fe16dd632c1f01 Mon Sep 17 00:00:00 2001 From: Anish Date: Wed, 18 Mar 2026 22:37:38 -0500 Subject: [PATCH 6/7] address review feeback --- .../src/airflow/api_fastapi/auth/tokens.py | 16 ++------ .../airflow/api_fastapi/execution_api/app.py | 15 ++++++- .../execution_api/routes/task_instances.py | 14 +++---- .../src/airflow/config_templates/config.yml | 2 +- .../src/airflow/executors/workloads/base.py | 7 +++- .../unit/api_fastapi/auth/test_tokens.py | 39 ++++++------------- .../api_fastapi/execution_api/conftest.py | 3 +- .../versions/head/test_task_instances.py | 38 ++++++++++++++++++ .../tests/unit/executors/test_workloads.py | 23 ++++++++++- .../tests/unit/jobs/test_scheduler_job.py | 1 - .../tests_common/test_utils/mock_executor.py | 4 +- 11 files changed, 108 insertions(+), 54 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 0fac48d9f4ea0..116f52d5e4f57 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -418,6 +418,10 @@ class JWTGenerator: kid: str = attrs.field(default=attrs.Factory(_generate_kid, takes_self=True)) valid_for: float + workload_valid_for: float = attrs.field( + factory=_conf_factory("execution_api", "jwt_workload_token_expiration_time", fallback="86400"), + converter=float, + ) audience: str issuer: str | list[str] | None = attrs.field( factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True, fallback=None) @@ -447,18 +451,6 @@ def signing_arg(self) -> AllowedPrivateKeys | str: assert self._secret_key return self._secret_key - def generate_workload_token(self, sub: str) -> str: - """Generate a long-lived workload token for executor queues.""" - from airflow.configuration import conf - - workload_valid_for = conf.getint( - "execution_api", "jwt_workload_token_expiration_time", fallback=86400 - ) - return self.generate( - extras={"sub": sub, "scope": "workload"}, - valid_for=workload_valid_for, - ) - def generate( self, extras: dict[str, Any] | None = None, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 0264853661c40..8675c70b77c2e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import secrets import time from contextlib import AsyncExitStack from functools import cached_property @@ -76,6 +77,7 @@ def _jwt_generator(): generator = JWTGenerator( valid_for=conf.getint("execution_api", "jwt_expiration_time"), + workload_valid_for=conf.getint("execution_api", "jwt_workload_token_expiration_time", fallback=86400), audience=conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], issuer=conf.get("api_auth", "jwt_issuer", fallback=None), # Since this one is used across components/server, there is no point trying to generate one, error @@ -142,6 +144,11 @@ async def dispatch(self, request: Request, call_next): validator: JWTValidator = await services.aget(JWTValidator) claims = await validator.avalidated_claims(token, {}) + # Workload tokens are long-lived and meant to survive queue + # wait times so avoid refreshing them. + if claims.get("scope") == "workload": + return response + now = int(time.time()) validity = conf.getint("execution_api", "jwt_expiration_time") refresh_when_less_than = max(int(validity * 0.20), 30) @@ -341,8 +348,14 @@ async def always_allow(request: Request): # Cadwyn's versioned sub-apps don't inherit the main app's state, # so lookups raise ServiceNotFoundError. This registry provides # services needed by routes called during dag.test(). + # + stub_generator = JWTGenerator( + secret_key=secrets.token_urlsafe(32), + audience="in-process", + valid_for=3600, + ) registry = svcs.Registry() - registry.register_factory(JWTGenerator, _jwt_generator) + registry.register_value(JWTGenerator, stub_generator) async def _in_process_container(request: Request): async with svcs.Container(registry) as cont: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index efaca9ba3fab4..b7ffd521f9878 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -107,7 +107,7 @@ }, response_model_exclude_unset=True, ) -async def ti_run( +def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], response: Response, @@ -291,18 +291,18 @@ async def ti_run( if ti.next_method: context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs - - generator: JWTGenerator = await services.aget(JWTGenerator) - execution_token = generator.generate(extras={"sub": str(task_instance_id)}) - response.headers["X-Execution-Token"] = execution_token - - return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" ) + generator: JWTGenerator = services.get(JWTGenerator) + execution_token = generator.generate(extras={"sub": str(task_instance_id)}) + response.headers["X-Execution-Token"] = execution_token + + return context + @ti_id_router.patch( "/{task_instance_id}/state", diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 6a1e9d2883f75..f9467bed5b43b 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2084,7 +2084,7 @@ execution_api: Seconds until workload JWT tokens expire. These long-lived tokens are sent with task workloads to executors and can only call the /run endpoint. Set long enough to cover maximum expected queue wait time. - version_added: ~ + version_added: 3.2.0 type: integer example: ~ default: "86400" diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 69bf85819edae..12a5574da3e40 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -74,7 +74,12 @@ class BaseWorkloadSchema(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate_workload_token(sub_id) if generator else "" + if not generator: + return "" + return generator.generate( + extras={"sub": sub_id, "scope": "workload"}, + valid_for=generator.workload_valid_for, + ) class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index a9b9612ad43b7..89677cc1ef00b 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -160,48 +160,33 @@ def test_secret_key_with_configured_kid(): assert header["kid"] == "my-custom-kid" -def test_generate_workload_token(): - """generate_workload_token() produces a token with scope 'workload' and 24h expiry.""" +def test_generate_with_custom_valid_for(): + """generate() accepts a valid_for override.""" generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + token = generator.generate(extras={"sub": "user"}, valid_for=3600) + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["exp"] - claims["iat"] == 3600 - with patch.dict( - "os.environ", - {"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"}, - ): - token = generator.generate_workload_token(sub="ti-123") +def test_generate_workload_scope_via_extras(): + """generate() with scope='workload' in extras produces a workload-scoped token.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + + token = generator.generate(extras={"sub": "ti-123", "scope": "workload"}, valid_for=86400) claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") assert claims["sub"] == "ti-123" assert claims["scope"] == "workload" - # Workload token should have ~24h validity, not the generator's default 60s assert claims["exp"] - claims["iat"] == 86400 -def test_generate_with_custom_valid_for(): - """generate() accepts a valid_for override.""" - generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) - token = generator.generate(extras={"sub": "user"}, valid_for=3600) - claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") - assert claims["exp"] - claims["iat"] == 3600 - - -def test_workload_token_vs_regular_token_scope(): - """Regular tokens have no scope, workload tokens have scope 'workload'.""" +def test_regular_token_has_no_scope(): + """Regular tokens without scope in extras have no scope claim.""" generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) regular = generator.generate(extras={"sub": "user"}) regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"], audience="test") assert "scope" not in regular_claims - with patch.dict( - "os.environ", - {"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"}, - ): - workload = generator.generate_workload_token(sub="ti-123") - - workload_claims = jwt.decode(workload, "test-secret", algorithms=["HS512"], audience="test") - assert workload_claims["scope"] == "workload" - @pytest.fixture def jwt_generator(ed25519_private_key: Ed25519PrivateKey): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 8ce394c7d6163..f7dacc40ac9cf 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -59,9 +59,10 @@ async def mock_jwt_bearer(request: Request): with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: mock_generator = MagicMock(spec=JWTGenerator) mock_generator.generate.return_value = "mock-execution-token" - mock_generator.generate_workload_token.return_value = "mock-workload-token" lifespan.registry.register_value(JWTGenerator, mock_generator) yield client + lifespan.registry.close() + exec_app.dependency_overrides.pop(_jwt_bearer, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 96759a3ec15c0..f00ddd9f20971 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -3264,6 +3264,44 @@ def test_invalid_scope_value_rejected(self, client, session, create_task_instanc assert resp.status_code == 403 assert "Invalid token scope" in resp.json()["detail"] + def test_workload_scope_accepted_on_run_endpoint( + self, client, session, create_task_instance, time_machine + ): + """workload scoped tokens should be accepted on the /run endpoint.""" + instant = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_workload_run", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.side_effect = lambda cred, validators: { + "sub": str(ti.id), + "scope": "workload", + "exp": 9999999999, + "iat": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + resp = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "test-host", + "unixname": "test-user", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + }, + ) + assert resp.status_code == 200 + def test_no_scope_defaults_to_execution(self, client, session, create_task_instance): """Tokens without scope claim should default to 'execution'.""" ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING) diff --git a/airflow-core/tests/unit/executors/test_workloads.py b/airflow-core/tests/unit/executors/test_workloads.py index 1a67ab96d4073..7e9745a763bfb 100644 --- a/airflow-core/tests/unit/executors/test_workloads.py +++ b/airflow-core/tests/unit/executors/test_workloads.py @@ -20,9 +20,12 @@ from pathlib import PurePosixPath from uuid import uuid4 +import jwt + +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.executors import workloads from airflow.executors.workloads import TaskInstance, TaskInstanceDTO -from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo from airflow.executors.workloads.task import ExecuteTask @@ -61,3 +64,21 @@ def test_token_excluded_from_workload_repr(): assert fake_token not in workload_repr, f"JWT token leaked into repr! Found token in: {workload_repr}" # But token should still be accessible as an attribute assert workload.token == fake_token + + +def test_generate_token_produces_workload_scope(): + """generate_token should create a JWT with scope 'workload' and workload_valid_for expiry.""" + generator = JWTGenerator( + secret_key="test-secret", audience="test", valid_for=60, workload_valid_for=86400 + ) + token = BaseWorkloadSchema.generate_token("ti-123", generator) + + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["sub"] == "ti-123" + assert claims["scope"] == "workload" + assert claims["exp"] - claims["iat"] == 86400 + + +def test_generate_token_without_generator(): + """generate_token should return empty string when no generator is provided.""" + assert BaseWorkloadSchema.generate_token("ti-123", None) == "" diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 2842ab4025905..fd77644bc44b5 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -302,7 +302,6 @@ def set_instance_attrs(self) -> Generator: def mock_executors(self): mock_jwt_generator = MagicMock(spec=JWTGenerator) mock_jwt_generator.generate.return_value = "mock-token" - mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token" default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index cf53e9ae56d33..c7a2f26315234 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING from unittest.mock import MagicMock +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstance @@ -57,9 +58,8 @@ def __init__(self, do_update=True, *args, **kwargs): self.mock_task_results = defaultdict(self.success) # Mock JWT generator for token generation - mock_jwt_generator = MagicMock() + mock_jwt_generator = MagicMock(spec=JWTGenerator) mock_jwt_generator.generate.return_value = "mock-token" - mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token" self.jwt_generator = mock_jwt_generator From 9eaf6ddba9569c3c211eec79242c27ff6a5608ba Mon Sep 17 00:00:00 2001 From: Anish Date: Wed, 25 Mar 2026 22:35:25 -0500 Subject: [PATCH 7/7] adress review comments --- .../airflow/api_fastapi/execution_api/app.py | 8 +- .../execution_api/routes/task_instances.py | 12 ++- .../src/airflow/config_templates/config.yml | 2 +- .../api_fastapi/execution_api/conftest.py | 2 - .../versions/head/test_task_instances.py | 97 ++++++++++--------- task-sdk/src/airflow/sdk/api/client.py | 5 +- task-sdk/tests/task_sdk/api/test_client.py | 30 ------ 7 files changed, 66 insertions(+), 90 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 8675c70b77c2e..feb7f7d33a8e1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -77,7 +77,7 @@ def _jwt_generator(): generator = JWTGenerator( valid_for=conf.getint("execution_api", "jwt_expiration_time"), - workload_valid_for=conf.getint("execution_api", "jwt_workload_token_expiration_time", fallback=86400), + # workload_valid_for uses the attrs default factory which reads the same config key audience=conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], issuer=conf.get("api_auth", "jwt_issuer", fallback=None), # Since this one is used across components/server, there is no point trying to generate one, error @@ -145,7 +145,8 @@ async def dispatch(self, request: Request, call_next): claims = await validator.avalidated_claims(token, {}) # Workload tokens are long-lived and meant to survive queue - # wait times so avoid refreshing them. + # wait times so avoid refreshing them. If avalidated_claims + # raises for a workload token, the outer except handles it. if claims.get("scope") == "workload": return response @@ -348,7 +349,8 @@ async def always_allow(request: Request): # Cadwyn's versioned sub-apps don't inherit the main app's state, # so lookups raise ServiceNotFoundError. This registry provides # services needed by routes called during dag.test(). - # + # Note: tokens generated by this stub are never validated since + # _jwt_bearer is overridden with always_allow in dag.test() mode. stub_generator = JWTGenerator( secret_key=secrets.token_urlsafe(32), audience="in-process", diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index b7ffd521f9878..c8f95b79b5f8f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -297,9 +297,15 @@ def ti_run( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" ) - generator: JWTGenerator = services.get(JWTGenerator) - execution_token = generator.generate(extras={"sub": str(task_instance_id)}) - response.headers["X-Execution-Token"] = execution_token + try: + generator: JWTGenerator = services.get(JWTGenerator) + execution_token = generator.generate(extras={"sub": str(task_instance_id), "scope": "execution"}) + except Exception: + log.exception("Failed to generate execution token for task instance %s", task_instance_id) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token generation failed" + ) + response.headers["Refreshed-API-Token"] = execution_token return context diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index f9467bed5b43b..2b8c2726557d3 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2084,7 +2084,7 @@ execution_api: Seconds until workload JWT tokens expire. These long-lived tokens are sent with task workloads to executors and can only call the /run endpoint. Set long enough to cover maximum expected queue wait time. - version_added: 3.2.0 + version_added: 3.2.1 type: integer example: ~ default: "86400" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index f7dacc40ac9cf..0bd48bb766245 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -63,6 +63,4 @@ async def mock_jwt_bearer(request: Request): yield client - lifespan.registry.close() - exec_app.dependency_overrides.pop(_jwt_bearer, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index f00ddd9f20971..7814b9bb321e2 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -255,7 +255,7 @@ def test_ti_run_state_to_running( assert response.status_code == 409 def test_ti_run_returns_execution_token(self, client, session, create_task_instance, time_machine): - """PATCH /run should return an X-Execution-Token header on success.""" + """PATCH /run should return a Refreshed-API-Token header on success.""" instant = timezone.parse("2024-10-31T12:00:00Z") time_machine.move_to(instant, tick=False) @@ -281,8 +281,8 @@ def test_ti_run_returns_execution_token(self, client, session, create_task_insta ) assert response.status_code == 200 - assert "X-Execution-Token" in response.headers - assert response.headers["X-Execution-Token"] == "mock-execution-token" + assert "Refreshed-API-Token" in response.headers + assert response.headers["Refreshed-API-Token"] == "mock-execution-token" def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker): """Test that dynamic task mapping works correctly with parse-time values.""" @@ -3199,40 +3199,63 @@ def test_ti_patch_rendered_map_index_empty_string(self, client, session, create_ @pytest.mark.usefixtures("_use_real_jwt_bearer") class TestTokenTypeValidation: - """Test token scope enforcement (workload vs execution).""" + """Test token scope enforcement (workload vs execution). - def test_workload_scope_rejected_on_default_endpoints(self, client, session, create_task_instance): - """workload scoped tokens should be rejected on endpoints without token:workload Security scope.""" - ti = create_task_instance(task_id="test_ti_run_heartbeat", state=State.RUNNING) - session.commit() + Uses _use_real_jwt_bearer to remove the conftest's mock _jwt_bearer + override, then registers a JWTValidator mock on the shared lifespan + registry that returns claims with specific scope values. + """ + def _register_scoped_validator(self, ti_id, scope): + """Register a JWTValidator mock returning claims with the given scope.""" validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "workload", - "exp": 9999999999, - "iat": 1000000000, - } + claims = {"sub": str(ti_id), "exp": 9999999999, "iat": 1000000000} + if scope is not None: + claims["scope"] = scope + validator.avalidated_claims.side_effect = lambda cred, validators: claims lifespan.registry.register_value(JWTValidator, validator) + def test_workload_scope_rejected_on_heartbeat_endpoint(self, client, session, create_task_instance): + """Workload scoped tokens should be rejected on /heartbeat.""" + ti = create_task_instance(task_id="test_ti_run_heartbeat", state=State.RUNNING) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + payload = {"hostname": "test-host", "pid": 100} resp = client.put(f"/execution/task-instances/{ti.id}/heartbeat", json=payload) assert resp.status_code == 403 assert "Token type 'workload' not allowed" in resp.json()["detail"] + def test_workload_scope_rejected_on_state_endpoint(self, client, session, create_task_instance): + """Workload scoped tokens should be rejected on PATCH /state.""" + ti = create_task_instance(task_id="test_workload_state", state=State.RUNNING) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + + payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} + resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + assert resp.status_code == 403 + assert "Token type 'workload' not allowed" in resp.json()["detail"] + + def test_workload_scope_rejected_on_connections_endpoint(self, client, session, create_task_instance): + """Workload scoped tokens should be rejected on GET /connections (different router).""" + ti = create_task_instance(task_id="test_workload_conn", state=State.RUNNING) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + + resp = client.get("/execution/connections/test_conn") + assert resp.status_code == 403 + assert "Token type 'workload' not allowed" in resp.json()["detail"] + def test_execution_scope_accepted_on_all_endpoints(self, client, session, create_task_instance): - """execution scoped tokens should be able to call all endpoints.""" + """Execution scoped tokens should be accepted on all endpoints.""" ti = create_task_instance(task_id="test_ti_star", state=State.RUNNING) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "execution", - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, "execution") payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) @@ -3243,14 +3266,7 @@ def test_invalid_scope_value_rejected(self, client, session, create_task_instanc ti = create_task_instance(task_id="test_invalid_scope", state=State.QUEUED) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "bogus:scope", - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, "bogus:scope") payload = { "state": "running", @@ -3267,7 +3283,7 @@ def test_invalid_scope_value_rejected(self, client, session, create_task_instanc def test_workload_scope_accepted_on_run_endpoint( self, client, session, create_task_instance, time_machine ): - """workload scoped tokens should be accepted on the /run endpoint.""" + """Workload scoped tokens should be accepted on the /run endpoint.""" instant = timezone.parse("2024-10-31T12:00:00Z") time_machine.move_to(instant, tick=False) @@ -3281,14 +3297,7 @@ def test_workload_scope_accepted_on_run_endpoint( ) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "workload", - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, "workload") resp = client.patch( f"/execution/task-instances/{ti.id}/run", @@ -3307,13 +3316,7 @@ def test_no_scope_defaults_to_execution(self, client, session, create_task_insta ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, None) payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 71e3d59575281..19e691281f73b 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -963,10 +963,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * ) def _update_auth(self, response: httpx.Response): - if new_token := response.headers.get("X-Execution-Token"): - log.debug("Received execution token, swapping auth") - self.auth = BearerAuth(new_token) - elif new_token := response.headers.get("Refreshed-API-Token"): + if new_token := response.headers.get("Refreshed-API-Token"): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index a6bcf89d4d92d..0df8839c55f30 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -250,36 +250,6 @@ def test_token_renewal(self): assert response.status_code == 200 assert response.request.headers["Authorization"] == "Bearer abc" - def test_execution_token_swap(self): - """X-Execution-Token header should replace the auth token.""" - responses: list[httpx.Response] = [ - httpx.Response(200, json={"ok": "1"}, headers={"X-Execution-Token": "exec-token-123"}), - httpx.Response(200, json={"ok": "2"}), - ] - client = make_client_w_responses(responses) - response = client.get("/") - assert response.status_code == 200 - - assert client.auth is not None - assert client.auth.token == "exec-token-123" - - response = client.get("/") - assert response.status_code == 200 - assert response.request.headers["Authorization"] == "Bearer exec-token-123" - - def test_execution_token_takes_priority_over_refreshed_token(self): - """When both headers present, X-Execution-Token should take priority.""" - responses: list[httpx.Response] = [ - httpx.Response( - 200, - json={"ok": "1"}, - headers={"X-Execution-Token": "exec-tok", "Refreshed-API-Token": "refresh-tok"}, - ), - ] - client = make_client_w_responses(responses) - client.get("/") - assert client.auth.token == "exec-tok" - @pytest.mark.parametrize( ("status_code", "description"), [