Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ class JWTGenerator:

kid: str = attrs.field(default=attrs.Factory(_generate_kid, takes_self=True))
valid_for: float
workload_valid_for: float = attrs.field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workload_valid_for default reads from config via _conf_factory, and _jwt_generator() in app.py also reads the same config key and passes it explicitly. The explicit kwarg takes precedence, so the default factory never runs in production. Having two code paths that reference the same config key is easy to get out of sync -- consider dropping the attrs default (make it required like valid_for) and always passing it explicitly, or drop the explicit kwarg in _jwt_generator() and let the default handle it.

factory=_conf_factory("execution_api", "jwt_workload_token_expiration_time", fallback="86400"),
converter=float,
)
audience: str
issuer: str | list[str] | None = attrs.field(
factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True, fallback=None)
Expand Down Expand Up @@ -447,15 +451,21 @@ def signing_arg(self) -> AllowedPrivateKeys | str:
assert self._secret_key
return self._secret_key

def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str:
def generate(
self,
extras: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
valid_for: float | None = None,
) -> str:
"""Generate a signed JWT for the subject."""
now = int(datetime.now(tz=timezone.utc).timestamp())
effective_valid_for = valid_for if valid_for is not None else self.valid_for
claims = {
"jti": uuid.uuid4().hex,
"iss": self.issuer,
"aud": self.audience,
"nbf": now,
"exp": int(now + self.valid_for),
"exp": int(now + effective_valid_for),
"iat": now,
}

Expand Down
32 changes: 32 additions & 0 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import secrets
import time
from contextlib import AsyncExitStack
from functools import cached_property
Expand Down Expand Up @@ -76,6 +77,7 @@ def _jwt_generator():

generator = JWTGenerator(
valid_for=conf.getint("execution_api", "jwt_expiration_time"),
# workload_valid_for uses the attrs default factory which reads the same config key
audience=conf.get_mandatory_list_value("execution_api", "jwt_audience")[0],
issuer=conf.get("api_auth", "jwt_issuer", fallback=None),
# Since this one is used across components/server, there is no point trying to generate one, error
Expand Down Expand Up @@ -142,6 +144,12 @@ async def dispatch(self, request: Request, call_next):
validator: JWTValidator = await services.aget(JWTValidator)
claims = await validator.avalidated_claims(token, {})

# Workload tokens are long-lived and meant to survive queue
# wait times so avoid refreshing them. If avalidated_claims
# raises for a workload token, the outer except handles it.
if claims.get("scope") == "workload":
return response

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early return for workload tokens skips the refresh logic (correct), but it also skips the except block below. If avalidated_claims raises for a workload token, execution falls into the outer except and the response still gets returned (with a warning log). Might be worth a comment clarifying that workload token validation errors are handled by the outer catch.

now = int(time.time())
validity = conf.getint("execution_api", "jwt_expiration_time")
refresh_when_less_than = max(int(validity * 0.20), 30)
Expand Down Expand Up @@ -311,9 +319,13 @@ class InProcessExecutionAPI:
@cached_property
def app(self):
if not self._app:
import svcs

from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.dagbag import create_dag_bag
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.deps import _container
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
Expand All @@ -332,10 +344,30 @@ async def always_allow(request: Request):
)
return TIToken(id=ti_id, claims={"scope": "execution"})

# Override _container (the svcs service locator behind DepContainer).
# The default _container reads request.app.state.svcs_registry, but
# Cadwyn's versioned sub-apps don't inherit the main app's state,
# so lookups raise ServiceNotFoundError. This registry provides
# services needed by routes called during dag.test().
# Note: tokens generated by this stub are never validated since
# _jwt_bearer is overridden with always_allow in dag.test() mode.
stub_generator = JWTGenerator(
secret_key=secrets.token_urlsafe(32),
audience="in-process",
valid_for=3600,
)
registry = svcs.Registry()
registry.register_value(JWTGenerator, stub_generator)

async def _in_process_container(request: Request):
async with svcs.Container(registry) as cont:
yield cont

self._app.dependency_overrides[_jwt_bearer] = always_allow
self._app.dependency_overrides[has_connection_access] = always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
self._app.dependency_overrides[_container] = _in_process_container

return self._app

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, Security, status
from fastapi import Body, HTTPException, Query, Response, Security, status
from opentelemetry import trace
from opentelemetry.trace import StatusCode
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
Expand All @@ -42,6 +42,7 @@

from airflow._shared.observability.traces import override_ids
from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
Expand All @@ -63,6 +64,7 @@
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import DepContainer
from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
Expand Down Expand Up @@ -97,6 +99,7 @@
@ti_id_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])],
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
Expand All @@ -107,8 +110,10 @@
def ti_run(
task_instance_id: UUID,
ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
response: Response,
session: SessionDep,
dag_bag: DagBagDep,
services=DepContainer,
) -> TIRunContext:
"""
Run a TaskInstance.
Expand Down Expand Up @@ -286,14 +291,24 @@ def ti_run(
if ti.next_method:
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs

return context
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)

try:
generator: JWTGenerator = services.get(JWTGenerator)
execution_token = generator.generate(extras={"sub": str(task_instance_id), "scope": "execution"})
except Exception:
log.exception("Failed to generate execution token for task instance %s", task_instance_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token generation failed"
)
response.headers["Refreshed-API-Token"] = execution_token

return context


@ti_id_router.patch(
"/{task_instance_id}/state",
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,15 @@ execution_api:
type: integer
example: ~
default: "600"
jwt_workload_token_expiration_time:
description: |
Seconds until workload JWT tokens expire. These long-lived tokens are sent
with task workloads to executors and can only call the /run endpoint.
Set long enough to cover maximum expected queue wait time.
version_added: 3.2.1
type: integer
example: ~
default: "86400"
jwt_audience:
version_added: 3.0.0
description: |
Expand Down
7 changes: 6 additions & 1 deletion airflow-core/src/airflow/executors/workloads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ class BaseWorkloadSchema(BaseModel):

@staticmethod
def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str:
return generator.generate({"sub": sub_id}) if generator else ""
if not generator:
return ""
return generator.generate(
extras={"sub": sub_id, "scope": "workload"},
valid_for=generator.workload_valid_for,
)


class BaseDagBundleWorkload(BaseWorkloadSchema, ABC):
Expand Down
28 changes: 28 additions & 0 deletions airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ def test_secret_key_with_configured_kid():
assert header["kid"] == "my-custom-kid"


def test_generate_with_custom_valid_for():
"""generate() accepts a valid_for override."""
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)
token = generator.generate(extras={"sub": "user"}, valid_for=3600)
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
assert claims["exp"] - claims["iat"] == 3600


def test_generate_workload_scope_via_extras():
"""generate() with scope='workload' in extras produces a workload-scoped token."""
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)

token = generator.generate(extras={"sub": "ti-123", "scope": "workload"}, valid_for=86400)
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
assert claims["sub"] == "ti-123"
assert claims["scope"] == "workload"
assert claims["exp"] - claims["iat"] == 86400


def test_regular_token_has_no_scope():
"""Regular tokens without scope in extras have no scope claim."""
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)

regular = generator.generate(extras={"sub": "user"})
regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"], audience="test")
assert "scope" not in regular_claims


@pytest.fixture
def jwt_generator(ed25519_private_key: Ed25519PrivateKey):
key = ed25519_private_key
Expand Down
8 changes: 8 additions & 0 deletions airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock

import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import cached_app
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.execution_api.app import lifespan
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.security import _jwt_bearer

Expand Down Expand Up @@ -53,6 +57,10 @@ async def mock_jwt_bearer(request: Request):
exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer

with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
mock_generator = MagicMock(spec=JWTGenerator)
mock_generator.generate.return_value = "mock-execution-token"
lifespan.registry.register_value(JWTGenerator, mock_generator)

yield client

exec_app.dependency_overrides.pop(_jwt_bearer, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lifespan.registry.close() is new here (no other test file does this), and the registry is shared across all tests via cached_app. Closing it could break subsequent tests that try to look up services from the same registry. The existing pattern in other test files (e.g., test_task_instances.py, test_router.py) registers values on lifespan.registry without closing it afterward. I'd drop this close() call to match what the rest of the test suite does.

Loading
Loading