diff --git a/airflow-core/docs/howto/deadline-alerts.rst b/airflow-core/docs/howto/deadline-alerts.rst index ab1e9da5f69a6..643e17fc185fb 100644 --- a/airflow-core/docs/howto/deadline-alerts.rst +++ b/airflow-core/docs/howto/deadline-alerts.rst @@ -237,6 +237,13 @@ Triggerer's system path. Nested callables are not currently supported. * The Triggerer will need to be restarted when a callback is added or changed in order to reload the file. +.. note:: + **Airflow ``context``:** When a deadline is missed, Airflow automatically provides a ``context`` + kwarg into the callback containing information about the Dag run and the deadline. To receive it, + accept ``**kwargs`` in your callback and access ``kwargs["context"]``, or add a named ``context`` + parameter. Callbacks that don't need the context can omit it — Airflow will only pass kwargs that + the callable accepts. The ``context`` keyword is reserved and cannot be used in the ``kwargs`` + parameter of a ``Callback``; attempting to do so will raise a ``ValueError`` at DAG parse time. A **custom asynchronous callback** might look like this: diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index c15bb33fba70e..2563f9a78f553 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -125,6 +125,8 @@ def execute_callback_workload( :param log: Logger instance for recording execution :return: Tuple of (success: bool, error_message: str | None) """ + from airflow.models.callback import _accepts_context # circular import + callback_path = callback.data.get("path") callback_kwargs = callback.data.get("kwargs", {}) @@ -137,15 +139,19 @@ def execute_callback_workload( module_path, function_name = callback_path.rsplit(".", 1) module = import_module(module_path) callback_callable = getattr(module, function_name) + context = callback_kwargs.pop("context", None) log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) # If the callback is a callable, call it. If it is a class, instantiate it. - result = callback_callable(**callback_kwargs) + # Rather than forcing all custom callbacks to accept context, conditionally provide it only if supported. + if _accepts_context(callback_callable) and context is not None: + result = callback_callable(**callback_kwargs, context=context) + else: + result = callback_callable(**callback_kwargs) # If the callback is a class then it is now instantiated and callable, call it. if callable(result): - context = callback_kwargs.get("context", {}) log.debug("Calling result with context for %s", callback_path) result = result(context) diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index ea482ab7ba8d5..e2c46153a7170 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import inspect +from collections.abc import Callable from datetime import datetime from enum import Enum from importlib import import_module @@ -50,6 +52,16 @@ TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) +def _accepts_context(callback: Callable) -> bool: + """Check if callback accepts a 'context' parameter or **kwargs.""" + try: + sig = inspect.signature(callback) + except (ValueError, TypeError): + return True + params = sig.parameters + return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + class CallbackType(str, Enum): """ Types of Callbacks. diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index aadfffe38cc6d..9c2470c77eae6 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -23,7 +23,7 @@ from typing import Any from airflow._shared.module_loading import import_string, qualname -from airflow.models.callback import CallbackState +from airflow.models.callback import CallbackState, _accepts_context from airflow.triggers.base import BaseTrigger, TriggerEvent log = logging.getLogger(__name__) @@ -52,9 +52,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) callback = import_string(self.callback_path) + # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` + context = self.callback_kwargs.pop("context", None) + + if _accepts_context(callback) and context is not None: + result = await callback(**self.callback_kwargs, context=context) + else: + result = await callback(**self.callback_kwargs) - # TODO: get full context and run template rendering. Right now, a simple context in included in `callback_kwargs` - result = await callback(**self.callback_kwargs) yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS, PAYLOAD_BODY_KEY: result}) except Exception as e: diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 6ab6ad2d02df7..20bbba29fc108 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest.mock import patch + import pytest from sqlalchemy import select @@ -26,6 +28,7 @@ CallbackState, ExecutorCallback, TriggererCallback, + _accepts_context, ) from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.triggers.base import TriggerEvent @@ -208,3 +211,33 @@ def test_queue(self): # Note: class DagProcessorCallback is tested in airflow-core/tests/unit/dag_processing/test_manager.py + + +class TestAcceptsContext: + def test_true_when_var_keyword_present(self): + def func_with_var_keyword(**kwargs): + pass + + assert _accepts_context(func_with_var_keyword) is True + + def test_true_when_context_param_present(self): + def func_with_context(context, alert_type): + pass + + assert _accepts_context(func_with_context) is True + + def test_false_when_no_context_or_var_keyword(self): + def func_without_context(a, b): + pass + + assert _accepts_context(func_without_context) is False + + def test_false_when_no_params(self): + def func_no_params(): + pass + + assert _accepts_context(func_no_params) is False + + def test_true_for_uninspectable_callable(self): + with patch("airflow.models.callback.inspect.signature", side_effect=ValueError): + assert _accepts_context(lambda: None) is True diff --git a/airflow-core/tests/unit/triggers/test_callback.py b/airflow-core/tests/unit/triggers/test_callback.py index ca59ea735f8ff..99eca603323bb 100644 --- a/airflow-core/tests/unit/triggers/test_callback.py +++ b/airflow-core/tests/unit/triggers/test_callback.py @@ -28,7 +28,6 @@ TEST_MESSAGE = "test_message" TEST_CALLBACK_PATH = "classpath.test_callback" TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}} -TEST_TRIGGER = CallbackTrigger(callback_path=TEST_CALLBACK_PATH, callback_kwargs=TEST_CALLBACK_KWARGS) class ExampleAsyncNotifier(BaseNotifier): @@ -46,6 +45,14 @@ def notify(self, context): class TestCallbackTrigger: + @pytest.fixture + def trigger(self): + """Create a fresh trigger per test to avoid shared mutable state.""" + return CallbackTrigger( + callback_path=TEST_CALLBACK_PATH, + callback_kwargs=dict(TEST_CALLBACK_KWARGS), + ) + @pytest.fixture def mock_import_string(self): with mock.patch("airflow.triggers.callback.import_string") as m: @@ -72,29 +79,30 @@ def test_serialization(self, callback_init_kwargs, expected_serialized_kwargs): } @pytest.mark.asyncio - async def test_run_success_with_async_function(self, mock_import_string): + async def test_run_success_with_async_function(self, trigger, mock_import_string): """Test trigger handles async functions correctly.""" callback_return_value = "some value" mock_callback = mock.AsyncMock(return_value=callback_return_value) mock_import_string.return_value = mock_callback - trigger_gen = TEST_TRIGGER.run() + trigger_gen = trigger.run() running_event = await anext(trigger_gen) assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING success_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) + # AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value @pytest.mark.asyncio - async def test_run_success_with_notifier(self, mock_import_string): + async def test_run_success_with_notifier(self, trigger, mock_import_string): """Test trigger handles async notifier classes correctly.""" mock_import_string.return_value = ExampleAsyncNotifier - trigger_gen = TEST_TRIGGER.run() + trigger_gen = trigger.run() running_event = await anext(trigger_gen) assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING @@ -108,18 +116,19 @@ async def test_run_success_with_notifier(self, mock_import_string): ) @pytest.mark.asyncio - async def test_run_failure(self, mock_import_string): + async def test_run_failure(self, trigger, mock_import_string): exc_msg = "Something went wrong" mock_callback = mock.AsyncMock(side_effect=RuntimeError(exc_msg)) mock_import_string.return_value = mock_callback - trigger_gen = TEST_TRIGGER.run() + trigger_gen = trigger.run() running_event = await anext(trigger_gen) assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING failure_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) + # AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg])