From 8258a1a2b4bf7a5dbd124951f67a3b18c5ff0c78 Mon Sep 17 00:00:00 2001 From: Henry Chen Date: Tue, 21 Apr 2026 14:55:20 +0800 Subject: [PATCH 1/3] fix infinite loop for Variable.get --- .../airflow/api_fastapi/execution_api/app.py | 22 +++++- airflow-core/src/airflow/models/connection.py | 6 +- airflow-core/src/airflow/models/variable.py | 15 ++-- .../src/airflow/utils/process_context.py | 53 ++++++++++++++ .../versions/head/test_connections.py | 35 +++++++++ .../versions/head/test_variables.py | 73 +++++++++++++++++++ .../tests/unit/models/test_connection.py | 22 ++++++ 7 files changed, 215 insertions(+), 11 deletions(-) create mode 100644 airflow-core/src/airflow/utils/process_context.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 88c7d23fff26a..2ba5b9fe07378 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -42,6 +42,7 @@ get_sig_validation_args, get_signing_args, ) +from airflow.utils.process_context import override_process_context if TYPE_CHECKING: import httpx @@ -348,6 +349,17 @@ def get_extra_schemas() -> dict[str, dict]: } +class _RequestScopedServerContextApp: + """Wrap an ASGI app so in-process requests behave like server-side API handling.""" + + def __init__(self, app: FastAPI) -> None: + self.app = app + + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + with override_process_context("server"): + await self.app(scope, receive, send) + + @attrs.define() class InProcessExecutionAPI: """ @@ -361,7 +373,7 @@ class InProcessExecutionAPI: _cm: AsyncExitStack | None = None @cached_property - def app(self): + def app(self) -> FastAPI: if not self._app: from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken @@ -391,6 +403,10 @@ async def always_allow(request: Request): return self._app + @cached_property + def request_scoped_app(self) -> _RequestScopedServerContextApp: + return _RequestScopedServerContextApp(self.app) + @cached_property def transport(self) -> httpx.WSGITransport: import asyncio @@ -398,7 +414,7 @@ def transport(self) -> httpx.WSGITransport: import httpx from a2wsgi import ASGIMiddleware - middleware = ASGIMiddleware(self.app) + middleware = ASGIMiddleware(self.request_scoped_app) # https://github.com/abersheeran/a2wsgi/discussions/64 async def start_lifespan(cm: AsyncExitStack, app: FastAPI): @@ -413,4 +429,4 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): def atransport(self) -> httpx.ASGITransport: import httpx - return httpx.ASGITransport(app=self.app) + return httpx.ASGITransport(app=self.request_scoped_app) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 1b4b0f8f86768..1ca24c336d5ed 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -20,7 +20,6 @@ import json import logging import re -import sys import warnings from contextlib import suppress from json import JSONDecodeError @@ -52,6 +51,7 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.process_context import should_use_task_sdk_api_path from airflow.utils.session import NEW_SESSION, provide_session log = logging.getLogger(__name__) @@ -475,7 +475,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -566,7 +566,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index b06e73cd5f50a..6346c92b8fe66 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -20,7 +20,6 @@ import contextlib import json import logging -import sys import warnings from typing import TYPE_CHECKING, Any @@ -49,6 +48,7 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.process_context import should_use_task_sdk_api_path from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import build_upsert_stmt, get_dialect_name @@ -120,6 +120,11 @@ def val(cls): """Get Airflow Variable from Metadata DB and decode it using the Fernet Key.""" return synonym("_val", descriptor=property(cls.get_val, cls.set_val)) + @staticmethod + def _should_use_task_sdk_api_path() -> bool: + """Return True when Variable operations should be routed through Task SDK APIs.""" + return should_use_task_sdk_api_path() + @classmethod def setdefault(cls, key, default, description=None, deserialize_json=False): """ @@ -166,7 +171,7 @@ def get( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if cls._should_use_task_sdk_api_path(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -226,7 +231,7 @@ def set( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if Variable._should_use_task_sdk_api_path(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -314,7 +319,7 @@ def update( # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if Variable._should_use_task_sdk_api_path(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -380,7 +385,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if Variable._should_use_task_sdk_api_path(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/airflow-core/src/airflow/utils/process_context.py b/airflow-core/src/airflow/utils/process_context.py new file mode 100644 index 0000000000000..7b98abd4dac0b --- /dev/null +++ b/airflow-core/src/airflow/utils/process_context.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import sys +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Literal + +_PROCESS_CONTEXT_OVERRIDE: ContextVar[str | None] = ContextVar( + "_AIRFLOW_PROCESS_CONTEXT_OVERRIDE", + default=None, +) + + +def get_process_context() -> str | None: + """Return the current process context, preferring request-scoped overrides.""" + return _PROCESS_CONTEXT_OVERRIDE.get() or os.environ.get("_AIRFLOW_PROCESS_CONTEXT") + + +@contextmanager +def override_process_context(context: Literal["server", "client"]) -> Generator[None, None, None]: + """Temporarily override the current process context for the active execution flow.""" + token = _PROCESS_CONTEXT_OVERRIDE.set(context) + try: + yield + finally: + _PROCESS_CONTEXT_OVERRIDE.reset(token) + + +def should_use_task_sdk_api_path() -> bool: + """Return True when execution-context helpers should route through Task SDK APIs.""" + if get_process_context() == "server": + return False + + task_runner_module = sys.modules.get("airflow.sdk.execution_time.task_runner") + return bool(getattr(task_runner_module, "SUPERVISOR_COMMS", None)) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py index a2e3cb51fab32..19e87bafd14c1 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py @@ -17,6 +17,7 @@ from __future__ import annotations +import sys from unittest import mock import pytest @@ -105,6 +106,40 @@ def test_connection_get_from_env_var(self, client, session): "extra": '{"headers": "header"}', } + @mock.patch.dict( + "os.environ", + { + "AIRFLOW_CONN_TEST_CONN_SERVER": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}', + "_AIRFLOW_PROCESS_CONTEXT": "server", + }, + ) + def test_connection_get_uses_server_path_when_supervisor_comms_exists(self, client): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Connection.get", + side_effect=AssertionError( + "Execution API should not route through Task SDK Connection.get in server context" + ), + ), + ): + response = client.get("/execution/connections/test_conn_server") + + assert response.status_code == 200 + assert response.json() == { + "conn_id": "test_conn_server", + "conn_type": "http", + "host": "localhost", + "login": "root", + "password": "admin", + "schema": "https", + "port": 8080, + "extra": '{"headers": "header"}', + } + def test_connection_get_not_found(self, client): response = client.get("/execution/connections/non_existent_test_conn") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index f078d6c2fe06f..b5c6c6985fdd6 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -96,6 +96,28 @@ def test_variable_get_from_env_var(self, client, session): assert response.status_code == 200 assert response.json() == {"key": "key1", "value": "VALUE"} + @mock.patch.dict( + "os.environ", + {"AIRFLOW_VAR_KEY1": "VALUE", "_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_get_uses_server_path_when_supervisor_comms_exists(self, client): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.get", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.get in server context" + ), + ), + ): + response = client.get("/execution/variables/key1") + + assert response.status_code == 200 + assert response.json() == {"key": "key1", "value": "VALUE"} + @pytest.mark.parametrize( "key", [ @@ -158,6 +180,31 @@ def test_should_create_variable(self, client, key, payload, session): if "description" in payload: assert var_from_db.description == payload["description"] + @mock.patch.dict( + "os.environ", + {"_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_put_uses_server_path_when_supervisor_comms_exists(self, client, session): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.set", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.set in server context" + ), + ), + ): + response = client.put("/execution/variables/var_server_only", json={"value": "server_value"}) + + assert response.status_code == 201 + assert response.json()["message"] == "Variable successfully set" + var_from_db = session.scalars(select(Variable).where(Variable.key == "var_server_only")).first() + assert var_from_db is not None + assert var_from_db.val == "server_value" + @pytest.mark.parametrize( ("key", "payload", "error_type"), [ @@ -342,3 +389,29 @@ def test_should_not_delete_variable(self, client, session): vars = session.scalars(select(Variable)).all() assert len(vars) == 1 + + @mock.patch.dict( + "os.environ", + {"_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_delete_uses_server_path_when_supervisor_comms_exists(self, client, session): + Variable.set(key="var_server_delete", value="to_delete", session=session) + session.commit() + + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.delete", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.delete in server context" + ), + ), + ): + response = client.delete("/execution/variables/var_server_delete") + + assert response.status_code == 204 + session.expire_all() + assert session.scalar(select(Variable).where(Variable.key == "var_server_delete")) is None diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 94cabe5e4daf4..93c823bd3dcf0 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -455,6 +455,28 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn with pytest.raises(AirflowNotFoundException): Connection.get_connection_from_secrets("test_conn") + @mock.patch.dict("os.environ", {"_AIRFLOW_PROCESS_CONTEXT": "server"}) + def test_connection_from_json_uses_core_path_when_server_context(self): + """Server context should prefer core Connection.from_json even if comms exist.""" + fake_task_runner = mock.MagicMock() + fake_task_runner.SUPERVISOR_COMMS = True + + with ( + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Connection.from_json", + side_effect=AssertionError( + "Connection.from_json should not route through Task SDK in server context" + ), + ), + ): + result = Connection.from_json('{"conn_type": "http", "host": "localhost"}', conn_id="test_conn") + + assert isinstance(result, Connection) + assert result.conn_id == "test_conn" + assert result.conn_type == "http" + assert result.host == "localhost" + @mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None}) @mock.patch("airflow.sdk.Connection") @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection") From 5d0b0204eaa07e31ad5e03df39c973e19efcf1fb Mon Sep 17 00:00:00 2001 From: Henry Chen Date: Tue, 21 Apr 2026 23:41:02 +0800 Subject: [PATCH 2/3] fix ci error --- airflow-core/src/airflow/models/connection.py | 2 +- airflow-core/src/airflow/models/variable.py | 15 +++++---------- .../src/airflow/{utils => }/process_context.py | 6 ++++++ airflow-core/tests/unit/models/test_connection.py | 6 ++++++ 4 files changed, 18 insertions(+), 11 deletions(-) rename airflow-core/src/airflow/{utils => }/process_context.py (94%) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 1ca24c336d5ed..b1ec6beaee562 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -49,9 +49,9 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red """Compat stub — never raised by task-sdk <1.2.2.""" +from airflow.process_context import should_use_task_sdk_api_path from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.process_context import should_use_task_sdk_api_path from airflow.utils.session import NEW_SESSION, provide_session log = logging.getLogger(__name__) diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 6346c92b8fe66..f29bdd55fc7eb 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -46,9 +46,9 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red """Compat stub — never raised by task-sdk <1.2.2.""" +from airflow.process_context import should_use_task_sdk_api_path from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.process_context import should_use_task_sdk_api_path from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import build_upsert_stmt, get_dialect_name @@ -120,11 +120,6 @@ def val(cls): """Get Airflow Variable from Metadata DB and decode it using the Fernet Key.""" return synonym("_val", descriptor=property(cls.get_val, cls.set_val)) - @staticmethod - def _should_use_task_sdk_api_path() -> bool: - """Return True when Variable operations should be routed through Task SDK APIs.""" - return should_use_task_sdk_api_path() - @classmethod def setdefault(cls, key, default, description=None, deserialize_json=False): """ @@ -171,7 +166,7 @@ def get( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if cls._should_use_task_sdk_api_path(): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -231,7 +226,7 @@ def set( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if Variable._should_use_task_sdk_api_path(): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -319,7 +314,7 @@ def update( # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if Variable._should_use_task_sdk_api_path(): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -385,7 +380,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if Variable._should_use_task_sdk_api_path(): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/airflow-core/src/airflow/utils/process_context.py b/airflow-core/src/airflow/process_context.py similarity index 94% rename from airflow-core/src/airflow/utils/process_context.py rename to airflow-core/src/airflow/process_context.py index 7b98abd4dac0b..1655d6e81fbb8 100644 --- a/airflow-core/src/airflow/utils/process_context.py +++ b/airflow-core/src/airflow/process_context.py @@ -23,6 +23,12 @@ from contextvars import ContextVar from typing import Literal +__all__ = [ + "get_process_context", + "override_process_context", + "should_use_task_sdk_api_path", +] + _PROCESS_CONTEXT_OVERRIDE: ContextVar[str | None] = ContextVar( "_AIRFLOW_PROCESS_CONTEXT_OVERRIDE", default=None, diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 93c823bd3dcf0..1fd0d755cae7d 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -49,6 +49,12 @@ def clear_fernet_cache(self): yield get_fernet.cache_clear() + @pytest.fixture(autouse=True) + def clear_process_context(self, monkeypatch): + """Isolate tests from process-wide execution context left behind by other imports.""" + monkeypatch.delenv("_AIRFLOW_PROCESS_CONTEXT", raising=False) + monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner", raising=False) + @pytest.mark.parametrize( ( "uri", From 348bc84472a635665044b824182c9ab1e73253a4 Mon Sep 17 00:00:00 2001 From: henry3260 Date: Thu, 7 May 2026 19:11:18 +0800 Subject: [PATCH 3/3] add unit tests to make sure expected behavior --- .../airflow/api_fastapi/execution_api/app.py | 17 +++-- .../airflow/sdk/execution_time/supervisor.py | 6 +- task-sdk/src/airflow/sdk/log.py | 13 +++- .../task_sdk/definitions/test_variables.py | 76 ++++++++++++++++++- .../execution_time/test_supervisor.py | 14 +++- 5 files changed, 114 insertions(+), 12 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 2ba5b9fe07378..b403d1598fd73 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -42,10 +42,12 @@ get_sig_validation_args, get_signing_args, ) -from airflow.utils.process_context import override_process_context +from airflow.process_context import override_process_context if TYPE_CHECKING: import httpx + from a2wsgi.asgi_typing import ASGIApp as A2WSGIApp + from starlette.types import ASGIApp, Receive, Scope, Send import structlog from structlog.contextvars import bind_contextvars @@ -355,7 +357,7 @@ class _RequestScopedServerContextApp: def __init__(self, app: FastAPI) -> None: self.app = app - async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: with override_process_context("server"): await self.app(scope, receive, send) @@ -369,6 +371,7 @@ class InProcessExecutionAPI: needed so that we can use the sync httpx client """ + request_scoped_server_context: bool = attrs.field(default=False, kw_only=True) _app: FastAPI | None = None _cm: AsyncExitStack | None = None @@ -404,8 +407,10 @@ async def always_allow(request: Request): return self._app @cached_property - def request_scoped_app(self) -> _RequestScopedServerContextApp: - return _RequestScopedServerContextApp(self.app) + def asgi_app(self) -> ASGIApp: + if self.request_scoped_server_context: + return _RequestScopedServerContextApp(self.app) + return self.app @cached_property def transport(self) -> httpx.WSGITransport: @@ -414,7 +419,7 @@ def transport(self) -> httpx.WSGITransport: import httpx from a2wsgi import ASGIMiddleware - middleware = ASGIMiddleware(self.request_scoped_app) + middleware = ASGIMiddleware(cast("A2WSGIApp", self.asgi_app)) # https://github.com/abersheeran/a2wsgi/discussions/64 async def start_lifespan(cm: AsyncExitStack, app: FastAPI): @@ -429,4 +434,4 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): def atransport(self) -> httpx.ASGITransport: import httpx - return httpx.ASGITransport(app=self.request_scoped_app) + return httpx.ASGITransport(app=self.asgi_app) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 6527e651041b8..bd3aabffd27af 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1921,10 +1921,10 @@ def _send_new_log_fd(self, req_id: int) -> None: @functools.lru_cache(maxsize=1) -def in_process_api_server(): +def in_process_api_server(*, request_scoped_server_context: bool = False): from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI - api = InProcessExecutionAPI() + api = InProcessExecutionAPI(request_scoped_server_context=request_scoped_server_context) return api @@ -2082,7 +2082,7 @@ def start( # type: ignore[override] @staticmethod def _api_client(dag=None): - api = in_process_api_server() + api = in_process_api_server(request_scoped_server_context=True) from airflow.api_fastapi.common.dagbag import dag_bag_from_app if dag is not None: diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index b843173fd65dc..f8724d8479544 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import os from functools import cache from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, TextIO @@ -265,6 +266,15 @@ def upload_to_remote(logger: FilteringBoundLogger, ti: RuntimeTI | None = None): ) +def _is_supervisor_comms_ready_for_mask_secret(comms: Any) -> bool: + """Return whether mask notifications can be sent without using stale virtualenv comms.""" + if not comms: + return False + if os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE"): + return getattr(comms, "socket", None) is not None + return True + + def mask_secret(secret: JsonValue, name: str | None = None) -> None: """ Mask a secret in both task process and supervisor process. @@ -284,7 +294,8 @@ def mask_secret(secret: JsonValue, name: str | None = None) -> None: from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import MaskSecret - if comms := getattr(task_runner, "SUPERVISOR_COMMS", None): + if _is_supervisor_comms_ready_for_mask_secret(getattr(task_runner, "SUPERVISOR_COMMS", None)): + comms = task_runner.SUPERVISOR_COMMS comms.send(MaskSecret(value=secret, name=name)) diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 6e94ccf503f8c..2047bace99e35 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -18,6 +18,8 @@ from __future__ import annotations import json +import queue +import threading from unittest import mock from unittest.mock import patch @@ -25,7 +27,13 @@ from airflow.sdk import Variable from airflow.sdk.configuration import initialize_secrets_backends -from airflow.sdk.execution_time.comms import GetVariableKeys, PutVariable, VariableKeysResult, VariableResult +from airflow.sdk.execution_time.comms import ( + GetVariableKeys, + MaskSecret, + PutVariable, + VariableKeysResult, + VariableResult, +) from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -263,6 +271,72 @@ def test_get_variable_env_var(self, mock_env_get, mock_supervisor_comms): Variable.get(key="fake_var_key") mock_env_get.assert_called_once_with(key="fake_var_key") + def test_get_variable_env_var_in_virtualenv_does_not_wait_for_supervisor_comms(self, monkeypatch): + """Regression test for Variable.get() hanging in PythonVirtualenvOperator child processes.""" + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.cache import SecretCache + + events = queue.Queue() + release_supervisor_comms = threading.Event() + + class BlockingSupervisorComms: + def send(self, *args, **kwargs): + events.put("supervisor_comms") + release_supervisor_comms.wait(timeout=5) + + SecretCache.reset() + monkeypatch.setenv("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", "1") + monkeypatch.setenv("AIRFLOW_VAR_DEMO_MESSAGE", "hello from env") + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", BlockingSupervisorComms(), raising=False) + + result = {} + error = {} + + def get_variable(): + try: + result["value"] = Variable.get(key="DEMO_MESSAGE") + except BaseException as exc: + error["exception"] = exc + finally: + events.put("done") + + thread = threading.Thread(target=get_variable, daemon=True) + thread.start() + first_event = events.get(timeout=5) + + release_supervisor_comms.set() + thread.join(timeout=5) + + assert first_event == "done", ( + "Variable.get() should not wait for supervisor comms when an env var backend returns the value " + "inside a PythonVirtualenvOperator child process." + ) + assert error == {} + assert result == {"value": "hello from env"} + + def test_get_variable_env_var_in_virtualenv_notifies_reinitialized_supervisor_comms(self, monkeypatch): + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.cache import SecretCache + + class ReinitializedSupervisorComms: + socket = object() + + def __init__(self): + self.sent_messages = [] + + def send(self, msg): + self.sent_messages.append(msg) + + comms = ReinitializedSupervisorComms() + + SecretCache.reset() + monkeypatch.setenv("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", "1") + monkeypatch.setenv("AIRFLOW_VAR_DEMO_MESSAGE", "hello from env") + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + + assert Variable.get(key="DEMO_MESSAGE") == "hello from env" + assert comms.sent_messages == [MaskSecret(value="hello from env", name="DEMO_MESSAGE")] + @conf_vars( { ("workers", "secrets_backend"): "airflow.secrets.local_filesystem.LocalFilesystemBackend", diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 62a9d00fdf286..4b0c22e267276 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -3481,6 +3481,18 @@ def execute(self, context: Context): assert isinstance(result.error, _Failure) assert isinstance(collected[0], _Failure) + def test_api_client_uses_request_scoped_server_context(self): + api = mock.Mock() + api.transport = httpx.MockTransport(lambda request: httpx.Response(status_code=200, json={})) + + with patch( + "airflow.sdk.execution_time.supervisor.in_process_api_server", return_value=api + ) as factory: + client = InProcessTestSupervisor._api_client() + + factory.assert_called_once_with(request_scoped_server_context=True) + client.close() + class TestInProcessClient: def test_no_retries(self): @@ -4029,7 +4041,7 @@ def test_api_client_clears_dag_bag_override_when_dag_is_none(): # First call with a dag sets the override mock_dag = MagicMock() InProcessTestSupervisor._api_client(dag=mock_dag) - api = in_process_api_server() + api = in_process_api_server(request_scoped_server_context=True) assert dag_bag_from_app in api.app.dependency_overrides # Second call with dag=None should remove it