diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 449019db7998d..7dcb8951f2699 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -44,9 +44,12 @@ get_sig_validation_args, get_signing_args, ) +from airflow.process_context import override_process_context if TYPE_CHECKING: import httpx + from a2wsgi.asgi_typing import ASGIApp as A2WSGIApp + from starlette.types import ASGIApp, Receive, Scope, Send import structlog from structlog.contextvars import bind_contextvars @@ -367,6 +370,17 @@ def _shutdown_loop( thread.join(timeout=5) +class _RequestScopedServerContextApp: + """Wrap an ASGI app so in-process requests behave like server-side API handling.""" + + def __init__(self, app: FastAPI) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + with override_process_context("server"): + await self.app(scope, receive, send) + + @attrs.define() class InProcessExecutionAPI: """ @@ -376,10 +390,11 @@ class InProcessExecutionAPI: needed so that we can use the sync httpx client """ + request_scoped_server_context: bool = attrs.field(default=False, kw_only=True) _app: FastAPI | None = None @cached_property - def app(self): + def app(self) -> FastAPI: if not self._app: from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken @@ -409,6 +424,12 @@ async def always_allow(request: Request): return self._app + @cached_property + def asgi_app(self) -> ASGIApp: + if self.request_scoped_server_context: + return _RequestScopedServerContextApp(self.app) + return self.app + @cached_property def transport(self) -> httpx.WSGITransport: import httpx @@ -420,7 +441,7 @@ def transport(self) -> httpx.WSGITransport: thread = threading.Thread(target=loop.run_forever, name="InProcessExecutionAPI-loop", daemon=True) thread.start() - middleware = ASGIMiddleware(self.app, loop=loop) + middleware = ASGIMiddleware(cast("A2WSGIApp", self.asgi_app), loop=loop) # https://github.com/abersheeran/a2wsgi/discussions/64 async def start_lifespan(cm: AsyncExitStack, app: FastAPI): @@ -447,4 +468,4 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): def atransport(self) -> httpx.ASGITransport: import httpx - return httpx.ASGITransport(app=self.app) + return httpx.ASGITransport(app=self.asgi_app) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 1b4b0f8f86768..b1ec6beaee562 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -20,7 +20,6 @@ import json import logging import re -import sys import warnings from contextlib import suppress from json import JSONDecodeError @@ -50,6 +49,7 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red """Compat stub — never raised by task-sdk <1.2.2.""" +from airflow.process_context import should_use_task_sdk_api_path from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -475,7 +475,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -566,7 +566,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index b06e73cd5f50a..f29bdd55fc7eb 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -20,7 +20,6 @@ import contextlib import json import logging -import sys import warnings from typing import TYPE_CHECKING, Any @@ -47,6 +46,7 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red """Compat stub — never raised by task-sdk <1.2.2.""" +from airflow.process_context import should_use_task_sdk_api_path from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -166,7 +166,7 @@ def get( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -226,7 +226,7 @@ def set( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -314,7 +314,7 @@ def update( # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -380,7 +380,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/airflow-core/src/airflow/process_context.py b/airflow-core/src/airflow/process_context.py new file mode 100644 index 0000000000000..1655d6e81fbb8 --- /dev/null +++ b/airflow-core/src/airflow/process_context.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import sys +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Literal + +__all__ = [ + "get_process_context", + "override_process_context", + "should_use_task_sdk_api_path", +] + +_PROCESS_CONTEXT_OVERRIDE: ContextVar[str | None] = ContextVar( + "_AIRFLOW_PROCESS_CONTEXT_OVERRIDE", + default=None, +) + + +def get_process_context() -> str | None: + """Return the current process context, preferring request-scoped overrides.""" + return _PROCESS_CONTEXT_OVERRIDE.get() or os.environ.get("_AIRFLOW_PROCESS_CONTEXT") + + +@contextmanager +def override_process_context(context: Literal["server", "client"]) -> Generator[None, None, None]: + """Temporarily override the current process context for the active execution flow.""" + token = _PROCESS_CONTEXT_OVERRIDE.set(context) + try: + yield + finally: + _PROCESS_CONTEXT_OVERRIDE.reset(token) + + +def should_use_task_sdk_api_path() -> bool: + """Return True when execution-context helpers should route through Task SDK APIs.""" + if get_process_context() == "server": + return False + + task_runner_module = sys.modules.get("airflow.sdk.execution_time.task_runner") + return bool(getattr(task_runner_module, "SUPERVISOR_COMMS", None)) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py index a2e3cb51fab32..19e87bafd14c1 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py @@ -17,6 +17,7 @@ from __future__ import annotations +import sys from unittest import mock import pytest @@ -105,6 +106,40 @@ def test_connection_get_from_env_var(self, client, session): "extra": '{"headers": "header"}', } + @mock.patch.dict( + "os.environ", + { + "AIRFLOW_CONN_TEST_CONN_SERVER": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}', + "_AIRFLOW_PROCESS_CONTEXT": "server", + }, + ) + def test_connection_get_uses_server_path_when_supervisor_comms_exists(self, client): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Connection.get", + side_effect=AssertionError( + "Execution API should not route through Task SDK Connection.get in server context" + ), + ), + ): + response = client.get("/execution/connections/test_conn_server") + + assert response.status_code == 200 + assert response.json() == { + "conn_id": "test_conn_server", + "conn_type": "http", + "host": "localhost", + "login": "root", + "password": "admin", + "schema": "https", + "port": 8080, + "extra": '{"headers": "header"}', + } + def test_connection_get_not_found(self, client): response = client.get("/execution/connections/non_existent_test_conn") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index f078d6c2fe06f..b5c6c6985fdd6 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -96,6 +96,28 @@ def test_variable_get_from_env_var(self, client, session): assert response.status_code == 200 assert response.json() == {"key": "key1", "value": "VALUE"} + @mock.patch.dict( + "os.environ", + {"AIRFLOW_VAR_KEY1": "VALUE", "_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_get_uses_server_path_when_supervisor_comms_exists(self, client): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.get", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.get in server context" + ), + ), + ): + response = client.get("/execution/variables/key1") + + assert response.status_code == 200 + assert response.json() == {"key": "key1", "value": "VALUE"} + @pytest.mark.parametrize( "key", [ @@ -158,6 +180,31 @@ def test_should_create_variable(self, client, key, payload, session): if "description" in payload: assert var_from_db.description == payload["description"] + @mock.patch.dict( + "os.environ", + {"_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_put_uses_server_path_when_supervisor_comms_exists(self, client, session): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.set", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.set in server context" + ), + ), + ): + response = client.put("/execution/variables/var_server_only", json={"value": "server_value"}) + + assert response.status_code == 201 + assert response.json()["message"] == "Variable successfully set" + var_from_db = session.scalars(select(Variable).where(Variable.key == "var_server_only")).first() + assert var_from_db is not None + assert var_from_db.val == "server_value" + @pytest.mark.parametrize( ("key", "payload", "error_type"), [ @@ -342,3 +389,29 @@ def test_should_not_delete_variable(self, client, session): vars = session.scalars(select(Variable)).all() assert len(vars) == 1 + + @mock.patch.dict( + "os.environ", + {"_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_delete_uses_server_path_when_supervisor_comms_exists(self, client, session): + Variable.set(key="var_server_delete", value="to_delete", session=session) + session.commit() + + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.delete", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.delete in server context" + ), + ), + ): + response = client.delete("/execution/variables/var_server_delete") + + assert response.status_code == 204 + session.expire_all() + assert session.scalar(select(Variable).where(Variable.key == "var_server_delete")) is None diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 94cabe5e4daf4..1fd0d755cae7d 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -49,6 +49,12 @@ def clear_fernet_cache(self): yield get_fernet.cache_clear() + @pytest.fixture(autouse=True) + def clear_process_context(self, monkeypatch): + """Isolate tests from process-wide execution context left behind by other imports.""" + monkeypatch.delenv("_AIRFLOW_PROCESS_CONTEXT", raising=False) + monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner", raising=False) + @pytest.mark.parametrize( ( "uri", @@ -455,6 +461,28 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn with pytest.raises(AirflowNotFoundException): Connection.get_connection_from_secrets("test_conn") + @mock.patch.dict("os.environ", {"_AIRFLOW_PROCESS_CONTEXT": "server"}) + def test_connection_from_json_uses_core_path_when_server_context(self): + """Server context should prefer core Connection.from_json even if comms exist.""" + fake_task_runner = mock.MagicMock() + fake_task_runner.SUPERVISOR_COMMS = True + + with ( + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Connection.from_json", + side_effect=AssertionError( + "Connection.from_json should not route through Task SDK in server context" + ), + ), + ): + result = Connection.from_json('{"conn_type": "http", "host": "localhost"}', conn_id="test_conn") + + assert isinstance(result, Connection) + assert result.conn_id == "test_conn" + assert result.conn_type == "http" + assert result.host == "localhost" + @mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None}) @mock.patch("airflow.sdk.Connection") @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection") diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 76dfd2009ac18..9251bf1c623d1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1922,10 +1922,10 @@ def _send_new_log_fd(self, req_id: int) -> None: @functools.lru_cache(maxsize=1) -def in_process_api_server(): +def in_process_api_server(*, request_scoped_server_context: bool = False): from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI - api = InProcessExecutionAPI() + api = InProcessExecutionAPI(request_scoped_server_context=request_scoped_server_context) return api @@ -2087,7 +2087,7 @@ def start( # type: ignore[override] @staticmethod def _api_client(dag=None): - api = in_process_api_server() + api = in_process_api_server(request_scoped_server_context=True) from airflow.api_fastapi.common.dagbag import dag_bag_from_app if dag is not None: diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 8dea3d0793f49..cc8f5e53108e5 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -3439,6 +3439,18 @@ def execute(self, context: Context): assert isinstance(result.error, _Failure) assert isinstance(collected[0], _Failure) + def test_api_client_uses_request_scoped_server_context(self): + api = mock.Mock() + api.transport = httpx.MockTransport(lambda request: httpx.Response(status_code=200, json={})) + + with patch( + "airflow.sdk.execution_time.supervisor.in_process_api_server", return_value=api + ) as factory: + client = InProcessTestSupervisor._api_client() + + factory.assert_called_once_with(request_scoped_server_context=True) + client.close() + class TestInProcessClient: def test_no_retries(self): @@ -4047,7 +4059,7 @@ def test_api_client_clears_dag_bag_override_when_dag_is_none(): # First call with a dag sets the override mock_dag = MagicMock() InProcessTestSupervisor._api_client(dag=mock_dag) - api = in_process_api_server() + api = in_process_api_server(request_scoped_server_context=True) assert dag_bag_from_app in api.app.dependency_overrides # Second call with dag=None should remove it