Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@
get_sig_validation_args,
get_signing_args,
)
from airflow.process_context import override_process_context

if TYPE_CHECKING:
import httpx
from a2wsgi.asgi_typing import ASGIApp as A2WSGIApp
from starlette.types import ASGIApp, Receive, Scope, Send

import structlog
from structlog.contextvars import bind_contextvars
Expand Down Expand Up @@ -348,6 +351,17 @@ def get_extra_schemas() -> dict[str, dict]:
}


class _RequestScopedServerContextApp:
"""Wrap an ASGI app so in-process requests behave like server-side API handling."""

def __init__(self, app: FastAPI) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
with override_process_context("server"):
await self.app(scope, receive, send)
Comment thread
henry3260 marked this conversation as resolved.


@attrs.define()
class InProcessExecutionAPI:
"""
Expand All @@ -357,11 +371,12 @@ class InProcessExecutionAPI:
needed so that we can use the sync httpx client
"""

request_scoped_server_context: bool = attrs.field(default=False, kw_only=True)
_app: FastAPI | None = None
_cm: AsyncExitStack | None = None

@cached_property
def app(self):
def app(self) -> FastAPI:
if not self._app:
from airflow.api_fastapi.common.dagbag import create_dag_bag
from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken
Expand Down Expand Up @@ -391,14 +406,20 @@ async def always_allow(request: Request):

return self._app

@cached_property
def asgi_app(self) -> ASGIApp:
if self.request_scoped_server_context:
return _RequestScopedServerContextApp(self.app)
return self.app
Comment on lines 407 to +413

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems no need to introduce another cached_property, making change at end of existing app cached_property would be sufficient.

Suggested change
return self._app
@cached_property
def asgi_app(self) -> ASGIApp:
if self.request_scoped_server_context:
return _RequestScopedServerContextApp(self.app)
return self.app
if self.request_scoped_server_context:
return _RequestScopedServerContextApp(self._app)
return self._app


@cached_property
def transport(self) -> httpx.WSGITransport:
import asyncio

import httpx
from a2wsgi import ASGIMiddleware

middleware = ASGIMiddleware(self.app)
middleware = ASGIMiddleware(cast("A2WSGIApp", self.asgi_app))
Comment on lines -401 to +422

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can revert this line change.


# https://github.com/abersheeran/a2wsgi/discussions/64
async def start_lifespan(cm: AsyncExitStack, app: FastAPI):
Expand All @@ -413,4 +434,4 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI):
def atransport(self) -> httpx.ASGITransport:
import httpx

return httpx.ASGITransport(app=self.app)
return httpx.ASGITransport(app=self.asgi_app)
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import logging
import re
import sys
import warnings
from contextlib import suppress
from json import JSONDecodeError
Expand Down Expand Up @@ -50,6 +49,7 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red
"""Compat stub — never raised by task-sdk <1.2.2."""


from airflow.process_context import should_use_task_sdk_api_path
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -475,7 +475,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None)

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
from airflow.sdk import Connection as TaskSDKConnection
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

Expand Down Expand Up @@ -566,7 +566,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s

@classmethod
def from_json(cls, value, conn_id=None) -> Connection:
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
from airflow.sdk import Connection as TaskSDKConnection

warnings.warn(
Expand Down
10 changes: 5 additions & 5 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import contextlib
import json
import logging
import sys
import warnings
from typing import TYPE_CHECKING, Any

Expand All @@ -47,6 +46,7 @@ class AirflowSecretsBackendAccessDenied(PermissionError): # type: ignore[no-red
"""Compat stub — never raised by task-sdk <1.2.2."""


from airflow.process_context import should_use_task_sdk_api_path
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -166,7 +166,7 @@ def get(

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.get from `airflow.models` is deprecated."
"Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -226,7 +226,7 @@ def set(

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
Comment thread
henry3260 marked this conversation as resolved.
warnings.warn(
"Using Variable.set from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -314,7 +314,7 @@ def update(

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.update from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.",
Expand Down Expand Up @@ -380,7 +380,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.delete from `airflow.models` is deprecated."
"Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down
59 changes: 59 additions & 0 deletions airflow-core/src/airflow/process_context.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had similar consolidation in https://github.com/apache/airflow/pull/59876/changes#diff-7694d13e2f87c84d20b0b8b44797bf96d754ae270204217e518082decc74649bR102 (PR was closed because was not accepted by other maintainers and went stale) - I'd favor "hiding" this small utility in another module where fitting rather than adding a new module just for the context detection.

@jason810496 jason810496 Jun 15, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the other side to keep this as-is so the lifecycle of _PROCESS_CONTEXT_OVERRIDE will be more clear. But no strong opinion for this.

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
import sys
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Literal

__all__ = [
"get_process_context",
"override_process_context",
"should_use_task_sdk_api_path",
]

_PROCESS_CONTEXT_OVERRIDE: ContextVar[str | None] = ContextVar(
"_AIRFLOW_PROCESS_CONTEXT_OVERRIDE",
default=None,
)
Comment thread
henry3260 marked this conversation as resolved.

Comment thread
jason810496 marked this conversation as resolved.

def get_process_context() -> str | None:
"""Return the current process context, preferring request-scoped overrides."""
return _PROCESS_CONTEXT_OVERRIDE.get() or os.environ.get("_AIRFLOW_PROCESS_CONTEXT")


@contextmanager
def override_process_context(context: Literal["server", "client"]) -> Generator[None, None, None]:
"""Temporarily override the current process context for the active execution flow."""
token = _PROCESS_CONTEXT_OVERRIDE.set(context)
try:
yield
finally:
_PROCESS_CONTEXT_OVERRIDE.reset(token)


def should_use_task_sdk_api_path() -> bool:
"""Return True when execution-context helpers should route through Task SDK APIs."""
if get_process_context() == "server":
return False

task_runner_module = sys.modules.get("airflow.sdk.execution_time.task_runner")
return bool(getattr(task_runner_module, "SUPERVISOR_COMMS", None))
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import sys
from unittest import mock

import pytest
Expand Down Expand Up @@ -105,6 +106,40 @@ def test_connection_get_from_env_var(self, client, session):
"extra": '{"headers": "header"}',
}

@mock.patch.dict(
"os.environ",
{
"AIRFLOW_CONN_TEST_CONN_SERVER": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}',
"_AIRFLOW_PROCESS_CONTEXT": "server",
},
)
def test_connection_get_uses_server_path_when_supervisor_comms_exists(self, client):
fake_task_runner = mock.Mock()
fake_task_runner.SUPERVISOR_COMMS = object()

with (
mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}),
mock.patch(
"airflow.sdk.Connection.get",
side_effect=AssertionError(
"Execution API should not route through Task SDK Connection.get in server context"
),
),
):
response = client.get("/execution/connections/test_conn_server")

assert response.status_code == 200
assert response.json() == {
"conn_id": "test_conn_server",
"conn_type": "http",
"host": "localhost",
"login": "root",
"password": "admin",
"schema": "https",
"port": 8080,
"extra": '{"headers": "header"}',
}

def test_connection_get_not_found(self, client):
response = client.get("/execution/connections/non_existent_test_conn")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ def test_variable_get_from_env_var(self, client, session):
assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}

@mock.patch.dict(
"os.environ",
{"AIRFLOW_VAR_KEY1": "VALUE", "_AIRFLOW_PROCESS_CONTEXT": "server"},
)
def test_variable_get_uses_server_path_when_supervisor_comms_exists(self, client):
fake_task_runner = mock.Mock()
fake_task_runner.SUPERVISOR_COMMS = object()

with (
mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}),
mock.patch(
"airflow.sdk.Variable.get",
side_effect=AssertionError(
"Execution API should not route through Task SDK Variable.get in server context"
),
),
):
response = client.get("/execution/variables/key1")

assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}

@pytest.mark.parametrize(
"key",
[
Expand Down Expand Up @@ -158,6 +180,31 @@ def test_should_create_variable(self, client, key, payload, session):
if "description" in payload:
assert var_from_db.description == payload["description"]

@mock.patch.dict(
"os.environ",
{"_AIRFLOW_PROCESS_CONTEXT": "server"},
)
def test_variable_put_uses_server_path_when_supervisor_comms_exists(self, client, session):
fake_task_runner = mock.Mock()
fake_task_runner.SUPERVISOR_COMMS = object()

with (
mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}),
mock.patch(
"airflow.sdk.Variable.set",
side_effect=AssertionError(
"Execution API should not route through Task SDK Variable.set in server context"
),
),
):
response = client.put("/execution/variables/var_server_only", json={"value": "server_value"})

assert response.status_code == 201
assert response.json()["message"] == "Variable successfully set"
var_from_db = session.scalars(select(Variable).where(Variable.key == "var_server_only")).first()
assert var_from_db is not None
assert var_from_db.val == "server_value"

@pytest.mark.parametrize(
("key", "payload", "error_type"),
[
Expand Down Expand Up @@ -342,3 +389,29 @@ def test_should_not_delete_variable(self, client, session):

vars = session.scalars(select(Variable)).all()
assert len(vars) == 1

@mock.patch.dict(
"os.environ",
{"_AIRFLOW_PROCESS_CONTEXT": "server"},
)
def test_variable_delete_uses_server_path_when_supervisor_comms_exists(self, client, session):
Variable.set(key="var_server_delete", value="to_delete", session=session)
session.commit()

fake_task_runner = mock.Mock()
fake_task_runner.SUPERVISOR_COMMS = object()

with (
mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}),
mock.patch(
"airflow.sdk.Variable.delete",
side_effect=AssertionError(
"Execution API should not route through Task SDK Variable.delete in server context"
),
),
):
response = client.delete("/execution/variables/var_server_delete")

assert response.status_code == 204
session.expire_all()
assert session.scalar(select(Variable).where(Variable.key == "var_server_delete")) is None
Loading
Loading