From 8c806922a79819751e95bf909c086f39da68050e Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 24 Feb 2026 15:38:50 -0800 Subject: [PATCH 1/5] Enforce and document context injection into custom callbacks --- airflow-core/docs/howto/deadline-alerts.rst | 7 ++++ .../airflow/executors/workloads/callback.py | 3 +- airflow-core/src/airflow/triggers/callback.py | 5 ++- airflow-core/src/airflow/utils/helpers.py | 24 ++++++++++++ airflow-core/tests/unit/utils/test_helpers.py | 38 +++++++++++++++++++ 5 files changed, 74 insertions(+), 3 deletions(-) diff --git a/airflow-core/docs/howto/deadline-alerts.rst b/airflow-core/docs/howto/deadline-alerts.rst index ab1e9da5f69a6..4f097e7d9be2b 100644 --- a/airflow-core/docs/howto/deadline-alerts.rst +++ b/airflow-core/docs/howto/deadline-alerts.rst @@ -237,6 +237,13 @@ Triggerer's system path. Nested callables are not currently supported. * The Triggerer will need to be restarted when a callback is added or changed in order to reload the file. +.. note:: + **Airflow context injection:** When a deadline is missed, Airflow automatically injects a ``context`` + kwarg into the callback containing information about the Dag run and the deadline. To receive it, + accept ``**kwargs`` in your callback and access ``kwargs["context"]``, or add a named ``context`` + parameter. Callbacks that don't need the context can omit it — Airflow will only pass kwargs that + the callable accepts. The ``context`` keyword is reserved and cannot be used in the ``kwargs`` + parameter of a ``Callback``; attempting to do so will raise a ``ValueError`` at DAG parse time. A **custom asynchronous callback** might look like this: diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index c15bb33fba70e..cd391bc20e54e 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -28,6 +28,7 @@ from pydantic import BaseModel, Field, field_validator from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.utils.helpers import filter_kwargs if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -141,7 +142,7 @@ def execute_callback_workload( log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) # If the callback is a callable, call it. If it is a class, instantiate it. - result = callback_callable(**callback_kwargs) + result = callback_callable(**filter_kwargs(callback_callable, callback_kwargs)) # If the callback is a class then it is now instantiated and callable, call it. if callable(result): diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index aadfffe38cc6d..856bcdeb59219 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -25,6 +25,7 @@ from airflow._shared.module_loading import import_string, qualname from airflow.models.callback import CallbackState from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.helpers import filter_kwargs log = logging.getLogger(__name__) @@ -53,8 +54,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) callback = import_string(self.callback_path) - # TODO: get full context and run template rendering. Right now, a simple context in included in `callback_kwargs` - result = await callback(**self.callback_kwargs) + # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` + result = await callback(**filter_kwargs(callback, self.callback_kwargs)) yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS, PAYLOAD_BODY_KEY: result}) except Exception as e: diff --git a/airflow-core/src/airflow/utils/helpers.py b/airflow-core/src/airflow/utils/helpers.py index 50bd8b82a4622..d8331664c7bed 100644 --- a/airflow-core/src/airflow/utils/helpers.py +++ b/airflow-core/src/airflow/utils/helpers.py @@ -18,6 +18,7 @@ from __future__ import annotations import copy +import inspect import itertools import re import signal @@ -295,6 +296,29 @@ def is_empty(x): return val +def filter_kwargs(callable_obj: object, kwargs: dict) -> dict: + """ + Filter kwargs to only include parameters the callable accepts. + + If the callable accepts **kwargs (VAR_KEYWORD), all kwargs are passed through. + Otherwise, only kwargs matching named parameters are passed. This is useful + when calling user-provided callables that may not accept all of the kwargs that + Airflow injects (e.g. context). + + :param callable_obj: The callable to inspect + :param kwargs: The full set of kwargs to filter + """ + try: + signature = inspect.signature(callable_obj) + except (ValueError, TypeError): + return kwargs + + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + return kwargs + + return {k: v for k, v in kwargs.items() if k in set(signature.parameters.keys())} + + __deprecated_imports = { "render_template_as_native": "airflow.sdk.definitions.context", "render_template_to_string": "airflow.sdk.definitions.context", diff --git a/airflow-core/tests/unit/utils/test_helpers.py b/airflow-core/tests/unit/utils/test_helpers.py index 8e16d11869843..68f98dd38f535 100644 --- a/airflow-core/tests/unit/utils/test_helpers.py +++ b/airflow-core/tests/unit/utils/test_helpers.py @@ -31,6 +31,7 @@ at_most_one, build_airflow_dagrun_url, exactly_one, + filter_kwargs, merge_dicts, prune_dict, validate_key, @@ -233,6 +234,43 @@ def test_prune_dict(self, mode, expected): assert prune_dict(d2, mode=mode) == expected +class TestFilterKwargs: + def test_passes_all_when_var_keyword_present(self): + def func_with_var_keyword(**kwargs): + pass + + kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} + assert filter_kwargs(func_with_var_keyword, kwargs) == kwargs + + def test_filters_to_named_params_only(self): + def func_with_named_params(a, b): + pass + + kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} + assert filter_kwargs(func_with_named_params, kwargs) == {"a": 1, "b": 2} + + def test_no_params_returns_empty(self): + def func_no_params(): + pass + + kwargs = {"context": {"dag_run": {}}, "extra": "value"} + assert filter_kwargs(func_no_params, kwargs) == {} + + def test_uninspectable_callable_passes_all(self): + # built-in len() cannot be inspected with inspect.signature in some Python versions + kwargs = {"a": 1} + result = filter_kwargs(len, kwargs) + + assert isinstance(result, dict) + + def test_mixed_named_and_extra_kwargs(self): + def func(context, alert_type): + pass + + kwargs = {"context": {"dag_run": {}}, "alert_type": "deadline", "extra": "dropped"} + assert filter_kwargs(func, kwargs) == {"context": {"dag_run": {}}, "alert_type": "deadline"} + + class MockJobRunner(BaseJobRunner): job_type = "MockJob" From c08a2a2a5b653ff17e3b3445eed48f49fd6bed9e Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Sat, 28 Feb 2026 13:07:20 -0800 Subject: [PATCH 2/5] mypy type-hint fixes --- airflow-core/src/airflow/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/utils/helpers.py b/airflow-core/src/airflow/utils/helpers.py index d8331664c7bed..ec35e700ce48e 100644 --- a/airflow-core/src/airflow/utils/helpers.py +++ b/airflow-core/src/airflow/utils/helpers.py @@ -296,7 +296,7 @@ def is_empty(x): return val -def filter_kwargs(callable_obj: object, kwargs: dict) -> dict: +def filter_kwargs(callable_obj: Callable[..., Any], kwargs: dict) -> dict: """ Filter kwargs to only include parameters the callable accepts. From 96cb45b8ecb1c51d52d46ee5669b697deaf6fc80 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 11:44:11 -0800 Subject: [PATCH 3/5] Move the kwargs filter helper into models/callbacks --- .../airflow/executors/workloads/callback.py | 3 +- airflow-core/src/airflow/models/callback.py | 25 ++++++++++++ airflow-core/src/airflow/triggers/callback.py | 3 +- airflow-core/src/airflow/utils/helpers.py | 24 ------------ .../tests/unit/models/test_callback.py | 38 +++++++++++++++++++ airflow-core/tests/unit/utils/test_helpers.py | 38 ------------------- 6 files changed, 67 insertions(+), 64 deletions(-) diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index cd391bc20e54e..bd2ce99506281 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -28,7 +28,6 @@ from pydantic import BaseModel, Field, field_validator from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo -from airflow.utils.helpers import filter_kwargs if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -126,6 +125,8 @@ def execute_callback_workload( :param log: Logger instance for recording execution :return: Tuple of (success: bool, error_message: str | None) """ + from airflow.models.callback import filter_kwargs # circular import + callback_path = callback.data.get("path") callback_kwargs = callback.data.get("kwargs", {}) diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index ea482ab7ba8d5..7d8d6b9058227 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import inspect +from collections.abc import Callable from datetime import datetime from enum import Enum from importlib import import_module @@ -50,6 +52,29 @@ TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) +def filter_kwargs(callable_obj: Callable[..., Any], kwargs: dict) -> dict: + """ + Filter kwargs to only include parameters the callable accepts. + + If the callable accepts **kwargs (VAR_KEYWORD), all kwargs are passed through. + Otherwise, only kwargs matching named parameters are passed. This is useful + when calling user-provided callables that may not accept all of the kwargs that + Airflow injects (e.g. context). + + :param callable_obj: The callable to inspect + :param kwargs: The full set of kwargs to filter + """ + try: + signature = inspect.signature(callable_obj) + except (ValueError, TypeError): + return kwargs + + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + return kwargs + + return {k: v for k, v in kwargs.items() if k in set(signature.parameters.keys())} + + class CallbackType(str, Enum): """ Types of Callbacks. diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index 856bcdeb59219..8f6c7e3a82107 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -25,7 +25,6 @@ from airflow._shared.module_loading import import_string, qualname from airflow.models.callback import CallbackState from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.helpers import filter_kwargs log = logging.getLogger(__name__) @@ -50,6 +49,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) async def run(self) -> AsyncIterator[TriggerEvent]: + from airflow.models.callback import filter_kwargs # circular import + try: yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) callback = import_string(self.callback_path) diff --git a/airflow-core/src/airflow/utils/helpers.py b/airflow-core/src/airflow/utils/helpers.py index ec35e700ce48e..50bd8b82a4622 100644 --- a/airflow-core/src/airflow/utils/helpers.py +++ b/airflow-core/src/airflow/utils/helpers.py @@ -18,7 +18,6 @@ from __future__ import annotations import copy -import inspect import itertools import re import signal @@ -296,29 +295,6 @@ def is_empty(x): return val -def filter_kwargs(callable_obj: Callable[..., Any], kwargs: dict) -> dict: - """ - Filter kwargs to only include parameters the callable accepts. - - If the callable accepts **kwargs (VAR_KEYWORD), all kwargs are passed through. - Otherwise, only kwargs matching named parameters are passed. This is useful - when calling user-provided callables that may not accept all of the kwargs that - Airflow injects (e.g. context). - - :param callable_obj: The callable to inspect - :param kwargs: The full set of kwargs to filter - """ - try: - signature = inspect.signature(callable_obj) - except (ValueError, TypeError): - return kwargs - - if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): - return kwargs - - return {k: v for k, v in kwargs.items() if k in set(signature.parameters.keys())} - - __deprecated_imports = { "render_template_as_native": "airflow.sdk.definitions.context", "render_template_to_string": "airflow.sdk.definitions.context", diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 6ab6ad2d02df7..4c2acefa89813 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -26,6 +26,7 @@ CallbackState, ExecutorCallback, TriggererCallback, + filter_kwargs, ) from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.triggers.base import TriggerEvent @@ -208,3 +209,40 @@ def test_queue(self): # Note: class DagProcessorCallback is tested in airflow-core/tests/unit/dag_processing/test_manager.py + + +class TestFilterKwargs: + def test_passes_all_when_var_keyword_present(self): + def func_with_var_keyword(**kwargs): + pass + + kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} + assert filter_kwargs(func_with_var_keyword, kwargs) == kwargs + + def test_filters_to_named_params_only(self): + def func_with_named_params(a, b): + pass + + kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} + assert filter_kwargs(func_with_named_params, kwargs) == {"a": 1, "b": 2} + + def test_no_params_returns_empty(self): + def func_no_params(): + pass + + kwargs = {"context": {"dag_run": {}}, "extra": "value"} + assert filter_kwargs(func_no_params, kwargs) == {} + + def test_uninspectable_callable_passes_all(self): + # built-in len() cannot be inspected with inspect.signature in some Python versions + kwargs = {"a": 1} + result = filter_kwargs(len, kwargs) + + assert isinstance(result, dict) + + def test_mixed_named_and_extra_kwargs(self): + def func(context, alert_type): + pass + + kwargs = {"context": {"dag_run": {}}, "alert_type": "deadline", "extra": "dropped"} + assert filter_kwargs(func, kwargs) == {"context": {"dag_run": {}}, "alert_type": "deadline"} diff --git a/airflow-core/tests/unit/utils/test_helpers.py b/airflow-core/tests/unit/utils/test_helpers.py index 68f98dd38f535..8e16d11869843 100644 --- a/airflow-core/tests/unit/utils/test_helpers.py +++ b/airflow-core/tests/unit/utils/test_helpers.py @@ -31,7 +31,6 @@ at_most_one, build_airflow_dagrun_url, exactly_one, - filter_kwargs, merge_dicts, prune_dict, validate_key, @@ -234,43 +233,6 @@ def test_prune_dict(self, mode, expected): assert prune_dict(d2, mode=mode) == expected -class TestFilterKwargs: - def test_passes_all_when_var_keyword_present(self): - def func_with_var_keyword(**kwargs): - pass - - kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} - assert filter_kwargs(func_with_var_keyword, kwargs) == kwargs - - def test_filters_to_named_params_only(self): - def func_with_named_params(a, b): - pass - - kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} - assert filter_kwargs(func_with_named_params, kwargs) == {"a": 1, "b": 2} - - def test_no_params_returns_empty(self): - def func_no_params(): - pass - - kwargs = {"context": {"dag_run": {}}, "extra": "value"} - assert filter_kwargs(func_no_params, kwargs) == {} - - def test_uninspectable_callable_passes_all(self): - # built-in len() cannot be inspected with inspect.signature in some Python versions - kwargs = {"a": 1} - result = filter_kwargs(len, kwargs) - - assert isinstance(result, dict) - - def test_mixed_named_and_extra_kwargs(self): - def func(context, alert_type): - pass - - kwargs = {"context": {"dag_run": {}}, "alert_type": "deadline", "extra": "dropped"} - assert filter_kwargs(func, kwargs) == {"context": {"dag_run": {}}, "alert_type": "deadline"} - - class MockJobRunner(BaseJobRunner): job_type = "MockJob" From f16b00339197a292ee18cedd42035c501ce2e6b7 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 12:48:42 -0800 Subject: [PATCH 4/5] replace `inject` with less-scary alternatives and rephrased things a little --- airflow-core/docs/howto/deadline-alerts.rst | 2 +- airflow-core/src/airflow/models/callback.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow-core/docs/howto/deadline-alerts.rst b/airflow-core/docs/howto/deadline-alerts.rst index 4f097e7d9be2b..643e17fc185fb 100644 --- a/airflow-core/docs/howto/deadline-alerts.rst +++ b/airflow-core/docs/howto/deadline-alerts.rst @@ -238,7 +238,7 @@ Triggerer's system path. * The Triggerer will need to be restarted when a callback is added or changed in order to reload the file. .. note:: - **Airflow context injection:** When a deadline is missed, Airflow automatically injects a ``context`` + **Airflow ``context``:** When a deadline is missed, Airflow automatically provides a ``context`` kwarg into the callback containing information about the Dag run and the deadline. To receive it, accept ``**kwargs`` in your callback and access ``kwargs["context"]``, or add a named ``context`` parameter. Callbacks that don't need the context can omit it — Airflow will only pass kwargs that diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index 7d8d6b9058227..7e7211462ee36 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -52,20 +52,20 @@ TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) -def filter_kwargs(callable_obj: Callable[..., Any], kwargs: dict) -> dict: +def filter_kwargs(callback_obj: Callable[..., Any], kwargs: dict) -> dict: """ - Filter kwargs to only include parameters the callable accepts. + Filter kwargs to only include parameters the callback accepts. If the callable accepts **kwargs (VAR_KEYWORD), all kwargs are passed through. Otherwise, only kwargs matching named parameters are passed. This is useful - when calling user-provided callables that may not accept all of the kwargs that - Airflow injects (e.g. context). + when calling user-provided callbacks that may not accept all kwargs that + Airflow provides (e.g. context). - :param callable_obj: The callable to inspect + :param callback_obj: The callback to inspect :param kwargs: The full set of kwargs to filter """ try: - signature = inspect.signature(callable_obj) + signature = inspect.signature(callback_obj) except (ValueError, TypeError): return kwargs From 92124077307a81287ba5c0ea8fe015cf5f00ecc9 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 3 Mar 2026 20:14:23 -0800 Subject: [PATCH 5/5] clean up filter_kwargs --- .../airflow/executors/workloads/callback.py | 10 +++-- airflow-core/src/airflow/models/callback.py | 25 +++-------- airflow-core/src/airflow/triggers/callback.py | 13 +++--- .../tests/unit/models/test_callback.py | 43 ++++++++----------- .../tests/unit/triggers/test_callback.py | 23 +++++++--- 5 files changed, 56 insertions(+), 58 deletions(-) diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index bd2ce99506281..2563f9a78f553 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -125,7 +125,7 @@ def execute_callback_workload( :param log: Logger instance for recording execution :return: Tuple of (success: bool, error_message: str | None) """ - from airflow.models.callback import filter_kwargs # circular import + from airflow.models.callback import _accepts_context # circular import callback_path = callback.data.get("path") callback_kwargs = callback.data.get("kwargs", {}) @@ -139,15 +139,19 @@ def execute_callback_workload( module_path, function_name = callback_path.rsplit(".", 1) module = import_module(module_path) callback_callable = getattr(module, function_name) + context = callback_kwargs.pop("context", None) log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) # If the callback is a callable, call it. If it is a class, instantiate it. - result = callback_callable(**filter_kwargs(callback_callable, callback_kwargs)) + # Rather than forcing all custom callbacks to accept context, conditionally provide it only if supported. + if _accepts_context(callback_callable) and context is not None: + result = callback_callable(**callback_kwargs, context=context) + else: + result = callback_callable(**callback_kwargs) # If the callback is a class then it is now instantiated and callable, call it. if callable(result): - context = callback_kwargs.get("context", {}) log.debug("Calling result with context for %s", callback_path) result = result(context) diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index 7e7211462ee36..e2c46153a7170 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -52,27 +52,14 @@ TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) -def filter_kwargs(callback_obj: Callable[..., Any], kwargs: dict) -> dict: - """ - Filter kwargs to only include parameters the callback accepts. - - If the callable accepts **kwargs (VAR_KEYWORD), all kwargs are passed through. - Otherwise, only kwargs matching named parameters are passed. This is useful - when calling user-provided callbacks that may not accept all kwargs that - Airflow provides (e.g. context). - - :param callback_obj: The callback to inspect - :param kwargs: The full set of kwargs to filter - """ +def _accepts_context(callback: Callable) -> bool: + """Check if callback accepts a 'context' parameter or **kwargs.""" try: - signature = inspect.signature(callback_obj) + sig = inspect.signature(callback) except (ValueError, TypeError): - return kwargs - - if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): - return kwargs - - return {k: v for k, v in kwargs.items() if k in set(signature.parameters.keys())} + return True + params = sig.parameters + return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) class CallbackType(str, Enum): diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index 8f6c7e3a82107..9c2470c77eae6 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -23,7 +23,7 @@ from typing import Any from airflow._shared.module_loading import import_string, qualname -from airflow.models.callback import CallbackState +from airflow.models.callback import CallbackState, _accepts_context from airflow.triggers.base import BaseTrigger, TriggerEvent log = logging.getLogger(__name__) @@ -49,14 +49,17 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) async def run(self) -> AsyncIterator[TriggerEvent]: - from airflow.models.callback import filter_kwargs # circular import - try: yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) callback = import_string(self.callback_path) - # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` - result = await callback(**filter_kwargs(callback, self.callback_kwargs)) + context = self.callback_kwargs.pop("context", None) + + if _accepts_context(callback) and context is not None: + result = await callback(**self.callback_kwargs, context=context) + else: + result = await callback(**self.callback_kwargs) + yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS, PAYLOAD_BODY_KEY: result}) except Exception as e: diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 4c2acefa89813..20bbba29fc108 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest.mock import patch + import pytest from sqlalchemy import select @@ -26,7 +28,7 @@ CallbackState, ExecutorCallback, TriggererCallback, - filter_kwargs, + _accepts_context, ) from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.triggers.base import TriggerEvent @@ -211,38 +213,31 @@ def test_queue(self): # Note: class DagProcessorCallback is tested in airflow-core/tests/unit/dag_processing/test_manager.py -class TestFilterKwargs: - def test_passes_all_when_var_keyword_present(self): +class TestAcceptsContext: + def test_true_when_var_keyword_present(self): def func_with_var_keyword(**kwargs): pass - kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} - assert filter_kwargs(func_with_var_keyword, kwargs) == kwargs + assert _accepts_context(func_with_var_keyword) is True - def test_filters_to_named_params_only(self): - def func_with_named_params(a, b): + def test_true_when_context_param_present(self): + def func_with_context(context, alert_type): pass - kwargs = {"a": 1, "b": 2, "context": {"dag_run": {}}} - assert filter_kwargs(func_with_named_params, kwargs) == {"a": 1, "b": 2} + assert _accepts_context(func_with_context) is True - def test_no_params_returns_empty(self): - def func_no_params(): + def test_false_when_no_context_or_var_keyword(self): + def func_without_context(a, b): pass - kwargs = {"context": {"dag_run": {}}, "extra": "value"} - assert filter_kwargs(func_no_params, kwargs) == {} - - def test_uninspectable_callable_passes_all(self): - # built-in len() cannot be inspected with inspect.signature in some Python versions - kwargs = {"a": 1} - result = filter_kwargs(len, kwargs) + assert _accepts_context(func_without_context) is False - assert isinstance(result, dict) - - def test_mixed_named_and_extra_kwargs(self): - def func(context, alert_type): + def test_false_when_no_params(self): + def func_no_params(): pass - kwargs = {"context": {"dag_run": {}}, "alert_type": "deadline", "extra": "dropped"} - assert filter_kwargs(func, kwargs) == {"context": {"dag_run": {}}, "alert_type": "deadline"} + assert _accepts_context(func_no_params) is False + + def test_true_for_uninspectable_callable(self): + with patch("airflow.models.callback.inspect.signature", side_effect=ValueError): + assert _accepts_context(lambda: None) is True diff --git a/airflow-core/tests/unit/triggers/test_callback.py b/airflow-core/tests/unit/triggers/test_callback.py index ca59ea735f8ff..99eca603323bb 100644 --- a/airflow-core/tests/unit/triggers/test_callback.py +++ b/airflow-core/tests/unit/triggers/test_callback.py @@ -28,7 +28,6 @@ TEST_MESSAGE = "test_message" TEST_CALLBACK_PATH = "classpath.test_callback" TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}} -TEST_TRIGGER = CallbackTrigger(callback_path=TEST_CALLBACK_PATH, callback_kwargs=TEST_CALLBACK_KWARGS) class ExampleAsyncNotifier(BaseNotifier): @@ -46,6 +45,14 @@ def notify(self, context): class TestCallbackTrigger: + @pytest.fixture + def trigger(self): + """Create a fresh trigger per test to avoid shared mutable state.""" + return CallbackTrigger( + callback_path=TEST_CALLBACK_PATH, + callback_kwargs=dict(TEST_CALLBACK_KWARGS), + ) + @pytest.fixture def mock_import_string(self): with mock.patch("airflow.triggers.callback.import_string") as m: @@ -72,29 +79,30 @@ def test_serialization(self, callback_init_kwargs, expected_serialized_kwargs): } @pytest.mark.asyncio - async def test_run_success_with_async_function(self, mock_import_string): + async def test_run_success_with_async_function(self, trigger, mock_import_string): """Test trigger handles async functions correctly.""" callback_return_value = "some value" mock_callback = mock.AsyncMock(return_value=callback_return_value) mock_import_string.return_value = mock_callback - trigger_gen = TEST_TRIGGER.run() + trigger_gen = trigger.run() running_event = await anext(trigger_gen) assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING success_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) + # AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value @pytest.mark.asyncio - async def test_run_success_with_notifier(self, mock_import_string): + async def test_run_success_with_notifier(self, trigger, mock_import_string): """Test trigger handles async notifier classes correctly.""" mock_import_string.return_value = ExampleAsyncNotifier - trigger_gen = TEST_TRIGGER.run() + trigger_gen = trigger.run() running_event = await anext(trigger_gen) assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING @@ -108,18 +116,19 @@ async def test_run_success_with_notifier(self, mock_import_string): ) @pytest.mark.asyncio - async def test_run_failure(self, mock_import_string): + async def test_run_failure(self, trigger, mock_import_string): exc_msg = "Something went wrong" mock_callback = mock.AsyncMock(side_effect=RuntimeError(exc_msg)) mock_import_string.return_value = mock_callback - trigger_gen = TEST_TRIGGER.run() + trigger_gen = trigger.run() running_event = await anext(trigger_gen) assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING failure_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) + # AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg])