-
Notifications
You must be signed in to change notification settings - Fork 16.8k
Two-token mechanism for task execution to prevent token expiration while tasks wait in executor queues #60108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
48be6b9
7f6065c
8d536d4
03c53bc
b377ced
fcbe2c7
bb96119
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import secrets | ||
| import time | ||
| from contextlib import AsyncExitStack | ||
| from functools import cached_property | ||
|
|
@@ -76,6 +77,7 @@ def _jwt_generator(): | |
|
|
||
| generator = JWTGenerator( | ||
| valid_for=conf.getint("execution_api", "jwt_expiration_time"), | ||
| # workload_valid_for uses the attrs default factory which reads the same config key | ||
| audience=conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], | ||
| issuer=conf.get("api_auth", "jwt_issuer", fallback=None), | ||
| # Since this one is used across components/server, there is no point trying to generate one, error | ||
|
|
@@ -142,6 +144,12 @@ async def dispatch(self, request: Request, call_next): | |
| validator: JWTValidator = await services.aget(JWTValidator) | ||
| claims = await validator.avalidated_claims(token, {}) | ||
|
|
||
| # Workload tokens are long-lived and meant to survive queue | ||
| # wait times so avoid refreshing them. If avalidated_claims | ||
| # raises for a workload token, the outer except handles it. | ||
| if claims.get("scope") == "workload": | ||
| return response | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The early return for workload tokens skips the refresh logic (correct), but it also skips the |
||
| now = int(time.time()) | ||
| validity = conf.getint("execution_api", "jwt_expiration_time") | ||
| refresh_when_less_than = max(int(validity * 0.20), 30) | ||
|
|
@@ -311,9 +319,13 @@ class InProcessExecutionAPI: | |
| @cached_property | ||
| def app(self): | ||
| if not self._app: | ||
| import svcs | ||
|
|
||
| from airflow.api_fastapi.auth.tokens import JWTGenerator | ||
| from airflow.api_fastapi.common.dagbag import create_dag_bag | ||
| from airflow.api_fastapi.execution_api.app import create_task_execution_api_app | ||
| from airflow.api_fastapi.execution_api.datamodels.token import TIToken | ||
| from airflow.api_fastapi.execution_api.deps import _container | ||
| from airflow.api_fastapi.execution_api.routes.connections import has_connection_access | ||
| from airflow.api_fastapi.execution_api.routes.variables import has_variable_access | ||
| from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access | ||
|
|
@@ -332,10 +344,30 @@ async def always_allow(request: Request): | |
| ) | ||
| return TIToken(id=ti_id, claims={"scope": "execution"}) | ||
|
|
||
| # Override _container (the svcs service locator behind DepContainer). | ||
| # The default _container reads request.app.state.svcs_registry, but | ||
| # Cadwyn's versioned sub-apps don't inherit the main app's state, | ||
| # so lookups raise ServiceNotFoundError. This registry provides | ||
anishgirianish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # services needed by routes called during dag.test(). | ||
| # Note: tokens generated by this stub are never validated since | ||
| # _jwt_bearer is overridden with always_allow in dag.test() mode. | ||
| stub_generator = JWTGenerator( | ||
| secret_key=secrets.token_urlsafe(32), | ||
| audience="in-process", | ||
| valid_for=3600, | ||
| ) | ||
| registry = svcs.Registry() | ||
| registry.register_value(JWTGenerator, stub_generator) | ||
|
|
||
| async def _in_process_container(request: Request): | ||
| async with svcs.Container(registry) as cont: | ||
| yield cont | ||
|
|
||
| self._app.dependency_overrides[_jwt_bearer] = always_allow | ||
| self._app.dependency_overrides[has_connection_access] = always_allow | ||
| self._app.dependency_overrides[has_variable_access] = always_allow | ||
| self._app.dependency_overrides[has_xcom_access] = always_allow | ||
| self._app.dependency_overrides[_container] = _in_process_container | ||
|
|
||
| return self._app | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,11 +16,15 @@ | |
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| import pytest | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.testclient import TestClient | ||
|
|
||
| from airflow.api_fastapi.app import cached_app | ||
| from airflow.api_fastapi.auth.tokens import JWTGenerator | ||
| from airflow.api_fastapi.execution_api.app import lifespan | ||
| from airflow.api_fastapi.execution_api.datamodels.token import TIToken | ||
| from airflow.api_fastapi.execution_api.security import _jwt_bearer | ||
|
|
||
|
|
@@ -53,6 +57,10 @@ async def mock_jwt_bearer(request: Request): | |
| exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer | ||
|
|
||
| with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: | ||
| mock_generator = MagicMock(spec=JWTGenerator) | ||
| mock_generator.generate.return_value = "mock-execution-token" | ||
| lifespan.registry.register_value(JWTGenerator, mock_generator) | ||
anishgirianish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| yield client | ||
|
|
||
| exec_app.dependency_overrides.pop(_jwt_bearer, None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
workload_valid_fordefault reads from config via_conf_factory, and_jwt_generator()inapp.pyalso reads the same config key and passes it explicitly. The explicit kwarg takes precedence, so the default factory never runs in production. Having two code paths that reference the same config key is easy to get out of sync -- consider dropping the attrs default (make it required likevalid_for) and always passing it explicitly, or drop the explicit kwarg in_jwt_generator()and let the default handle it.