diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 3bb8a70fa2712..2997d55d8bb3b 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -32,7 +32,9 @@ from airflow.configuration import conf from airflow.executors import workloads from airflow.executors.executor_loader import ExecutorLoader +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models import Log +from airflow.models.callback import CallbackKey from airflow.observability.metrics import stats_utils from airflow.observability.trace import Trace from airflow.utils.log.logging_mixin import LoggingMixin @@ -52,6 +54,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.executor_utils import ExecutorName + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -143,6 +146,7 @@ class BaseExecutor(LoggingMixin): active_spans = ThreadSafeDict() supports_ad_hoc_ti_run: bool = False + supports_callbacks: bool = False supports_multi_team: bool = False sentry_integration: str = "" @@ -186,8 +190,9 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.parallelism: int = parallelism self.team_name: str | None = team_name self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} - self.running: set[TaskInstanceKey] = set() - self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} + self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {} + self.running: set[WorkloadKey] = set() + self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() self.conf = ExecutorConf(team_name) @@ -203,7 +208,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) :meta private: """ - self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType) + self.attempts: dict[WorkloadKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType) def __repr__(self): _repr = f"{self.__class__.__name__}(parallelism={self.parallelism}" @@ -224,10 +229,47 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey): self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) def queue_workload(self, workload: workloads.All, session: Session) -> None: - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") - ti = workload.ti - self.queued_tasks[ti.key] = workload + if isinstance(workload, workloads.ExecuteTask): + ti = workload.ti + self.queued_tasks[ti.key] = workload + elif isinstance(workload, workloads.ExecuteCallback): + if not self.supports_callbacks: + raise NotImplementedError( + f"{type(self).__name__} does not support ExecuteCallback workloads. " + f"Set supports_callbacks = True and implement callback handling in _process_workloads(). " + f"See LocalExecutor or CeleryExecutor for reference implementation." + ) + self.queued_callbacks[workload.callback.id] = workload + else: + raise ValueError( + f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. " + f"Workload must be one of: ExecuteTask, ExecuteCallback." + ) + + def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]: + """ + Select and return the next batch of workloads to schedule, respecting priority policy. + + Priority Policy: Callbacks are scheduled before tasks (callbacks complete existing work). + Callbacks are processed in FIFO order. Tasks are sorted by priority_weight (higher priority first). + + :param open_slots: Number of available execution slots + """ + workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = [] + + if self.queued_callbacks: + for key, workload in self.queued_callbacks.items(): + if len(workloads_to_schedule) >= open_slots: + break + workloads_to_schedule.append((key, workload)) + + if open_slots > len(workloads_to_schedule) and self.queued_tasks: + for task_key, task_workload in self.order_queued_tasks_by_priority(): + if len(workloads_to_schedule) >= open_slots: + break + workloads_to_schedule.append((task_key, task_workload)) + + return workloads_to_schedule def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: """ @@ -266,10 +308,10 @@ def heartbeat(self) -> None: """Heartbeat sent to trigger new jobs.""" open_slots = self.parallelism - len(self.running) - num_running_tasks = len(self.running) - num_queued_tasks = len(self.queued_tasks) + num_running_workloads = len(self.running) + num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks) - self._emit_metrics(open_slots, num_running_tasks, num_queued_tasks) + self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads) self.trigger_tasks(open_slots) # Calling child class sync method @@ -350,16 +392,16 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workload def trigger_tasks(self, open_slots: int) -> None: """ - Initiate async execution of the queued tasks, up to the number of available slots. + Initiate async execution of queued workloads (tasks and callbacks), up to the number of available slots. + + Callbacks are prioritized over tasks to complete existing work before starting new work. :param open_slots: Number of open slots """ - sorted_queue = self.order_queued_tasks_by_priority() + workloads_to_schedule = self._get_workloads_to_schedule(open_slots) workload_list = [] - for _ in range(min((open_slots, len(self.queued_tasks)))): - key, item = sorted_queue.pop() - + for key, workload in workloads_to_schedule: # If a task makes it here but is still understood by the executor # to be running, it generally means that the task has been killed # externally and not yet been marked as failed. @@ -373,12 +415,12 @@ def trigger_tasks(self, open_slots: int) -> None: if key in self.attempts: del self.attempts[key] - if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"): - ti = item.ti + if isinstance(workload, workloads.ExecuteTask) and hasattr(workload, "ti"): + ti = workload.ti # If it's None, then the span for the current id hasn't been started. if self.active_spans is not None and self.active_spans.get("ti:" + str(ti.id)) is None: - if isinstance(ti, workloads.TaskInstance): + if isinstance(ti, TaskInstanceDTO): parent_context = Trace.extract(ti.parent_context_carrier) else: parent_context = Trace.extract(ti.dag_run.context_carrier) @@ -397,7 +439,8 @@ def trigger_tasks(self, open_slots: int) -> None: carrier = Trace.inject() ti.context_carrier = carrier - workload_list.append(item) + workload_list.append(workload) + if workload_list: self._process_workloads(workload_list) @@ -459,24 +502,25 @@ def running_state(self, key: TaskInstanceKey, info=None) -> None: """ self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False) - def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]: + def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueType]: """ Return and flush the event buffer. In case dag_ids is specified it will only return and flush events for the given dag_ids. Otherwise, it returns and flushes all events. + Note: Callback events (with string keys) are always included regardless of dag_ids filter. :param dag_ids: the dag_ids to return events for; returns all if given ``None``. :return: a dict of events """ - cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {} + cleared_events: dict[WorkloadKey, EventBufferValueType] = {} if dag_ids is None: cleared_events = self.event_buffer self.event_buffer = {} else: - for ti_key in list(self.event_buffer.keys()): - if ti_key.dag_id in dag_ids: - cleared_events[ti_key] = self.event_buffer.pop(ti_key) + for key in list(self.event_buffer.keys()): + if isinstance(key, CallbackKey) or key.dag_id in dag_ids: + cleared_events[key] = self.event_buffer.pop(key) return cleared_events @@ -529,21 +573,26 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task @property def slots_available(self): - """Number of new tasks this executor instance can accept.""" - return self.parallelism - len(self.running) - len(self.queued_tasks) + """Number of new workloads (tasks and callbacks) this executor instance can accept.""" + return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks) @property def slots_occupied(self): - """Number of tasks this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + """Number of workloads (tasks and callbacks) this executor instance is currently managing.""" + return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks) def debug_dump(self): """Get called in response to SIGUSR2 by the scheduler.""" self.log.info( - "executor.queued (%d)\n\t%s", + "executor.queued_tasks (%d)\n\t%s", len(self.queued_tasks), "\n\t".join(map(repr, self.queued_tasks.items())), ) + self.log.info( + "executor.queued_callbacks (%d)\n\t%s", + len(self.queued_callbacks), + "\n\t".join(map(repr, self.queued_callbacks.items())), + ) self.log.info("executor.running (%d)\n\t%s", len(self.running), "\n\t".join(map(repr, self.running))) self.log.info( "executor.event_buffer (%d)\n\t%s", diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 604de7c7f00f4..9b5939a0bd2e7 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -37,7 +37,8 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor -from airflow.utils.state import TaskInstanceState +from airflow.executors.workloads.callback import execute_callback_workload +from airflow.utils.state import CallbackState, TaskInstanceState # add logger to parameter of setproctitle to support logging if sys.platform == "darwin": @@ -50,13 +51,23 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger - TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Exception | None] + from airflow.executors.workloads.types import WorkloadResultType + + +def _get_executor_process_title_prefix(team_name: str | None) -> str: + """ + Build the process title prefix for LocalExecutor workers. + + :param team_name: Team name from executor configuration + """ + team_suffix = f" [{team_name}]" if team_name else "" + return f"airflow worker -- LocalExecutor{team_suffix}:" def _run_worker( logger_name: str, input: SimpleQueue[workloads.All | None], - output: Queue[TaskInstanceStateType], + output: Queue[WorkloadResultType], unread_messages: multiprocessing.sharedctypes.Synchronized[int], team_conf, ): @@ -68,11 +79,8 @@ def _run_worker( log = structlog.get_logger(logger_name) log.info("Worker starting up pid=%d", os.getpid()) - # Create team suffix for process title - team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else "" - while True: - setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: ", log) + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} ", log) try: workload = input.get() except EOFError: @@ -87,25 +95,30 @@ def _run_worker( # Received poison pill, no more tasks to run return - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") - # Decrement this as soon as we pick up a message off the queue with unread_messages: unread_messages.value -= 1 - key = None - if ti := getattr(workload, "ti", None): - key = ti.key - else: - raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}") - try: - _execute_work(log, workload, team_conf) + # Handle different workload types + if isinstance(workload, workloads.ExecuteTask): + try: + _execute_work(log, workload, team_conf) + output.put((workload.ti.key, TaskInstanceState.SUCCESS, None)) + except Exception as e: + log.exception("Task execution failed.") + output.put((workload.ti.key, TaskInstanceState.FAILED, e)) + + elif isinstance(workload, workloads.ExecuteCallback): + output.put((workload.callback.id, CallbackState.RUNNING, None)) + try: + _execute_callback(log, workload, team_conf) + output.put((workload.callback.id, CallbackState.SUCCESS, None)) + except Exception as e: + log.exception("Callback execution failed") + output.put((workload.callback.id, CallbackState.FAILED, e)) - output.put((key, TaskInstanceState.SUCCESS, None)) - except Exception as e: - log.exception("uhoh") - output.put((key, TaskInstanceState.FAILED, e)) + else: + raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None: @@ -118,9 +131,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No """ from airflow.sdk.execution_time.supervisor import supervise - # Create team suffix for process title - team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else "" - setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: {workload.ti.id}", log) + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log) base_url = team_conf.get("api", "base_url", fallback="/") # If it's a relative URL, use localhost:8080 as the default @@ -141,6 +152,22 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No ) +def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None: + """ + Execute a callback workload. + + :param log: Logger instance + :param workload: The ExecuteCallback workload to execute + :param team_conf: Team-specific executor configuration + """ + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log) + + success, error_msg = execute_callback_workload(workload.callback, log) + + if not success: + raise RuntimeError(error_msg or "Callback execution failed") + + class LocalExecutor(BaseExecutor): """ LocalExecutor executes tasks locally in parallel. @@ -155,9 +182,10 @@ class LocalExecutor(BaseExecutor): supports_multi_team: bool = True serve_logs: bool = True + supports_callbacks: bool = True activity_queue: SimpleQueue[workloads.All | None] - result_queue: SimpleQueue[TaskInstanceStateType] + result_queue: SimpleQueue[WorkloadResultType] workers: dict[int, multiprocessing.Process] _unread_messages: multiprocessing.sharedctypes.Synchronized[int] @@ -300,10 +328,14 @@ def end(self) -> None: def terminate(self): """Terminate the executor is not doing anything.""" - def _process_workloads(self, workloads): - for workload in workloads: + def _process_workloads(self, workload_list): + for workload in workload_list: self.activity_queue.put(workload) - del self.queued_tasks[workload.ti.key] + # Remove from appropriate queue based on workload type + if isinstance(workload, workloads.ExecuteTask): + del self.queued_tasks[workload.ti.key] + elif isinstance(workload, workloads.ExecuteCallback): + del self.queued_callbacks[workload.callback.id] with self._unread_messages: - self._unread_messages.value += len(workloads) + self._unread_messages.value += len(workload_list) self._check_workers() diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py deleted file mode 100644 index 7cf1aae60ff21..0000000000000 --- a/airflow-core/src/airflow/executors/workloads.py +++ /dev/null @@ -1,210 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os -import uuid -from abc import ABC -from datetime import datetime -from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Literal - -import structlog -from pydantic import BaseModel, Field - -if TYPE_CHECKING: - from airflow.api_fastapi.auth.tokens import JWTGenerator - from airflow.models import DagRun - from airflow.models.callback import Callback as CallbackModel, CallbackFetchMethod - from airflow.models.taskinstance import TaskInstance as TIModel - from airflow.models.taskinstancekey import TaskInstanceKey - - -__all__ = ["All", "ExecuteTask", "ExecuteCallback"] - -log = structlog.get_logger(__name__) - - -class BaseWorkload(BaseModel): - token: str - """The identity token for this workload""" - - @staticmethod - def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" - - -class BundleInfo(BaseModel): - """Schema for telling task which bundle to run with.""" - - name: str - version: str | None = None - - -class TaskInstance(BaseModel): - """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" - - id: uuid.UUID - dag_version_id: uuid.UUID - task_id: str - dag_id: str - run_id: str - try_number: int - map_index: int = -1 - - pool_slots: int - queue: str - priority_weight: int - executor_config: dict | None = Field(default=None, exclude=True) - - parent_context_carrier: dict | None = None - context_carrier: dict | None = None - - # TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across the codebase? - @property - def key(self) -> TaskInstanceKey: - from airflow.models.taskinstancekey import TaskInstanceKey - - return TaskInstanceKey( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=self.run_id, - try_number=self.try_number, - map_index=self.map_index, - ) - - -class Callback(BaseModel): - """Schema for Callback with minimal required fields needed for Executors and Task SDK.""" - - id: uuid.UUID - fetch_type: CallbackFetchMethod - data: dict - - -class BaseDagBundleWorkload(BaseWorkload, ABC): - """Base class for Workloads that are associated with a DAG bundle.""" - - dag_rel_path: os.PathLike[str] - """The filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)""" - - bundle_info: BundleInfo - - log_path: str | None - """The rendered relative log filename template the task logs should be written to""" - - -class ExecuteTask(BaseDagBundleWorkload): - """Execute the given Task.""" - - ti: TaskInstance - sentry_integration: str = "" - - type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") - - @classmethod - def make( - cls, - ti: TIModel, - dag_rel_path: Path | None = None, - generator: JWTGenerator | None = None, - bundle_info: BundleInfo | None = None, - sentry_integration: str = "", - ) -> ExecuteTask: - from airflow.utils.helpers import log_filename_template_renderer - - ser_ti = TaskInstance.model_validate(ti, from_attributes=True) - ser_ti.parent_context_carrier = ti.dag_run.context_carrier - if not bundle_info: - bundle_info = BundleInfo( - name=ti.dag_model.bundle_name, - version=ti.dag_run.bundle_version, - ) - fname = log_filename_template_renderer()(ti=ti) - - return cls( - ti=ser_ti, - dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or ""), - token=cls.generate_token(str(ti.id), generator), - log_path=fname, - bundle_info=bundle_info, - sentry_integration=sentry_integration, - ) - - -class ExecuteCallback(BaseDagBundleWorkload): - """Execute the given Callback.""" - - callback: Callback - - type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") - - @classmethod - def make( - cls, - callback: CallbackModel, - dag_run: DagRun, - dag_rel_path: Path | None = None, - generator: JWTGenerator | None = None, - bundle_info: BundleInfo | None = None, - ) -> ExecuteCallback: - if not bundle_info: - bundle_info = BundleInfo( - name=dag_run.dag_model.bundle_name, - version=dag_run.bundle_version, - ) - fname = f"executor_callbacks/{callback.id}" # TODO: better log file template - - return cls( - callback=Callback.model_validate(callback, from_attributes=True), - dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), - token=cls.generate_token(str(callback.id), generator), - log_path=fname, - bundle_info=bundle_info, - ) - - -class RunTrigger(BaseModel): - """Execute an async "trigger" process that yields events.""" - - id: int - - ti: TaskInstance | None - """ - The task instance associated with this trigger. - - Could be none for asset-based triggers. - """ - - classpath: str - """ - Dot-separated name of the module+fn to import and run this workload. - - Consumers of this Workload must perform their own validation of this input. - """ - - encrypted_kwargs: str - - timeout_after: datetime | None = None - - type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger") - - -All = Annotated[ - ExecuteTask | RunTrigger, - Field(discriminator="type"), -] diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py new file mode 100644 index 0000000000000..dca4c991f637b --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workload schemas for executor communication.""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import Field + +from airflow.executors.workloads.base import BaseWorkload, BundleInfo +from airflow.executors.workloads.callback import CallbackFetchMethod, ExecuteCallback +from airflow.executors.workloads.task import ExecuteTask +from airflow.executors.workloads.trigger import RunTrigger + +All = Annotated[ + ExecuteTask | ExecuteCallback | RunTrigger, + Field(discriminator="type"), +] + +__all__ = ["All", "BaseWorkload", "BundleInfo", "CallbackFetchMethod", "ExecuteCallback", "ExecuteTask"] diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py new file mode 100644 index 0000000000000..cf622209d67ba --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ORM models and Pydantic schemas for BaseWorkload.""" + +from __future__ import annotations + +import os +from abc import ABC +from typing import TYPE_CHECKING + +from pydantic import BaseModel, ConfigDict + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + + +class BaseWorkload: + """ + Mixin for ORM models that can be scheduled as workloads. + + This mixin defines the interface that scheduler workloads (TaskInstance, + ExecutorCallback, etc.) must implement to provide routing information to the scheduler. + + Subclasses must override: + - get_dag_id() -> str | None + - get_executor_name() -> str | None + """ + + def get_dag_id(self) -> str | None: + """ + Return the DAG ID for scheduler routing. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement get_dag_id()") + + def get_executor_name(self) -> str | None: + """ + Return the executor name for scheduler routing. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement get_executor_name()") + + +class BundleInfo(BaseModel): + """Schema for telling task which bundle to run with.""" + + name: str + version: str | None = None + + +class BaseWorkloadSchema(BaseModel): + """Base Pydantic schema for executor workload DTOs.""" + + model_config = ConfigDict(populate_by_name=True) + + token: str + """The identity token for this workload""" + + @staticmethod + def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: + return generator.generate({"sub": sub_id}) if generator else "" + + +class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): + """Base class for Workloads that are associated with a DAG bundle.""" + + dag_rel_path: os.PathLike[str] # Filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`) + bundle_info: BundleInfo + log_path: str | None # Rendered relative log filename template the task logs should be written to. diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py new file mode 100644 index 0000000000000..c15bb33fba70e --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Callback workload schemas for executor communication.""" + +from __future__ import annotations + +from enum import Enum +from importlib import import_module +from pathlib import Path +from typing import TYPE_CHECKING, Literal +from uuid import UUID + +import structlog +from pydantic import BaseModel, Field, field_validator + +from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + from airflow.models import DagRun + from airflow.models.callback import Callback as CallbackModel, CallbackKey + +log = structlog.get_logger(__name__) + + +class CallbackFetchMethod(str, Enum): + """Methods used to fetch callback at runtime.""" + + # For future use once Dag Processor callbacks (on_success_callback/on_failure_callback) get moved to executors + DAG_ATTRIBUTE = "dag_attribute" + + # For deadline callbacks since they import callbacks through the import path + IMPORT_PATH = "import_path" + + +class CallbackDTO(BaseModel): + """Schema for Callback with minimal required fields needed for Executors and Task SDK.""" + + id: str # A uuid.UUID stored as a string + fetch_method: CallbackFetchMethod + data: dict + + @field_validator("id", mode="before") + @classmethod + def validate_id(cls, v): + """Convert UUID to str if needed.""" + if isinstance(v, UUID): + return str(v) + return v + + @property + def key(self) -> CallbackKey: + """Return callback ID as key (CallbackKey = str).""" + return self.id + + +class ExecuteCallback(BaseDagBundleWorkload): + """Execute the given Callback.""" + + callback: CallbackDTO + + type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") + + @classmethod + def make( + cls, + callback: CallbackModel, + dag_run: DagRun, + dag_rel_path: Path | None = None, + generator: JWTGenerator | None = None, + bundle_info: BundleInfo | None = None, + ) -> ExecuteCallback: + """Create an ExecuteCallback workload from a Callback ORM model.""" + if not bundle_info: + bundle_info = BundleInfo( + name=dag_run.dag_model.bundle_name, + version=dag_run.bundle_version, + ) + fname = f"executor_callbacks/{callback.id}" # TODO: better log file template + + return cls( + callback=CallbackDTO.model_validate(callback, from_attributes=True), + dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), + token=cls.generate_token(str(callback.id), generator), + log_path=fname, + bundle_info=bundle_info, + ) + + +def execute_callback_workload( + callback: CallbackDTO, + log, +) -> tuple[bool, str | None]: + """ + Execute a callback function by importing and calling it, returning the success state. + + Supports two patterns: + 1. Functions - called directly with kwargs + 2. Classes that return callable instances (like BaseNotifier) - instantiated then called with context + + Example: + # Function callback + callback.data = {"path": "my_module.alert_func", "kwargs": {"msg": "Alert!", "context": {...}}} + execute_callback_workload(callback, log) # Calls alert_func(msg="Alert!", context={...}) + + # Notifier callback + callback.data = {"path": "airflow.providers.slack...SlackWebhookNotifier", "kwargs": {"text": "Alert!", "context": {...}}} + execute_callback_workload(callback, log) # SlackWebhookNotifier(text=..., context=...) then calls instance(context) + + :param callback: The Callback schema containing path and kwargs + :param log: Logger instance for recording execution + :return: Tuple of (success: bool, error_message: str | None) + """ + callback_path = callback.data.get("path") + callback_kwargs = callback.data.get("kwargs", {}) + + if not callback_path: + return False, "Callback path not found in data." + + try: + # Import the callback callable + # Expected format: "module.path.to.function_or_class" + module_path, function_name = callback_path.rsplit(".", 1) + module = import_module(module_path) + callback_callable = getattr(module, function_name) + + log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) + + # If the callback is a callable, call it. If it is a class, instantiate it. + result = callback_callable(**callback_kwargs) + + # If the callback is a class then it is now instantiated and callable, call it. + if callable(result): + context = callback_kwargs.get("context", {}) + log.debug("Calling result with context for %s", callback_path) + result = result(context) + + log.info("Callback %s executed successfully.", callback_path) + return True, None + + except Exception as e: + error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}" + log.exception("Callback %s(%s) execution failed: %s", callback_path, callback_kwargs, error_msg) + return False, error_msg diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py new file mode 100644 index 0000000000000..d691dcb6f0968 --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task workload schemas for executor communication.""" + +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from pydantic import BaseModel, Field + +from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + from airflow.models.taskinstance import TaskInstance as TIModel + from airflow.models.taskinstancekey import TaskInstanceKey + + +class TaskInstanceDTO(BaseModel): + """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" + + id: uuid.UUID + dag_version_id: uuid.UUID + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int = -1 + + pool_slots: int + queue: str + priority_weight: int + executor_config: dict | None = Field(default=None, exclude=True) + + parent_context_carrier: dict | None = None + context_carrier: dict | None = None + + # TODO: Task-SDK: Can we replace TaskInstanceKey with just the uuid across the codebase? + @property + def key(self) -> TaskInstanceKey: + from airflow.models.taskinstancekey import TaskInstanceKey + + return TaskInstanceKey( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + try_number=self.try_number, + map_index=self.map_index, + ) + + +class ExecuteTask(BaseDagBundleWorkload): + """Execute the given Task.""" + + ti: TaskInstanceDTO + sentry_integration: str = "" + + type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + + @classmethod + def make( + cls, + ti: TIModel, + dag_rel_path: Path | None = None, + generator: JWTGenerator | None = None, + bundle_info: BundleInfo | None = None, + sentry_integration: str = "", + ) -> ExecuteTask: + """Create an ExecuteTask workload from a TaskInstance ORM model.""" + from airflow.utils.helpers import log_filename_template_renderer + + ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True) + ser_ti.parent_context_carrier = ti.dag_run.context_carrier + if not bundle_info: + bundle_info = BundleInfo( + name=ti.dag_model.bundle_name, + version=ti.dag_run.bundle_version, + ) + fname = log_filename_template_renderer()(ti=ti) + + return cls( + ti=ser_ti, + dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or ""), + token=cls.generate_token(str(ti.id), generator), + log_path=fname, + bundle_info=bundle_info, + sentry_integration=sentry_integration, + ) diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py b/airflow-core/src/airflow/executors/workloads/trigger.py new file mode 100644 index 0000000000000..25bca9ce44b13 --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/trigger.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trigger workload schemas for executor communication.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + +# Using noqa because Ruff wants this in a TYPE_CHECKING block but Pydantic fails if it is. +from airflow.executors.workloads.task import TaskInstanceDTO # noqa: TCH001 + + +class RunTrigger(BaseModel): + """ + Execute an async "trigger" process that yields events. + + Consumers of this Workload must perform their own validation of the classpath input. + """ + + id: int + ti: TaskInstanceDTO | None # Could be none for asset-based triggers. + classpath: str # Dot-separated name of the module+fn to import and run this workload. + encrypted_kwargs: str + timeout_after: datetime | None = None + type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger") diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py new file mode 100644 index 0000000000000..31cda7028466f --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type aliases for Workloads.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias + +from airflow.models.callback import ExecutorCallback +from airflow.models.taskinstance import TaskInstance + +if TYPE_CHECKING: + from airflow.models.callback import CallbackKey + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.utils.state import CallbackState, TaskInstanceState + + # Type aliases for workload keys and states (used by executor layer) + WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey + WorkloadState: TypeAlias = TaskInstanceState | CallbackState + + # Type alias for executor workload results (used by executor implementations) + WorkloadResultType: TypeAlias = tuple[WorkloadKey, WorkloadState, Exception | None] + +# Type alias for scheduler workloads (ORM models that can be routed to executors) +# Must be outside TYPE_CHECKING for use in function signatures +SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 4117153d4e7cd..d192798004211 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -85,7 +85,7 @@ TaskOutletAssetReference, ) from airflow.models.backfill import Backfill -from airflow.models.callback import Callback +from airflow.models.callback import Callback, CallbackType, ExecutorCallback from airflow.models.dag import DagModel from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag @@ -95,6 +95,7 @@ from airflow.models.pool import normalize_pool_name_for_stats from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.team import Team from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger, TriggerFailureReason from airflow.observability.metrics import stats_utils @@ -115,7 +116,7 @@ prohibit_commit, with_row_locks, ) -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import CallbackState, DagRunState, State, TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -130,7 +131,7 @@ from airflow._shared.logging.types import Logger from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName - from airflow.models.taskinstance import TaskInstanceKey + from airflow.executors.workloads.types import SchedulerWorkload from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -359,32 +360,35 @@ def _get_team_names_for_dag_ids( # Return dict with all None values to ensure graceful degradation return {} - def _get_task_team_name(self, task_instance: TaskInstance, session: Session) -> str | None: + def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) -> str | None: """ - Resolve team name for a task instance using the DAG > Bundle > Team relationship chain. + Resolve team name for a workload using the DAG > Bundle > Team relationship chain. - TaskInstance > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team + Workload > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team - :param task_instance: The TaskInstance to resolve team name for + :param workload: The Workload to resolve team name for :param session: Database session for queries :return: Team name if found or None """ # Use the batch query function with a single DAG ID - dag_id_to_team_name = self._get_team_names_for_dag_ids([task_instance.dag_id], session) - team_name = dag_id_to_team_name.get(task_instance.dag_id) + if dag_id := workload.get_dag_id(): + dag_id_to_team_name = self._get_team_names_for_dag_ids([dag_id], session) + team_name = dag_id_to_team_name.get(dag_id) + else: + team_name = None # mypy didn't like the implicit defaulting to None if team_name: self.log.debug( - "Resolved team name '%s' for task %s (dag_id=%s)", + "Resolved team name '%s' for task or callback %s (dag_id=%s)", team_name, - task_instance.task_id, - task_instance.dag_id, + workload, + dag_id, ) else: self.log.debug( - "No team found for task %s (dag_id=%s) - DAG may not have bundle or team association", - task_instance.task_id, - task_instance.dag_id, + "No team found for task or callback %s (dag_id=%s) - DAG may not have bundle or team association", + workload, + dag_id, ) return team_name @@ -981,7 +985,7 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) # Sort queued TIs to their respective executor - executor_to_queued_tis = self._executor_to_tis(queued_tis, session) + executor_to_queued_tis = self._executor_to_workloads(queued_tis, session) for executor, queued_tis_per_executor in executor_to_queued_tis.items(): self.log.info( "Trying to enqueue tasks: %s for executor: %s", @@ -993,6 +997,75 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: return len(queued_tis) + def _enqueue_executor_callbacks(self, session: Session) -> None: + """ + Enqueue ExecutorCallback workloads to executors. + + Similar to _enqueue_task_instances, but for callbacks that need to run on executors. + Queries for QUEUED ExecutorCallback instances and routes them to the appropriate executor. + + :param session: The database session + """ + num_occupied_slots = sum(executor.slots_occupied for executor in self.executors) + max_callbacks = conf.getint("core", "parallelism") - num_occupied_slots + + if max_callbacks <= 0: + self.log.debug("No available slots for callbacks; all executors at capacity") + return + + pending_callbacks = session.scalars( + select(ExecutorCallback) + .where(ExecutorCallback.type == CallbackType.EXECUTOR) + .where(ExecutorCallback.state == CallbackState.PENDING) + .order_by(ExecutorCallback.priority_weight.desc()) + .limit(max_callbacks) + ).all() + + if not pending_callbacks: + return + + # Route callbacks to executors using the generalized routing method + executor_to_callbacks = self._executor_to_workloads(pending_callbacks, session) + + # Enqueue callbacks for each executor + for executor, callbacks in executor_to_callbacks.items(): + for callback in callbacks: + if not isinstance(callback, ExecutorCallback): + # Can't happen since we queried ExecutorCallback, but satisfies mypy. + continue + + # TODO: Add dagrun_id as a proper ORM foreign key on the callback table instead of storing in data dict. + # This would eliminate this reconstruction step. For now, all ExecutorCallbacks + # are expected to have dag_run_id set in their data dict (e.g., by Deadline.handle_miss). + if not isinstance(callback.data, dict) or "dag_run_id" not in callback.data: + self.log.error( + "ExecutorCallback %s is missing required 'dag_run_id' in data dict. " + "This indicates a bug in callback creation. Skipping callback.", + callback.id, + ) + continue + + dag_run_id = callback.data["dag_run_id"] + dag_run = session.get(DagRun, dag_run_id) + + if dag_run is None: + self.log.warning( + "Could not find DagRun with id=%s for callback %s. DagRun may have been deleted.", + dag_run_id, + callback.id, + ) + continue + + workload = workloads.ExecuteCallback.make( + callback=callback, + dag_run=dag_run, + generator=executor.jwt_generator, + ) + + executor.queue_workload(workload, session=session) + callback.state = CallbackState.QUEUED + session.add(callback) + @staticmethod def _process_task_event_logs(log_records: deque[Log], session: Session): objects = (log_records.popleft() for _ in range(len(log_records))) @@ -1055,21 +1128,50 @@ def process_executor_events( ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] = {} event_buffer = executor.get_event_buffer() tis_with_right_state: list[TaskInstanceKey] = [] + callback_keys_with_events: list[str] = [] + + # Report execution - handle both task and callback events + for key, (state, _) in event_buffer.items(): + if isinstance(key, TaskInstanceKey): + ti_primary_key_to_try_number_map[key.primary] = key.try_number + cls.logger().info("Received executor event with state %s for task instance %s", state, key) + if state in ( + TaskInstanceState.FAILED, + TaskInstanceState.SUCCESS, + TaskInstanceState.QUEUED, + TaskInstanceState.RUNNING, + TaskInstanceState.RESTARTING, + ): + tis_with_right_state.append(key) + else: + # Callback event (key is string UUID) + cls.logger().info("Received executor event with state %s for callback %s", state, key) + if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS): + callback_keys_with_events.append(key) + + # Handle callback state events + for callback_id in callback_keys_with_events: + state, info = event_buffer.pop(callback_id) + callback = session.get(Callback, callback_id) + if not callback: + # This should not normally happen - we just received an event for this callback. + # Only possible if callback was deleted mid-execution (e.g., cascade delete from DagRun deletion). + cls.logger().warning( + "Callback %s not found in database (may have been cascade deleted)", callback_id + ) + continue - # Report execution - for ti_key, (state, _) in event_buffer.items(): - # We create map (dag_id, task_id, logical_date) -> in-memory try_number - ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number - - cls.logger().info("Received executor event with state %s for task instance %s", state, ti_key) - if state in ( - TaskInstanceState.FAILED, - TaskInstanceState.SUCCESS, - TaskInstanceState.QUEUED, - TaskInstanceState.RUNNING, - TaskInstanceState.RESTARTING, - ): - tis_with_right_state.append(ti_key) + if state == CallbackState.RUNNING: + callback.state = CallbackState.RUNNING + cls.logger().info("Callback %s is currently running", callback_id) + elif state == CallbackState.SUCCESS: + callback.state = CallbackState.SUCCESS + cls.logger().info("Callback %s completed successfully", callback_id) + elif state == CallbackState.FAILED: + callback.state = CallbackState.FAILED + callback.output = str(info) if info else "Execution failed" + cls.logger().error("Callback %s failed: %s", callback_id, callback.output) + session.add(callback) # Return if no finished tasks if not tis_with_right_state: @@ -1657,6 +1759,9 @@ def _run_scheduler_loop(self) -> None: ): deadline.handle_miss(session) + # Route ExecutorCallback workloads to executors (similar to task routing) + self._enqueue_executor_callbacks(session) + # Heartbeat the scheduler periodically perform_heartbeat( job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True @@ -2434,7 +2539,7 @@ def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None: scheduled) up to 2 times before failing the task. """ tasks_stuck_in_queued = self._get_tis_stuck_in_queued(session) - for executor, stuck_tis in self._executor_to_tis(tasks_stuck_in_queued, session).items(): + for executor, stuck_tis in self._executor_to_workloads(tasks_stuck_in_queued, session).items(): try: for ti in stuck_tis: executor.revoke_task(ti=ti) @@ -2725,7 +2830,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: ) to_reset: list[TaskInstance] = [] - exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset, session) + exec_to_tis = self._executor_to_workloads(tis_to_adopt_or_reset, session) for executor, tis in exec_to_tis.items(): to_reset.extend(executor.try_adopt_task_instances(tis)) @@ -3074,50 +3179,57 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: session.add(warning) existing_warned_dag_ids.add(warning.dag_id) - def _executor_to_tis( + def _executor_to_workloads( self, - tis: Iterable[TaskInstance], + workloads: Iterable[SchedulerWorkload], session, dag_id_to_team_name: dict[str, str | None] | None = None, - ) -> dict[BaseExecutor, list[TaskInstance]]: - """Organize TIs into lists per their respective executor.""" - tis_iter: Iterable[TaskInstance] + ) -> dict[BaseExecutor, list[SchedulerWorkload]]: + """Organize workloads into lists per their respective executor.""" + workloads_iter: Iterable[SchedulerWorkload] if conf.getboolean("core", "multi_team"): if dag_id_to_team_name is None: - if isinstance(tis, list): - tis_list = tis + if isinstance(workloads, list): + workloads_list = workloads else: - tis_list = list(tis) - if tis_list: + workloads_list = list(workloads) + if workloads_list: dag_id_to_team_name = self._get_team_names_for_dag_ids( - {ti.dag_id for ti in tis_list}, session + { + dag_id + for workload in workloads_list + if (dag_id := workload.get_dag_id()) is not None + }, + session, ) else: dag_id_to_team_name = {} - tis_iter = tis_list + workloads_iter = workloads_list else: - tis_iter = tis + workloads_iter = workloads else: dag_id_to_team_name = {} - tis_iter = tis + workloads_iter = workloads - _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] = defaultdict(list) - for ti in tis_iter: - if executor_obj := self._try_to_load_executor( - ti, session, team_name=dag_id_to_team_name.get(ti.dag_id, NOTSET) - ): - _executor_to_tis[executor_obj].append(ti) + _executor_to_workloads: defaultdict[BaseExecutor, list[SchedulerWorkload]] = defaultdict(list) + for workload in workloads_iter: + _dag_id = workload.get_dag_id() + _team = dag_id_to_team_name.get(_dag_id, NOTSET) if _dag_id else NOTSET + if executor_obj := self._try_to_load_executor(workload, session, team_name=_team): + _executor_to_workloads[executor_obj].append(workload) - return _executor_to_tis + return _executor_to_workloads - def _try_to_load_executor(self, ti: TaskInstance, session, team_name=NOTSET) -> BaseExecutor | None: + def _try_to_load_executor( + self, workload: SchedulerWorkload, session, team_name=NOTSET + ) -> BaseExecutor | None: """ Try to load the given executor. In this context, we don't want to fail if the executor does not exist. Catch the exception and log to the user. - :param ti: TaskInstance to load executor for + :param workload: SchedulerWorkload (TaskInstance or ExecutorCallback) to load executor for :param session: Database session for queries :param team_name: Optional pre-resolved team name. If NOTSET and multi-team is enabled, will query the database to resolve team name. None indicates global team. @@ -3126,17 +3238,16 @@ def _try_to_load_executor(self, ti: TaskInstance, session, team_name=NOTSET) -> if conf.getboolean("core", "multi_team"): # Use provided team_name if available, otherwise query the database if team_name is NOTSET: - team_name = self._get_task_team_name(ti, session) + team_name = self._get_workload_team_name(workload, session) else: team_name = None - # Firstly, check if there is no executor set on the TaskInstance, if not, we need to fetch the default - # (either globally or for the team) - if ti.executor is None: + # If there is no executor set on the workload fetch the default (either globally or for the team) + if workload.get_executor_name() is None: if not team_name: - # No team is specified, so just use the global default executor + # No team is specified, use the global default executor executor = self.executor else: - # We do have a team, so we need to find the default executor for that team + # We do have a team, use the default executor for that team for _executor in self.executors: # First executor that resolves should be the default for that team if _executor.team_name == team_name: @@ -3146,22 +3257,30 @@ def _try_to_load_executor(self, ti: TaskInstance, session, team_name=NOTSET) -> # No executor found for that team, fall back to global default executor = self.executor else: - # An executor is specified on the TaskInstance (as a str), so we need to find it in the list of executors + # An executor is specified on the workload (as a str), so we need to find it in the list of executors for _executor in self.executors: - if _executor.name and ti.executor in (_executor.name.alias, _executor.name.module_path): + if _executor.name and workload.get_executor_name() in ( + _executor.name.alias, + _executor.name.module_path, + ): # The executor must either match the team or be global (i.e. team_name is None) if team_name and _executor.team_name == team_name or _executor.team_name is None: executor = _executor if executor is not None: - self.log.debug("Found executor %s for task %s (team: %s)", executor.name, ti, team_name) + self.log.debug( + "Found executor %s for task or callback %s (team: %s)", executor.name, workload, team_name + ) else: # This case should not happen unless some (as of now unknown) edge case occurs or direct DB # modification, since the DAG parser will validate the tasks in the DAG and ensure the executor # they request is available and if not, disallow the DAG to be scheduled. # Keeping this exception handling because this is a critical issue if we do somehow find # ourselves here and the user should get some feedback about that. - self.log.warning("Executor, %s, was not found but a Task was configured to use it", ti.executor) + self.log.warning( + "Executor, %s, was not found but a Task or Callback was configured to use it", + workload.get_executor_name(), + ) return executor diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index be8213c423fd6..5567e4763ea1d 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -45,6 +45,7 @@ from airflow._shared.timezones import timezone from airflow.configuration import conf from airflow.executors import workloads +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger @@ -687,9 +688,7 @@ def update_triggers(self, requested_trigger_ids: set[int]): ti_id=new_trigger_orm.task_instance.id, ) continue - ser_ti = workloads.TaskInstance.model_validate( - new_trigger_orm.task_instance, from_attributes=True - ) + ser_ti = TaskInstanceDTO.model_validate(new_trigger_orm.task_instance, from_attributes=True) # When producing logs from TIs, include the job id producing the logs to disambiguate it. self.logger_cache[new_id] = TriggerLoggingFactory( log_path=f"{log_path}.trigger.{self.job.id}.log", diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index ea45a10f1f1ae..ea482ab7ba8d5 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -29,8 +29,13 @@ from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone +from airflow.executors.workloads import BaseWorkload +from airflow.executors.workloads.callback import CallbackFetchMethod from airflow.models import Base from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.state import CallbackState + +CallbackKey = str # Callback keys are str(UUID) if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -41,20 +46,7 @@ log = structlog.get_logger(__name__) -class CallbackState(str, Enum): - """All possible states of callbacks.""" - - PENDING = "pending" - QUEUED = "queued" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - - def __str__(self) -> str: - return self.value - - -ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING)) +ACTIVE_STATES = frozenset((CallbackState.PENDING, CallbackState.QUEUED, CallbackState.RUNNING)) TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) @@ -70,16 +62,6 @@ class CallbackType(str, Enum): DAG_PROCESSOR = "dag_processor" -class CallbackFetchMethod(str, Enum): - """Methods used to fetch callback at runtime.""" - - # For future use once Dag Processor callbacks (on_success_callback/on_failure_callback) get moved to executors - DAG_ATTRIBUTE = "dag_attribute" - - # For deadline callbacks since they import callbacks through the import path - IMPORT_PATH = "import_path" - - class CallbackDefinitionProtocol(Protocol): """Protocol for TaskSDK Callback definition.""" @@ -103,7 +85,7 @@ class ImportPathExecutorCallbackDefProtocol(ImportPathCallbackDefProtocol, Proto executor: str | None -class Callback(Base): +class Callback(Base, BaseWorkload): """Base class for callbacks.""" __tablename__ = "callback" @@ -147,7 +129,7 @@ def __init__(self, priority_weight: int = 1, prefix: str = "", **kwargs): :param prefix: Optional prefix for metric names :param kwargs: Additional data emitted in metric tags """ - self.state = CallbackState.PENDING + self.state = CallbackState.SCHEDULED self.priority_weight = priority_weight self.data = kwargs # kwargs can be used to include additional info in metric tags if prefix: @@ -169,6 +151,14 @@ def get_metric_info(self, status: CallbackState, result: Any) -> dict: return {"stat": name, "tags": tags} + def get_dag_id(self) -> str | None: + """Return the DAG ID for scheduler routing.""" + return self.data.get("dag_id") + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.data.get("executor") + @staticmethod def create_from_sdk_def(callback_def: CallbackDefinitionProtocol, **kwargs) -> Callback: # Cannot check actual type using isinstance() because that would require SDK import diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index bbf24cc2842d2..debfe949b314c 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -32,10 +32,15 @@ from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone from airflow.models.base import Base -from airflow.models.callback import Callback, CallbackDefinitionProtocol +from airflow.models.callback import ( + Callback, + ExecutorCallback, + TriggererCallback, +) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name +from airflow.utils.state import CallbackState if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -224,9 +229,36 @@ def get_simple_context(): "deadline": {"id": self.id, "deadline_time": self.deadline_time}, } - self.callback.data["kwargs"] = self.callback.data["kwargs"] | {"context": get_simple_context()} + if isinstance(self.callback, TriggererCallback): + # Update the callback with context before queuing + if "kwargs" not in self.callback.data: + self.callback.data["kwargs"] = {} + self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { + "context": get_simple_context() + } + + self.callback.queue() + session.add(self.callback) + session.flush() + + elif isinstance(self.callback, ExecutorCallback): + if "kwargs" not in self.callback.data: + self.callback.data["kwargs"] = {} + self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { + "context": get_simple_context() + } + self.callback.data["deadline_id"] = str(self.id) + self.callback.data["dag_run_id"] = str(self.dagrun.id) + self.callback.data["dag_id"] = self.dagrun.dag_id + + self.callback.state = CallbackState.PENDING + session.add(self.callback) + session.flush() + + else: + raise TypeError(f"Unknown Callback type: {type(self.callback).__name__}") + self.missed = True - self.callback.queue() session.add(self) Stats.incr( "deadline_alerts.deadline_missed", diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 475cbd7ae68fb..8b12db483affa 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -70,6 +70,7 @@ from airflow._shared.timezones import timezone from airflow.assets.manager import asset_manager from airflow.configuration import conf +from airflow.executors.workloads import BaseWorkload from airflow.listeners.listener import get_listener_manager from airflow.models.asset import AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies @@ -406,7 +407,7 @@ def uuid7() -> UUID: return uuid6.uuid7() -class TaskInstance(Base, LoggingMixin): +class TaskInstance(Base, LoggingMixin, BaseWorkload): """ Task instances store the state of a task instance. @@ -802,6 +803,14 @@ def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely.""" return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) + def get_dag_id(self) -> str: + """Return the DAG ID for scheduler routing.""" + return self.dag_id + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + @provide_session def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: """ diff --git a/airflow-core/src/airflow/utils/state.py b/airflow-core/src/airflow/utils/state.py index b392a02352574..332efb105533d 100644 --- a/airflow-core/src/airflow/utils/state.py +++ b/airflow-core/src/airflow/utils/state.py @@ -20,6 +20,20 @@ from enum import Enum +class CallbackState(str, Enum): + """All possible states of callbacks.""" + + SCHEDULED = "scheduled" + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + class TerminalTIState(str, Enum): """States that a Task Instance can be in that indicate it has reached a terminal state.""" diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index 5c2a3d6d549df..fa0f311d018fe 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -25,6 +25,7 @@ import pendulum import pytest +import structlog import time_machine from airflow._shared.timezones import timezone @@ -34,6 +35,9 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor +from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.callback import CallbackDTO, execute_callback_workload +from airflow.models.callback import CallbackFetchMethod from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.sdk import BaseOperator from airflow.serialization.definitions.baseoperator import SerializedBaseOperator @@ -573,3 +577,148 @@ def test_executor_conf_get_mandatory_value(self): team_executor_conf = ExecutorConf(team_name="test_team") assert team_executor_conf.get_mandatory_value("celery", "broker_url") == "redis://team-redis" + + +class TestCallbackSupport: + def test_supports_callbacks_flag_default_false(self): + executor = BaseExecutor() + assert executor.supports_callbacks is False + + def test_local_executor_supports_callbacks_true(self): + """Test that LocalExecutor sets supports_callbacks to True.""" + executor = LocalExecutor() + assert executor.supports_callbacks is True + + @pytest.mark.db_test + def test_queue_callback_without_support_raises_error(self, dag_maker, session): + executor = BaseExecutor() # supports_callbacks = False by default + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + with pytest.raises(NotImplementedError, match="does not support ExecuteCallback"): + executor.queue_workload(callback_workload, session) + + @pytest.mark.db_test + def test_queue_workload_with_execute_callback(self, dag_maker, session): + executor = BaseExecutor() + executor.supports_callbacks = True # Enable for this test + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + executor.queue_workload(callback_workload, session) + + assert len(executor.queued_callbacks) == 1 + assert callback_data.id in executor.queued_callbacks + + @pytest.mark.db_test + def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): + executor = BaseExecutor() + executor.supports_callbacks = True # Enable for this test + dagrun = setup_dagrun(dag_maker) + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + executor.queue_workload(callback_workload, session) + + for ti in dagrun.task_instances: + task_workload = workloads.ExecuteTask.make(ti) + executor.queue_workload(task_workload, session) + + workloads_to_schedule = executor._get_workloads_to_schedule(open_slots=10) + + assert len(workloads_to_schedule) == 4 # 1 callback + 3 tasks + _, first_workload = workloads_to_schedule[0] + assert isinstance(first_workload, workloads.ExecuteCallback) # Assert callback comes first + + +class TestExecuteCallbackWorkload: + def test_execute_function_callback_success(self): + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.dict", + "kwargs": {"a": 1, "b": 2, "c": 3}, + }, + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is True + assert error is None + + def test_execute_callback_missing_path(self): + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"kwargs": {}}, # Missing 'path' + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is False + assert "Callback path not found" in error + + def test_execute_callback_import_error(self): + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "nonexistent.module.function", + "kwargs": {}, + }, + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is False + assert "ModuleNotFoundError" in error + + def test_execute_callback_execution_error(self): + # Use a function that will raise an error; len() requires an argument + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.len", + "kwargs": {}, + }, + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is False + assert "TypeError" in error diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 5f216cca2e767..34e8f818aa94c 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -29,6 +29,10 @@ from airflow._shared.timezones import timezone from airflow.executors import workloads from airflow.executors.local_executor import LocalExecutor, _execute_work +from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.callback import CallbackDTO +from airflow.executors.workloads.task import TaskInstanceDTO +from airflow.models.callback import CallbackFetchMethod from airflow.settings import Session from airflow.utils.state import State @@ -81,7 +85,7 @@ def test_executor_worker_spawned(self, mock_freeze, mock_unfreeze): @mock.patch("airflow.sdk.execution_time.supervisor.supervise") def test_execution(self, mock_supervise): success_tis = [ - workloads.TaskInstance( + TaskInstanceDTO( id=uuid7(), dag_version_id=uuid7(), task_id=f"success_{i}", @@ -327,3 +331,40 @@ def test_global_executor_without_team_name(self): assert len(executor.workers) == 2 executor.end() + + +class TestLocalExecutorCallbackSupport: + def test_supports_callbacks_flag_is_true(self): + executor = LocalExecutor() + assert executor.supports_callbacks is True + + @skip_spawn_mp_start + @mock.patch("airflow.executors.workloads.callback.execute_callback_workload") + def test_process_callback_workload(self, mock_execute_callback): + mock_execute_callback.return_value = (True, None) + + executor = LocalExecutor(parallelism=1) + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + executor.start() + + try: + executor.queued_callbacks[callback_data.id] = callback_workload + executor._process_workloads([callback_workload]) + assert len(executor.queued_callbacks) == 0 + # We can't easily verify worker execution without running the worker, + # but we can verify the helper is called via mock + + finally: + executor.end() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index ea9c51220ca6b..abbd5f20067b9 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -65,6 +65,7 @@ PartitionedAssetKeyLog, ) from airflow.models.backfill import Backfill, _create_backfill +from airflow.models.callback import ExecutorCallback from airflow.models.dag import DagModel, get_last_dagrun, infer_automated_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel @@ -92,7 +93,7 @@ from airflow.timetables.base import DagRunInfo, DataInterval from airflow.utils.session import create_session, provide_session from airflow.utils.span_status import SpanStatus -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import CallbackState, DagRunState, State, TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -556,6 +557,48 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta any_order=True, ) + def test_enqueue_executor_callbacks_only_selects_pending_state(self, dag_maker, session): + def test_callback(): + pass + + def create_callback_in_state(state: CallbackState): + callback = Deadline( + deadline_time=timezone.utcnow(), + callback=SyncCallback(test_callback), + dagrun_id=dag_run.id, + deadline_alert_id=None, + ).callback + callback.state = state + callback.data["dag_run_id"] = dag_run.id + callback.data["dag_id"] = dag_run.dag_id + return callback + + with dag_maker(dag_id="test_callback_states"): + pass + dag_run = dag_maker.create_dagrun() + + scheduled_callback = create_callback_in_state(CallbackState.SCHEDULED) + pending_callback = create_callback_in_state(CallbackState.PENDING) + queued_callback = create_callback_in_state(CallbackState.QUEUED) + running_callback = create_callback_in_state(CallbackState.RUNNING) + session.add_all([scheduled_callback, pending_callback, queued_callback, running_callback]) + session.flush() + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + # Verify initial state before calling _enqueue_executor_callbacks + assert session.get(ExecutorCallback, pending_callback.id).state == CallbackState.PENDING + + self.job_runner._enqueue_executor_callbacks(session) + # PENDING should progress to QUEUED after _enqueue_executor_callbacks + assert session.get(ExecutorCallback, pending_callback.id).state == CallbackState.QUEUED + + # Other callbacks should remain in their original states + assert session.get(ExecutorCallback, scheduled_callback.id).state == CallbackState.SCHEDULED + assert session.get(ExecutorCallback, queued_callback.id).state == CallbackState.QUEUED + assert session.get(ExecutorCallback, running_callback.id).state == CallbackState.RUNNING + @mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest") @mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr") def test_process_executor_events_with_callback( @@ -1311,7 +1354,7 @@ def test_find_executable_task_instances_executor_with_teams(self, dag_maker, moc assert len(res) == 5 # Verify that each task is routed to the correct executor - executor_to_tis = self.job_runner._executor_to_tis(res, session) + executor_to_tis = self.job_runner._executor_to_workloads(res, session) # Team pi tasks should go to mock_executors[0] (configured for team_pi) a_tis_in_executor = [ti for ti in executor_to_tis.get(mock_executors[0], []) if ti.dag_id == "dag_a"] @@ -7909,7 +7952,7 @@ def test_multi_team_get_team_names_for_dag_ids_database_error(self, mock_log, da assert result == {} mock_log.exception.assert_called_once() - def test_multi_team_get_task_team_name_success(self, dag_maker, session): + def test_multi_team_get_workload_team_name_success(self, dag_maker, session): """Test successful team name resolution for a single task.""" clear_db_teams() clear_db_dag_bundles() @@ -7932,10 +7975,10 @@ def test_multi_team_get_task_team_name_success(self, dag_maker, session): scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - result = self.job_runner._get_task_team_name(ti, session) + result = self.job_runner._get_workload_team_name(ti, session) assert result == "team_a" - def test_multi_team_get_task_team_name_no_team(self, dag_maker, session): + def test_multi_team_get_workload_team_name_no_team(self, dag_maker, session): """Test team resolution when no team is associated with the DAG.""" with dag_maker(dag_id="dag_no_team", session=session): task = EmptyOperator(task_id="task_no_team") @@ -7946,10 +7989,10 @@ def test_multi_team_get_task_team_name_no_team(self, dag_maker, session): scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - result = self.job_runner._get_task_team_name(ti, session) + result = self.job_runner._get_workload_team_name(ti, session) assert result is None - def test_multi_team_get_task_team_name_database_error(self, dag_maker, session): + def test_multi_team_get_workload_team_name_database_error(self, dag_maker, session): """Test graceful error handling when individual task team resolution fails. This code should _not_ fail the scheduler.""" with dag_maker(dag_id="dag_test", session=session): task = EmptyOperator(task_id="task_test") @@ -7962,7 +8005,7 @@ def test_multi_team_get_task_team_name_database_error(self, dag_maker, session): # Mock _get_team_names_for_dag_ids to return empty dict (simulates database error handling in that function) with mock.patch.object(self.job_runner, "_get_team_names_for_dag_ids", return_value={}) as mock_batch: - result = self.job_runner._get_task_team_name(ti, session) + result = self.job_runner._get_workload_team_name(ti, session) mock_batch.assert_called_once_with([ti.dag_id], session) # Should return None when batch function returns empty dict @@ -7980,7 +8023,7 @@ def test_multi_team_try_to_load_executor_multi_team_disabled(self, dag_maker, mo scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - with mock.patch.object(self.job_runner, "_get_task_team_name") as mock_team_resolve: + with mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_team_resolve: result = self.job_runner._try_to_load_executor(ti, session) # Should not call team resolution when multi_team is disabled mock_team_resolve.assert_not_called() @@ -8177,7 +8220,8 @@ def test_multi_team_try_to_load_executor_explicit_executor_team_mismatch( # Should log a warning when no executor is found mock_log.warning.assert_called_once_with( - "Executor, %s, was not found but a Task was configured to use it", "secondary_exec" + "Executor, %s, was not found but a Task or Callback was configured to use it", + "secondary_exec", ) # Should return None since we failed to resolve an executor due to the mismatch. In practice, this @@ -8229,7 +8273,7 @@ def test_multi_team_try_to_load_executor_team_name_pre_resolved(self, dag_maker, self.job_runner = SchedulerJobRunner(job=scheduler_job) # Call with pre-resolved team name (as done in the scheduling loop) - with mock.patch.object(self.job_runner, "_get_task_team_name") as mock_team_resolve: + with mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_team_resolve: result = self.job_runner._try_to_load_executor(ti, session, team_name="team_a") mock_team_resolve.assert_not_called() # We don't query for the team if it is pre-resolved @@ -8320,13 +8364,13 @@ def test_multi_team_executor_to_tis_batch_optimization(self, dag_maker, mock_exe with ( assert_queries_count(1, session=session), - mock.patch.object(self.job_runner, "_get_task_team_name") as mock_single, + mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_single, ): - executor_to_tis = self.job_runner._executor_to_tis([ti1, ti2], session) + executor_to_workloads = self.job_runner._executor_to_workloads([ti1, ti2], session) mock_single.assert_not_called() - assert executor_to_tis[mock_executors[0]] == [ti1] - assert executor_to_tis[mock_executors[1]] == [ti2] + assert executor_to_workloads[mock_executors[0]] == [ti1] + assert executor_to_workloads[mock_executors[1]] == [ti2] @conf_vars({("core", "multi_team"): "false"}) def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker, mock_executors, session): @@ -8342,7 +8386,7 @@ def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker, mock_e scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - with mock.patch.object(self.job_runner, "_get_task_team_name") as mock_team_resolve: + with mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_team_resolve: result1 = self.job_runner._try_to_load_executor(ti1, session) result2 = self.job_runner._try_to_load_executor(ti2, session) diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index dfc19fc61a354..6ab6ad2d02df7 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -123,7 +123,7 @@ def test_polymorphic_serde(self, session): assert isinstance(retrieved, TriggererCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_ASYNC_CALLBACK.serialize() - assert retrieved.state == CallbackState.PENDING.value + assert retrieved.state == CallbackState.SCHEDULED.value assert retrieved.output is None assert retrieved.priority_weight == 1 assert retrieved.created_at is not None @@ -131,7 +131,7 @@ def test_polymorphic_serde(self, session): def test_queue(self, session): callback = TriggererCallback(TEST_ASYNC_CALLBACK) - assert callback.state == CallbackState.PENDING + assert callback.state == CallbackState.SCHEDULED assert callback.trigger is None callback.queue() @@ -193,7 +193,7 @@ def test_polymorphic_serde(self, session): assert isinstance(retrieved, ExecutorCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_SYNC_CALLBACK.serialize() - assert retrieved.state == CallbackState.PENDING.value + assert retrieved.state == CallbackState.SCHEDULED.value assert retrieved.output is None assert retrieved.priority_weight == 1 assert retrieved.created_at is not None @@ -201,7 +201,7 @@ def test_polymorphic_serde(self, session): def test_queue(self): callback = ExecutorCallback(TEST_SYNC_CALLBACK, fetch_method=CallbackFetchMethod.DAG_ATTRIBUTE) - assert callback.state == CallbackState.PENDING + assert callback.state == CallbackState.SCHEDULED callback.queue() assert callback.state == CallbackState.QUEUED diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index e7bef2859afcc..12f7e3ac427f4 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -2092,6 +2092,7 @@ winrm WIT workgroup workgroups +WorkloadKey workspaces writeable wsman diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 7144425b2c3d7..cd142acc7a182 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -39,7 +39,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats from airflow.utils.state import TaskInstanceState @@ -52,13 +52,11 @@ if TYPE_CHECKING: from collections.abc import Sequence - from sqlalchemy.orm import Session - from airflow.cli.cli_config import GroupCommand from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.celery.executors.celery_executor_utils import TaskInstanceInCelery, TaskTuple + from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery # PEP562 @@ -91,16 +89,17 @@ class CeleryExecutor(BaseExecutor): """ supports_ad_hoc_ti_run: bool = True + supports_callbacks: bool = True sentry_integration: str = "sentry_sdk.integrations.celery.CeleryIntegration" # TODO: Remove this flag once providers depend on Airflow 3.2. supports_sentry: bool = True supports_multi_team: bool = True - if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: - # In the v3 path, we store workloads, not commands as strings. - # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + if TYPE_CHECKING: + if AIRFLOW_V_3_0_PLUS: + # TODO: TaskSDK: move this type change into BaseExecutor + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -160,18 +159,25 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: # Airflow V3 version -- have to delay imports until we know we are on v3 from airflow.executors.workloads import ExecuteTask - tasks = [ - (workload.ti.key, workload, workload.ti.queue, self.team_name) - for workload in workloads - if isinstance(workload, ExecuteTask) - ] - if len(tasks) != len(workloads): - invalid = list(workload for workload in workloads if not isinstance(workload, ExecuteTask)) - raise ValueError(f"{type(self)}._process_workloads cannot handle {invalid}") + if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteCallback + + tasks: list[WorkloadInCelery] = [] + for workload in workloads: + if isinstance(workload, ExecuteTask): + tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) + elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): + # Use default queue for callbacks, or extract from callback data if available + queue = "default" + if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: + queue = workload.callback.data["queue"] + tasks.append((workload.callback.key, workload, queue, self.team_name)) + else: + raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") self._send_tasks(tasks) - def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): + def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]): # Celery state queries will be stuck if we do not use one same backend # for all tasks. cached_celery_backend = self.celery_app.backend @@ -195,7 +201,10 @@ def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): ) self.task_publish_retries[key] = retries + 1 continue - self.queued_tasks.pop(key) + if key in self.queued_tasks: + self.queued_tasks.pop(key) + else: + self.queued_callbacks.pop(key, None) self.task_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback) @@ -210,7 +219,7 @@ def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): # which point we don't need the ID anymore anyway self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) - def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): + def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[WorkloadInCelery]): from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: @@ -375,11 +384,3 @@ def get_cli_commands() -> list[GroupCommand]: from airflow.providers.celery.cli.definition import get_celery_cli_commands return get_celery_cli_commands() - - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: - from airflow.executors import workloads - - if not isinstance(workload, workloads.ExecuteTask): - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - ti = workload.ti - self.queued_tasks[ti.key] = workload diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index b09737701f3f7..578d0a909acc1 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -53,6 +53,9 @@ except ImportError: from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager +if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads.callback import execute_callback_workload + log = logging.getLogger(__name__) if sys.platform == "darwin": @@ -67,16 +70,21 @@ from airflow.executors import workloads from airflow.executors.base_executor import EventBufferValueType, ExecutorConf + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstanceKey # We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define # the type as the union of both kinds CommandType = Sequence[str] - TaskInstanceInCelery: TypeAlias = tuple[ - TaskInstanceKey, workloads.All | CommandType, str | None, str | None + WorkloadInCelery: TypeAlias = tuple[WorkloadKey, workloads.All | CommandType, str | None, str | None] + WorkloadInCeleryResult: TypeAlias = tuple[ + WorkloadKey, CommandType, AsyncResult | "ExceptionWithTraceback" ] + # Deprecated alias for backward compatibility + TaskInstanceInCelery: TypeAlias = WorkloadInCelery + TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None] OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout") @@ -182,9 +190,6 @@ def execute_workload(input: str) -> None: celery_task_id = app.current_task.request.id - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") - log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) base_url = conf.get("api", "base_url", fallback="/") @@ -193,15 +198,22 @@ def execute_workload(input: str) -> None: base_url = f"http://localhost:8080{base_url}" default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) + if isinstance(workload, workloads.ExecuteTask): + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), + log_path=workload.log_path, + ) + elif isinstance(workload, workloads.ExecuteCallback): + success, error_msg = execute_callback_workload(workload.callback, log) + if not success: + raise RuntimeError(error_msg or "Callback execution failed") + else: + raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") if not AIRFLOW_V_3_0_PLUS: @@ -303,16 +315,16 @@ def __init__(self, exception: BaseException, exception_traceback: str): self.traceback = exception_traceback -def send_task_to_executor( - task_tuple: TaskInstanceInCelery, -) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]: +def send_workload_to_executor( + workload_tuple: WorkloadInCelery, +) -> WorkloadInCeleryResult: """ - Send task to executor. + Send workload to executor. This function is called in ProcessPoolExecutor subprocesses. To avoid pickling issues with team-specific Celery apps, we pass the team_name and reconstruct the Celery app here. """ - key, args, queue, team_name = task_tuple + key, args, queue, team_name = workload_tuple # Reconstruct the Celery app from configuration, which may or may not be team-specific. # ExecutorConf wraps config access to automatically use team-specific config where present. @@ -326,7 +338,6 @@ def send_task_to_executor( else: # Airflow <3.2 ExecutorConf doesn't exist (at least not with the required attributes), fall back to global conf _conf = conf - # Create the Celery app with the correct configuration celery_app = create_celery_app(_conf) @@ -362,6 +373,10 @@ def send_task_to_executor( return key, args, result +# Backward compatibility alias +send_task_to_executor = send_workload_to_executor + + def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: """ Fetch and return the state of the given celery task. diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 49ae5b35b6f52..34cfb27e86a81 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -104,18 +104,18 @@ def _task_event_logs(self, value): def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from celery and kubernetes executor.""" queued_tasks = self.celery_executor.queued_tasks.copy() - queued_tasks.update(self.kubernetes_executor.queued_tasks) + queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] - return queued_tasks + return queued_tasks # type: ignore[return-value] @queued_tasks.setter def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" - @property + @property # type: ignore[override] def running(self) -> set[TaskInstanceKey]: """Return running tasks from celery and kubernetes executor.""" - return self.celery_executor.running.union(self.kubernetes_executor.running) + return self.celery_executor.running.union(self.kubernetes_executor.running) # type: ignore[return-value, arg-type] @running.setter def running(self, value) -> None: @@ -225,7 +225,7 @@ def heartbeat(self) -> None: self.celery_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer( + def get_event_buffer( # type: ignore[override] self, dag_ids: list[str] | None = None ) -> dict[TaskInstanceKey, EventBufferValueType]: """ @@ -237,7 +237,7 @@ def get_event_buffer( cleared_events_from_celery = self.celery_executor.get_event_buffer(dag_ids) cleared_events_from_kubernetes = self.kubernetes_executor.get_event_buffer(dag_ids) - return {**cleared_events_from_celery, **cleared_events_from_kubernetes} + return {**cleared_events_from_celery, **cleared_events_from_kubernetes} # type: ignore[dict-item] def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index 3b055dd0b4d66..da8ee15571b10 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -42,6 +42,7 @@ from airflow._shared.timezones import timezone from airflow.configuration import conf from airflow.executors import workloads +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, TaskInstanceKey @@ -197,7 +198,7 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel="info"): - ti_success = workloads.TaskInstance.model_construct( + ti_success = TaskInstanceDTO.model_construct( id=uuid7(), task_id="success", dag_id="id", @@ -257,7 +258,7 @@ def test_error_sending_task(self): else: ti = TaskInstance(task=task, run_id="abc") workload = workloads.ExecuteTask.model_construct( - ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), + ti=TaskInstanceDTO.model_validate(ti, from_attributes=True), ) key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) @@ -309,7 +310,7 @@ def test_retry_on_error_sending_task(self, caplog): else: ti = TaskInstance(task=task, run_id="abc") workload = workloads.ExecuteTask.model_construct( - ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), + ti=TaskInstanceDTO.model_validate(ti, from_attributes=True), ) key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 114da7ec36fe8..f2eb64e46c588 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -108,10 +108,10 @@ def queued_tasks(self) -> dict[TaskInstanceKey, Any]: def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" - @property + @property # type: ignore[override] def running(self) -> set[TaskInstanceKey]: """Return running tasks from local and kubernetes executor.""" - return self.local_executor.running.union(self.kubernetes_executor.running) + return self.local_executor.running.union(self.kubernetes_executor.running) # type: ignore[return-value, arg-type] @running.setter def running(self, value) -> None: @@ -219,7 +219,7 @@ def heartbeat(self) -> None: self.local_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer( + def get_event_buffer( # type: ignore[override] self, dag_ids: list[str] | None = None ) -> dict[TaskInstanceKey, EventBufferValueType]: """ @@ -231,7 +231,7 @@ def get_event_buffer( cleared_events_from_local = self.local_executor.get_event_buffer(dag_ids) cleared_events_from_kubernetes = self.kubernetes_executor.get_event_buffer(dag_ids) - return {**cleared_events_from_local, **cleared_events_from_kubernetes} + return {**cleared_events_from_local, **cleared_events_from_kubernetes} # type: ignore[dict-item] def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 293d8700ec725..94feded9ff924 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1005,7 +1005,7 @@ components: - type: 'null' title: Log Path ti: - $ref: '#/components/schemas/TaskInstance' + $ref: '#/components/schemas/TaskInstanceDTO' sentry_integration: type: string title: Sentry Integration @@ -1151,7 +1151,7 @@ components: - log_chunk_data title: PushLogsBody description: Incremental new log content from worker. - TaskInstance: + TaskInstanceDTO: properties: id: type: string @@ -1209,7 +1209,7 @@ components: - pool_slots - queue - priority_weight - title: TaskInstance + title: TaskInstanceDTO description: Schema for TaskInstance with minimal required fields needed for Executors and Task SDK. TaskInstanceState: diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index aeb1ff89010a2..8c55e10d45c86 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from airflow.models.deadline import DeadlineReferenceType, ReferenceModels -from airflow.sdk.definitions.callback import AsyncCallback, Callback +from airflow.sdk.definitions.callback import AsyncCallback, Callback, SyncCallback if TYPE_CHECKING: from collections.abc import Callable @@ -44,7 +44,7 @@ def __init__( self.reference = reference self.interval = interval - if not isinstance(callback, AsyncCallback): + if not isinstance(callback, (AsyncCallback, SyncCallback)): raise ValueError(f"Callbacks of type {type(callback).__name__} are not currently supported") self.callback = callback diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py b/task-sdk/tests/task_sdk/definitions/test_deadline.py index 1025cfc27a3ac..8e9e816b30705 100644 --- a/task-sdk/tests/task_sdk/definitions/test_deadline.py +++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py @@ -138,10 +138,27 @@ def test_deadline_alert_in_set(self): alert_set = {alert1, alert2} assert len(alert_set) == 1 - def test_deadline_alert_unsupported_callback(self): - with pytest.raises(ValueError, match="Callbacks of type SyncCallback are not currently supported"): + @pytest.mark.parametrize( + ("callback_class"), + [ + pytest.param(AsyncCallback, id="async_callback"), + pytest.param(SyncCallback, id="sync_callback"), + ], + ) + def test_deadline_alert_accepts_all_callbacks(self, callback_class): + alert = DeadlineAlert( + reference=DeadlineReference.DAGRUN_QUEUED_AT, + interval=timedelta(hours=1), + callback=callback_class(TEST_CALLBACK_PATH), + ) + assert alert.callback is not None + assert isinstance(alert.callback, callback_class) + + def test_deadline_alert_rejects_invalid_callback(self): + """Test that DeadlineAlert rejects non-callback types.""" + with pytest.raises(ValueError, match="Callbacks of type str are not currently supported"): DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback=SyncCallback(TEST_CALLBACK_PATH), + callback="not_a_callback", # type: ignore )