Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions airflow-core/docs/howto/deadline-alerts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ Triggerer's system path.
Nested callables are not currently supported.
* The Triggerer will need to be restarted when a callback is added or changed in order to reload the file.

.. note::
**Airflow ``context``:** When a deadline is missed, Airflow automatically provides a ``context``
kwarg into the callback containing information about the Dag run and the deadline. To receive it,
accept ``**kwargs`` in your callback and access ``kwargs["context"]``, or add a named ``context``
parameter. Callbacks that don't need the context can omit it — Airflow will only pass kwargs that
the callable accepts. The ``context`` keyword is reserved and cannot be used in the ``kwargs``
parameter of a ``Callback``; attempting to do so will raise a ``ValueError`` at DAG parse time.

A **custom asynchronous callback** might look like this:

Expand Down
10 changes: 8 additions & 2 deletions airflow-core/src/airflow/executors/workloads/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def execute_callback_workload(
:param log: Logger instance for recording execution
:return: Tuple of (success: bool, error_message: str | None)
"""
from airflow.models.callback import _accepts_context # circular import

callback_path = callback.data.get("path")
callback_kwargs = callback.data.get("kwargs", {})

Expand All @@ -137,15 +139,19 @@ def execute_callback_workload(
module_path, function_name = callback_path.rsplit(".", 1)
module = import_module(module_path)
callback_callable = getattr(module, function_name)
context = callback_kwargs.pop("context", None)

log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs)

# If the callback is a callable, call it. If it is a class, instantiate it.
result = callback_callable(**callback_kwargs)
# Rather than forcing all custom callbacks to accept context, conditionally provide it only if supported.
if _accepts_context(callback_callable) and context is not None:
result = callback_callable(**callback_kwargs, context=context)
else:
result = callback_callable(**callback_kwargs)

# If the callback is a class then it is now instantiated and callable, call it.
if callable(result):
context = callback_kwargs.get("context", {})
log.debug("Calling result with context for %s", callback_path)
result = result(context)

Expand Down
12 changes: 12 additions & 0 deletions airflow-core/src/airflow/models/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import inspect
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from importlib import import_module
Expand Down Expand Up @@ -50,6 +52,16 @@
TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED))


def _accepts_context(callback: Callable) -> bool:
"""Check if callback accepts a 'context' parameter or **kwargs."""
try:
sig = inspect.signature(callback)
except (ValueError, TypeError):
return True
params = sig.parameters
return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())


class CallbackType(str, Enum):
"""
Types of Callbacks.
Expand Down
11 changes: 8 additions & 3 deletions airflow-core/src/airflow/triggers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Any

from airflow._shared.module_loading import import_string, qualname
from airflow.models.callback import CallbackState
from airflow.models.callback import CallbackState, _accepts_context
from airflow.triggers.base import BaseTrigger, TriggerEvent

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,9 +52,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
try:
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
callback = import_string(self.callback_path)
# TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs`
context = self.callback_kwargs.pop("context", None)

if _accepts_context(callback) and context is not None:
result = await callback(**self.callback_kwargs, context=context)
else:
result = await callback(**self.callback_kwargs)

# TODO: get full context and run template rendering. Right now, a simple context in included in `callback_kwargs`
result = await callback(**self.callback_kwargs)
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS, PAYLOAD_BODY_KEY: result})

except Exception as e:
Expand Down
33 changes: 33 additions & 0 deletions airflow-core/tests/unit/models/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

from unittest.mock import patch

import pytest
from sqlalchemy import select

Expand All @@ -26,6 +28,7 @@
CallbackState,
ExecutorCallback,
TriggererCallback,
_accepts_context,
)
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
from airflow.triggers.base import TriggerEvent
Expand Down Expand Up @@ -208,3 +211,33 @@ def test_queue(self):


# Note: class DagProcessorCallback is tested in airflow-core/tests/unit/dag_processing/test_manager.py


class TestAcceptsContext:
def test_true_when_var_keyword_present(self):
def func_with_var_keyword(**kwargs):
pass

assert _accepts_context(func_with_var_keyword) is True

def test_true_when_context_param_present(self):
def func_with_context(context, alert_type):
pass

assert _accepts_context(func_with_context) is True

def test_false_when_no_context_or_var_keyword(self):
def func_without_context(a, b):
pass

assert _accepts_context(func_without_context) is False

def test_false_when_no_params(self):
def func_no_params():
pass

assert _accepts_context(func_no_params) is False

def test_true_for_uninspectable_callable(self):
with patch("airflow.models.callback.inspect.signature", side_effect=ValueError):
assert _accepts_context(lambda: None) is True
23 changes: 16 additions & 7 deletions airflow-core/tests/unit/triggers/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback"
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}}
TEST_TRIGGER = CallbackTrigger(callback_path=TEST_CALLBACK_PATH, callback_kwargs=TEST_CALLBACK_KWARGS)


class ExampleAsyncNotifier(BaseNotifier):
Expand All @@ -46,6 +45,14 @@ def notify(self, context):


class TestCallbackTrigger:
@pytest.fixture
def trigger(self):
"""Create a fresh trigger per test to avoid shared mutable state."""
return CallbackTrigger(
callback_path=TEST_CALLBACK_PATH,
callback_kwargs=dict(TEST_CALLBACK_KWARGS),
)

@pytest.fixture
def mock_import_string(self):
with mock.patch("airflow.triggers.callback.import_string") as m:
Expand All @@ -72,29 +79,30 @@ def test_serialization(self, callback_init_kwargs, expected_serialized_kwargs):
}

@pytest.mark.asyncio
async def test_run_success_with_async_function(self, mock_import_string):
async def test_run_success_with_async_function(self, trigger, mock_import_string):
"""Test trigger handles async functions correctly."""
callback_return_value = "some value"
mock_callback = mock.AsyncMock(return_value=callback_return_value)
mock_import_string.return_value = mock_callback

trigger_gen = TEST_TRIGGER.run()
trigger_gen = trigger.run()

running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING

success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
# AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value

@pytest.mark.asyncio
async def test_run_success_with_notifier(self, mock_import_string):
async def test_run_success_with_notifier(self, trigger, mock_import_string):
"""Test trigger handles async notifier classes correctly."""
mock_import_string.return_value = ExampleAsyncNotifier

trigger_gen = TEST_TRIGGER.run()
trigger_gen = trigger.run()

running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING
Expand All @@ -108,18 +116,19 @@ async def test_run_success_with_notifier(self, mock_import_string):
)

@pytest.mark.asyncio
async def test_run_failure(self, mock_import_string):
async def test_run_failure(self, trigger, mock_import_string):
exc_msg = "Something went wrong"
mock_callback = mock.AsyncMock(side_effect=RuntimeError(exc_msg))
mock_import_string.return_value = mock_callback

trigger_gen = TEST_TRIGGER.run()
trigger_gen = trigger.run()

running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING

failure_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
# AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg])