diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 3375853a29a5d..116f52d5e4f57 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -418,6 +418,10 @@ class JWTGenerator: kid: str = attrs.field(default=attrs.Factory(_generate_kid, takes_self=True)) valid_for: float + workload_valid_for: float = attrs.field( + factory=_conf_factory("execution_api", "jwt_workload_token_expiration_time", fallback="86400"), + converter=float, + ) audience: str issuer: str | list[str] | None = attrs.field( factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True, fallback=None) @@ -447,15 +451,21 @@ def signing_arg(self) -> AllowedPrivateKeys | str: assert self._secret_key return self._secret_key - def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str: + def generate( + self, + extras: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + valid_for: float | None = None, + ) -> str: """Generate a signed JWT for the subject.""" now = int(datetime.now(tz=timezone.utc).timestamp()) + effective_valid_for = valid_for if valid_for is not None else self.valid_for claims = { "jti": uuid.uuid4().hex, "iss": self.issuer, "aud": self.audience, "nbf": now, - "exp": int(now + self.valid_for), + "exp": int(now + effective_valid_for), "iat": now, } diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index c7a9593c3c82f..feb7f7d33a8e1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import secrets import time from contextlib import AsyncExitStack from functools import cached_property @@ -76,6 +77,7 @@ def _jwt_generator(): generator = JWTGenerator( valid_for=conf.getint("execution_api", "jwt_expiration_time"), + # workload_valid_for uses the attrs default factory which reads the same config key audience=conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], issuer=conf.get("api_auth", "jwt_issuer", fallback=None), # Since this one is used across components/server, there is no point trying to generate one, error @@ -142,6 +144,12 @@ async def dispatch(self, request: Request, call_next): validator: JWTValidator = await services.aget(JWTValidator) claims = await validator.avalidated_claims(token, {}) + # Workload tokens are long-lived and meant to survive queue + # wait times so avoid refreshing them. If avalidated_claims + # raises for a workload token, the outer except handles it. + if claims.get("scope") == "workload": + return response + now = int(time.time()) validity = conf.getint("execution_api", "jwt_expiration_time") refresh_when_less_than = max(int(validity * 0.20), 30) @@ -311,9 +319,13 @@ class InProcessExecutionAPI: @cached_property def app(self): if not self._app: + import svcs + + from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.datamodels.token import TIToken + from airflow.api_fastapi.execution_api.deps import _container from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access @@ -332,10 +344,30 @@ async def always_allow(request: Request): ) return TIToken(id=ti_id, claims={"scope": "execution"}) + # Override _container (the svcs service locator behind DepContainer). + # The default _container reads request.app.state.svcs_registry, but + # Cadwyn's versioned sub-apps don't inherit the main app's state, + # so lookups raise ServiceNotFoundError. This registry provides + # services needed by routes called during dag.test(). + # Note: tokens generated by this stub are never validated since + # _jwt_bearer is overridden with always_allow in dag.test() mode. + stub_generator = JWTGenerator( + secret_key=secrets.token_urlsafe(32), + audience="in-process", + valid_for=3600, + ) + registry = svcs.Registry() + registry.register_value(JWTGenerator, stub_generator) + + async def _in_process_container(request: Request): + async with svcs.Container(registry) as cont: + yield cont + self._app.dependency_overrides[_jwt_bearer] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow + self._app.dependency_overrides[_container] = _in_process_container return self._app diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 5f5073c916b68..c8f95b79b5f8f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, Security, status +from fastapi import Body, HTTPException, Query, Response, Security, status from opentelemetry import trace from opentelemetry.trace import StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -42,6 +42,7 @@ from airflow._shared.observability.traces import override_ids from airflow._shared.timezones import timezone +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime @@ -63,6 +64,7 @@ TISuccessStatePayload, TITerminalStatePayload, ) +from airflow.api_fastapi.execution_api.deps import DepContainer from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive @@ -97,6 +99,7 @@ @ti_id_router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, + dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])], responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, @@ -107,8 +110,10 @@ def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], + response: Response, session: SessionDep, dag_bag: DagBagDep, + services=DepContainer, ) -> TIRunContext: """ Run a TaskInstance. @@ -286,14 +291,24 @@ def ti_run( if ti.next_method: context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs - - return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" ) + try: + generator: JWTGenerator = services.get(JWTGenerator) + execution_token = generator.generate(extras={"sub": str(task_instance_id), "scope": "execution"}) + except Exception: + log.exception("Failed to generate execution token for task instance %s", task_instance_id) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token generation failed" + ) + response.headers["Refreshed-API-Token"] = execution_token + + return context + @ti_id_router.patch( "/{task_instance_id}/state", diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 55313198d5b39..2b8c2726557d3 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2079,6 +2079,15 @@ execution_api: type: integer example: ~ default: "600" + jwt_workload_token_expiration_time: + description: | + Seconds until workload JWT tokens expire. These long-lived tokens are sent + with task workloads to executors and can only call the /run endpoint. + Set long enough to cover maximum expected queue wait time. + version_added: 3.2.1 + type: integer + example: ~ + default: "86400" jwt_audience: version_added: 3.0.0 description: | diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 97cf16ebaf64d..12a5574da3e40 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -74,7 +74,12 @@ class BaseWorkloadSchema(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" + if not generator: + return "" + return generator.generate( + extras={"sub": sub_id, "scope": "workload"}, + valid_for=generator.workload_valid_for, + ) class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index 6b848f723a004..89677cc1ef00b 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -160,6 +160,34 @@ def test_secret_key_with_configured_kid(): assert header["kid"] == "my-custom-kid" +def test_generate_with_custom_valid_for(): + """generate() accepts a valid_for override.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + token = generator.generate(extras={"sub": "user"}, valid_for=3600) + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["exp"] - claims["iat"] == 3600 + + +def test_generate_workload_scope_via_extras(): + """generate() with scope='workload' in extras produces a workload-scoped token.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + + token = generator.generate(extras={"sub": "ti-123", "scope": "workload"}, valid_for=86400) + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["sub"] == "ti-123" + assert claims["scope"] == "workload" + assert claims["exp"] - claims["iat"] == 86400 + + +def test_regular_token_has_no_scope(): + """Regular tokens without scope in extras have no scope claim.""" + generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60) + + regular = generator.generate(extras={"sub": "user"}) + regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"], audience="test") + assert "scope" not in regular_claims + + @pytest.fixture def jwt_generator(ed25519_private_key: Ed25519PrivateKey): key = ed25519_private_key diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 78bd0548df9d2..0bd48bb766245 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -16,11 +16,15 @@ # under the License. from __future__ import annotations +from unittest.mock import MagicMock + import pytest from fastapi import FastAPI, Request from fastapi.testclient import TestClient from airflow.api_fastapi.app import cached_app +from airflow.api_fastapi.auth.tokens import JWTGenerator +from airflow.api_fastapi.execution_api.app import lifespan from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.api_fastapi.execution_api.security import _jwt_bearer @@ -53,6 +57,10 @@ async def mock_jwt_bearer(request: Request): exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: + mock_generator = MagicMock(spec=JWTGenerator) + mock_generator.generate.return_value = "mock-execution-token" + lifespan.registry.register_value(JWTGenerator, mock_generator) + yield client exec_app.dependency_overrides.pop(_jwt_bearer, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 7f766ede71e4d..7814b9bb321e2 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -254,6 +254,36 @@ def test_ti_run_state_to_running( ) assert response.status_code == 409 + def test_ti_run_returns_execution_token(self, client, session, create_task_instance, time_machine): + """PATCH /run should return a Refreshed-API-Token header on success.""" + instant = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_exec_token", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "test-host", + "unixname": "test-user", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + }, + ) + + assert response.status_code == 200 + assert "Refreshed-API-Token" in response.headers + assert response.headers["Refreshed-API-Token"] == "mock-execution-token" + def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker): """Test that dynamic task mapping works correctly with parse-time values.""" with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True): @@ -3169,40 +3199,63 @@ def test_ti_patch_rendered_map_index_empty_string(self, client, session, create_ @pytest.mark.usefixtures("_use_real_jwt_bearer") class TestTokenTypeValidation: - """Test token scope enforcement (workload vs execution).""" + """Test token scope enforcement (workload vs execution). - def test_workload_scope_rejected_on_default_endpoints(self, client, session, create_task_instance): - """workload scoped tokens should be rejected on endpoints without token:workload Security scope.""" - ti = create_task_instance(task_id="test_ti_run_heartbeat", state=State.RUNNING) - session.commit() + Uses _use_real_jwt_bearer to remove the conftest's mock _jwt_bearer + override, then registers a JWTValidator mock on the shared lifespan + registry that returns claims with specific scope values. + """ + def _register_scoped_validator(self, ti_id, scope): + """Register a JWTValidator mock returning claims with the given scope.""" validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "workload", - "exp": 9999999999, - "iat": 1000000000, - } + claims = {"sub": str(ti_id), "exp": 9999999999, "iat": 1000000000} + if scope is not None: + claims["scope"] = scope + validator.avalidated_claims.side_effect = lambda cred, validators: claims lifespan.registry.register_value(JWTValidator, validator) + def test_workload_scope_rejected_on_heartbeat_endpoint(self, client, session, create_task_instance): + """Workload scoped tokens should be rejected on /heartbeat.""" + ti = create_task_instance(task_id="test_ti_run_heartbeat", state=State.RUNNING) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + payload = {"hostname": "test-host", "pid": 100} resp = client.put(f"/execution/task-instances/{ti.id}/heartbeat", json=payload) assert resp.status_code == 403 assert "Token type 'workload' not allowed" in resp.json()["detail"] + def test_workload_scope_rejected_on_state_endpoint(self, client, session, create_task_instance): + """Workload scoped tokens should be rejected on PATCH /state.""" + ti = create_task_instance(task_id="test_workload_state", state=State.RUNNING) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + + payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} + resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + assert resp.status_code == 403 + assert "Token type 'workload' not allowed" in resp.json()["detail"] + + def test_workload_scope_rejected_on_connections_endpoint(self, client, session, create_task_instance): + """Workload scoped tokens should be rejected on GET /connections (different router).""" + ti = create_task_instance(task_id="test_workload_conn", state=State.RUNNING) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + + resp = client.get("/execution/connections/test_conn") + assert resp.status_code == 403 + assert "Token type 'workload' not allowed" in resp.json()["detail"] + def test_execution_scope_accepted_on_all_endpoints(self, client, session, create_task_instance): - """execution scoped tokens should be able to call all endpoints.""" + """Execution scoped tokens should be accepted on all endpoints.""" ti = create_task_instance(task_id="test_ti_star", state=State.RUNNING) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "execution", - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, "execution") payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) @@ -3213,14 +3266,7 @@ def test_invalid_scope_value_rejected(self, client, session, create_task_instanc ti = create_task_instance(task_id="test_invalid_scope", state=State.QUEUED) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "scope": "bogus:scope", - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, "bogus:scope") payload = { "state": "running", @@ -3234,18 +3280,43 @@ def test_invalid_scope_value_rejected(self, client, session, create_task_instanc assert resp.status_code == 403 assert "Invalid token scope" in resp.json()["detail"] + def test_workload_scope_accepted_on_run_endpoint( + self, client, session, create_task_instance, time_machine + ): + """Workload scoped tokens should be accepted on the /run endpoint.""" + instant = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_workload_run", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.commit() + + self._register_scoped_validator(ti.id, "workload") + + resp = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "test-host", + "unixname": "test-user", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + }, + ) + assert resp.status_code == 200 + def test_no_scope_defaults_to_execution(self, client, session, create_task_instance): """Tokens without scope claim should default to 'execution'.""" ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - validator.avalidated_claims.side_effect = lambda cred, validators: { - "sub": str(ti.id), - "exp": 9999999999, - "iat": 1000000000, - } - lifespan.registry.register_value(JWTValidator, validator) + self._register_scoped_validator(ti.id, None) payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) diff --git a/airflow-core/tests/unit/executors/test_workloads.py b/airflow-core/tests/unit/executors/test_workloads.py index 1a67ab96d4073..7e9745a763bfb 100644 --- a/airflow-core/tests/unit/executors/test_workloads.py +++ b/airflow-core/tests/unit/executors/test_workloads.py @@ -20,9 +20,12 @@ from pathlib import PurePosixPath from uuid import uuid4 +import jwt + +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.executors import workloads from airflow.executors.workloads import TaskInstance, TaskInstanceDTO -from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo from airflow.executors.workloads.task import ExecuteTask @@ -61,3 +64,21 @@ def test_token_excluded_from_workload_repr(): assert fake_token not in workload_repr, f"JWT token leaked into repr! Found token in: {workload_repr}" # But token should still be accessible as an attribute assert workload.token == fake_token + + +def test_generate_token_produces_workload_scope(): + """generate_token should create a JWT with scope 'workload' and workload_valid_for expiry.""" + generator = JWTGenerator( + secret_key="test-secret", audience="test", valid_for=60, workload_valid_for=86400 + ) + token = BaseWorkloadSchema.generate_token("ti-123", generator) + + claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test") + assert claims["sub"] == "ti-123" + assert claims["scope"] == "workload" + assert claims["exp"] - claims["iat"] == 86400 + + +def test_generate_token_without_generator(): + """generate_token should return empty string when no generator is provided.""" + assert BaseWorkloadSchema.generate_token("ti-123", None) == "" diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 4e95ed3a4eea7..c7a2f26315234 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING from unittest.mock import MagicMock +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstance @@ -57,7 +58,7 @@ def __init__(self, do_update=True, *args, **kwargs): self.mock_task_results = defaultdict(self.success) # Mock JWT generator for token generation - mock_jwt_generator = MagicMock() + mock_jwt_generator = MagicMock(spec=JWTGenerator) mock_jwt_generator.generate.return_value = "mock-token" self.jwt_generator = mock_jwt_generator