diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 273c55953675b..6a664105bd3e5 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -21,7 +21,7 @@ from enum import Enum from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, ClassVar, Literal from uuid import UUID import structlog @@ -74,6 +74,7 @@ class ExecuteCallback(BaseDagBundleWorkload): callback: CallbackDTO type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") + TYPE: ClassVar[str] = "ExecuteCallback" @classmethod def make( diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index 4ca8c310fb5c2..785697c4f71aa 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -20,7 +20,7 @@ import uuid from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, ClassVar, Literal from pydantic import BaseModel, Field @@ -72,6 +72,7 @@ class ExecuteTask(BaseDagBundleWorkload): sentry_integration: str = "" type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + TYPE: ClassVar[str] = "ExecuteTask" @classmethod def make( diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index f0cffa96c05ea..9208cb9d15ea5 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -51,6 +51,7 @@ status_file_path, write_pid_to_pidfile, ) +from airflow.providers.edge3.executors.utils import is_callback_execute from airflow.providers.edge3.models.edge_worker import ( EdgeWorkerDuplicateException, EdgeWorkerState, @@ -60,7 +61,7 @@ from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads import ExecuteCallback, ExecuteTask logger = logging.getLogger(__name__) base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE") @@ -226,15 +227,23 @@ def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) - results_queue.put(e) return 1 - def _launch_job(self, workload: ExecuteTask) -> tuple[Process, Queue[Exception]]: + def _launch_job(self, workload: ExecuteTask | ExecuteCallback) -> tuple[Process, Queue[Exception]]: # Improvement: Use frozen GC to prevent child process from copying unnecessary memory # See _spawn_workers_with_gc_freeze() in airflow-core/src/airflow/executors/local_executor.py results_queue: Queue[Exception] = Queue() - process = Process( - target=self._run_job_via_supervisor, - kwargs={"workload": workload, "results_queue": results_queue}, - ) - process.start() + if is_callback_execute(workload): + process = Process( + # TODO : change the supervisor by using in https://github.com/apache/airflow/pull/62645 + target=self._run_job_via_supervisor, + kwargs={"workload": workload, "results_queue": results_queue}, + ) + process.start() + else: + process = Process( + target=self._run_job_via_supervisor, + kwargs={"workload": workload, "results_queue": results_queue}, + ) + process.start() return process, results_queue async def _push_logs_in_chunks(self, job: Job): @@ -343,7 +352,7 @@ async def fetch_and_run_job(self) -> None: logger.info("Received job: %s", edge_job.identifier) - workload: ExecuteTask = edge_job.command + workload: ExecuteTask | ExecuteCallback = edge_job.command process, results_queue = self._launch_job(workload) if TYPE_CHECKING: assert workload.log_path # We need to assume this is defined in here diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index bef49278c6a70..f1d97e9e9f65f 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -28,6 +28,7 @@ from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.sdk import Stats, conf, timezone +from airflow.providers.edge3.executors.utils import is_callback_execute from airflow.providers.edge3.models.db import EdgeDBManager, check_db_manager_config from airflow.providers.edge3.models.edge_job import EdgeJobModel from airflow.providers.edge3.models.edge_logs import EdgeLogsModel @@ -87,42 +88,71 @@ def queue_workload( session: Session = NEW_SESSION, ) -> None: """Put new workload to queue. Airflow 3 entry point to execute a task.""" - if not isinstance(workload, workloads.ExecuteTask): - raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}") + if is_callback_execute(workload): + existing_job = session.scalars( + select(EdgeJobModel).where( + EdgeJobModel.dag_id == workload.type, + EdgeJobModel.task_id == workload.callback.key, + EdgeJobModel.run_id == f"{workload.type}-{workload.callback.key}", + ) + ).first() + + if existing_job: + existing_job.state = TaskInstanceState.QUEUED + existing_job.command = workload.model_dump_json() + else: + session.add( + EdgeJobModel( + dag_id=workload.type, + task_id=workload.callback.key, + run_id=f"{workload.type}-{workload.callback.key}", + map_index=-1, + try_number=0, + queue=self.conf.get_mandatory_value("operators", "default_queue"), + concurrency_slots=1, + state=TaskInstanceState.QUEUED, + command=workload.model_dump_json(), + ) + ) - task_instance = workload.ti - key = task_instance.key - - # Check if job already exists with same dag_id, task_id, run_id, map_index, try_number - existing_job = session.scalars( - select(EdgeJobModel).where( - EdgeJobModel.dag_id == key.dag_id, - EdgeJobModel.task_id == key.task_id, - EdgeJobModel.run_id == key.run_id, - EdgeJobModel.map_index == key.map_index, - EdgeJobModel.try_number == key.try_number, - ) - ).first() + elif isinstance(workload, workloads.ExecuteTask): + # TODO: to be updated + task_instance = workload.ti + key = task_instance.key + + # Check if job already exists with same dag_id, task_id, run_id, map_index, try_number + existing_job = session.scalars( + select(EdgeJobModel).where( + EdgeJobModel.dag_id == key.dag_id, + EdgeJobModel.task_id == key.task_id, + EdgeJobModel.run_id == key.run_id, + EdgeJobModel.map_index == key.map_index, + EdgeJobModel.try_number == key.try_number, + ) + ).first() + + if existing_job: + existing_job.state = TaskInstanceState.QUEUED + existing_job.queue = task_instance.queue + existing_job.concurrency_slots = task_instance.pool_slots + existing_job.command = workload.model_dump_json() + else: + session.add( + EdgeJobModel( + dag_id=key.dag_id, + task_id=key.task_id, + run_id=key.run_id, + map_index=key.map_index, + try_number=key.try_number, + state=TaskInstanceState.QUEUED, + queue=task_instance.queue, + concurrency_slots=task_instance.pool_slots, + command=workload.model_dump_json(), + ) + ) - if existing_job: - existing_job.state = TaskInstanceState.QUEUED - existing_job.queue = task_instance.queue - existing_job.concurrency_slots = task_instance.pool_slots - existing_job.command = workload.model_dump_json() else: - session.add( - EdgeJobModel( - dag_id=key.dag_id, - task_id=key.task_id, - run_id=key.run_id, - map_index=key.map_index, - try_number=key.try_number, - state=TaskInstanceState.QUEUED, - queue=task_instance.queue, - concurrency_slots=task_instance.pool_slots, - command=workload.model_dump_json(), - ) - ) + raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}") def _check_worker_liveness(self, session: Session) -> bool: """Reset worker state if heartbeat timed out.""" diff --git a/providers/edge3/src/airflow/providers/edge3/executors/utils.py b/providers/edge3/src/airflow/providers/edge3/executors/utils.py new file mode 100644 index 0000000000000..9aa3f34974049 --- /dev/null +++ b/providers/edge3/src/airflow/providers/edge3/executors/utils.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Annotated, TypeGuard + +from pydantic import Discriminator, Tag + +from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS + +if not AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteTask + + ExecuteTypeBody = ExecuteTask +else: + from airflow.executors.workloads import ExecuteCallback, ExecuteTask + + ExecuteTypeBody = Annotated[ + Annotated[ExecuteTask, Tag("ExecuteTask")] | Annotated[ExecuteCallback, Tag("ExecuteCallback")], + Discriminator("type"), + ] + + +def is_callback_execute(workload: ExecuteCallback | ExecuteTask) -> TypeGuard[ExecuteCallback]: + if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteCallback + + return isinstance(workload, ExecuteCallback) + return False diff --git a/providers/edge3/src/airflow/providers/edge3/version_compat.py b/providers/edge3/src/airflow/providers/edge3/version_compat.py index 27070ab292bad..f251f1ec357de 100644 --- a/providers/edge3/src/airflow/providers/edge3/version_compat.py +++ b/providers/edge3/src/airflow/providers/edge3/version_compat.py @@ -33,7 +33,6 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) +AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) -__all__ = [ - "AIRFLOW_V_3_1_PLUS", -] +__all__ = ["AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS"] diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py index fc780b8766219..37a0c515808a6 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py @@ -23,11 +23,24 @@ ) from fastapi import Path -from pydantic import BaseModel, Field +from pydantic import BaseModel, Discriminator, Field, Tag -from airflow.executors.workloads import ExecuteTask # noqa: TCH001 from airflow.providers.common.compat.sdk import TaskInstanceKey +from airflow.providers.edge3.executors.utils import ExecuteTypeBody from airflow.providers.edge3.models.edge_worker import EdgeWorkerState # noqa: TCH001 +from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS + +if not AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteTask + + ExecuteTypeBody = ExecuteTask +else: + from airflow.executors.workloads import ExecuteCallback, ExecuteTask + + ExecuteTypeBody = Annotated[ + Annotated[ExecuteTask, Tag("ExecuteTask")] | Annotated[ExecuteCallback, Tag("ExecuteCallback")], + Discriminator("type"), + ] class WorkerApiDocs: @@ -91,7 +104,7 @@ class EdgeJobFetched(EdgeJobBase): """Job that is to be executed on the edge worker.""" command: Annotated[ - ExecuteTask, + ExecuteTypeBody, Field( title="Command", description="Command line to use to execute the job in Airflow", diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py index 1191d068014a0..3b407d655a7db 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import Body, Depends, status from sqlalchemy import select, update @@ -27,6 +27,7 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.executors.workloads import ExecuteTask from airflow.providers.common.compat.sdk import Stats, timezone +from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS try: from airflow.sdk.observability.stats import DualStatsManager @@ -41,10 +42,19 @@ ) from airflow.utils.state import TaskInstanceState +if TYPE_CHECKING: + from airflow.executors.workloads import ExecuteCallback + jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs") -def parse_command(command: str) -> ExecuteTask: +def parse_command(command: str, dag_id: str, run_id: str) -> ExecuteTask | ExecuteCallback: + if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteCallback + + if dag_id == ExecuteCallback.TYPE and run_id.startswith(ExecuteCallback.TYPE): + return ExecuteCallback.model_validate_json(command) + return ExecuteTask.model_validate_json(command) @@ -102,7 +112,7 @@ def fetch( run_id=job.run_id, map_index=job.map_index, try_number=job.try_number, - command=parse_command(job.command), + command=parse_command(job.command, job.dag_id, job.run_id), concurrency_slots=job.concurrency_slots, ) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 2904a2e0d3d5d..0080bf274540f 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -964,6 +964,32 @@ components: - name title: BundleInfo description: Schema for telling task which bundle to run with. + CallbackDTO: + properties: + id: + type: string + title: Id + fetch_method: + $ref: '#/components/schemas/CallbackFetchMethod' + data: + additionalProperties: true + type: object + title: Data + type: object + required: + - id + - fetch_method + - data + title: CallbackDTO + description: Schema for Callback with minimal required fields needed for Executors + and Task SDK. + CallbackFetchMethod: + type: string + enum: + - dag_attribute + - import_path + title: CallbackFetchMethod + description: Methods used to fetch callback at runtime. ConcurrencyRequest: properties: concurrency: @@ -1000,9 +1026,16 @@ components: title: Try Number description: The number of attempt to execute this task. command: - $ref: '#/components/schemas/ExecuteTask' + oneOf: + - $ref: '#/components/schemas/ExecuteTask' + - $ref: '#/components/schemas/ExecuteCallback' title: Command description: Command line to use to execute the job in Airflow + discriminator: + propertyName: type + mapping: + ExecuteCallback: '#/components/schemas/ExecuteCallback' + ExecuteTask: '#/components/schemas/ExecuteTask' concurrency_slots: type: integer title: Concurrency Slots @@ -1035,6 +1068,38 @@ components: - offline maintenance title: EdgeWorkerState description: Status of a Edge Worker instance. + ExecuteCallback: + properties: + token: + type: string + title: Token + dag_rel_path: + type: string + format: path + title: Dag Rel Path + bundle_info: + $ref: '#/components/schemas/BundleInfo' + log_path: + anyOf: + - type: string + - type: 'null' + title: Log Path + callback: + $ref: '#/components/schemas/CallbackDTO' + type: + type: string + const: ExecuteCallback + title: Type + default: ExecuteCallback + type: object + required: + - token + - dag_rel_path + - bundle_info + - log_path + - callback + title: ExecuteCallback + description: Execute the given Callback. ExecuteTask: properties: token: diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index b2b97034b36ba..ff17ea570f0d3 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -23,6 +23,7 @@ from io import StringIO from multiprocessing import Process, Queue from pathlib import Path +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import call, patch @@ -51,6 +52,13 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS +if TYPE_CHECKING: + from airflow.executors.workloads import ExecuteCallback + +if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteCallback + + pytest.importorskip("pydantic", minversion="2.0.0") pytestmark = [pytest.mark.asyncio] @@ -73,6 +81,23 @@ "dag_rel_path": "mock.py", "log_path": "mock.log", "bundle_info": {"name": "hello", "version": "abc"}, + "type": "ExecuteTask", +} + +MOCK_CALLBACK_COMMAND = { + "token": "mock", + "callback": { + "id": "12345678-1234-5678-1234-567812345678", + "fetch_method": "import_path", + "data": { + "path": "builtins.dict", + "kwargs": {"a": 1, "b": 2, "c": 3}, + }, + }, + "dag_rel_path": "test.py", + "log_path": "test.log", + "bundle_info": {"name": "test_bundle", "version": "1.0"}, + "type": "ExecuteCallback", } @@ -525,3 +550,25 @@ def test_list_edge_workers(self, mock_edgeworker: EdgeWorkerModel): ]: assert key in edge_workers[0] assert any("test_edge_worker" in h["worker_name"] for h in edge_workers) + + +class TestEdgeJobFetchedSerialization: + """Test that EdgeJobFetched serializes and deserializes with both ExecuteTask and ExecuteCallback.""" + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be skipped for Airflow < 3.2") + def test_serialize_with_execute_callback(self): + fetched = EdgeJobFetched( + dag_id="ExecuteCallback", + task_id="12345678-1234-5678-1234-567812345678", + run_id="ExecuteCallback-12345678-1234-5678-1234-567812345678", + map_index=-1, + try_number=0, + concurrency_slots=1, + command=MOCK_CALLBACK_COMMAND, # type: ignore[arg-type] + ) + serialized = fetched.model_dump_json() + deserialized = EdgeJobFetched.model_validate_json(serialized) + + assert deserialized.dag_id == "ExecuteCallback" + assert deserialized.command.type == ExecuteCallback.TYPE + assert isinstance(deserialized.command, ExecuteCallback) diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index c38dffed3e918..646fc75d63bd3 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -18,12 +18,15 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest import time_machine from sqlalchemy import delete, select from airflow.configuration import conf +from airflow.executors.workloads import ExecuteTask, TaskInstanceDTO +from airflow.executors.workloads.base import BundleInfo from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.common.compat.sdk import Stats, timezone from airflow.providers.edge3.executors.edge_executor import EdgeExecutor @@ -33,6 +36,10 @@ from airflow.utils.state import TaskInstanceState from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS + +if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod, ExecuteCallback pytestmark = pytest.mark.db_test @@ -347,3 +354,122 @@ def test_revoke_task_nonexistent(self): # Verify nothing breaks assert key not in executor.running assert key not in executor.queued_tasks + + +class TestQueueWorkload: + @pytest.fixture(autouse=True) + def setup(self): + with create_session() as session: + session.execute(delete(EdgeJobModel)) + session.commit() + + def _make_execute_task(self) -> ExecuteTask: + ti = TaskInstanceDTO( + id=uuid4(), + dag_version_id=uuid4(), + task_id="test_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + map_index=-1, + pool_slots=1, + queue="default", + priority_weight=1, + ) + return ExecuteTask( + ti=ti, + dag_rel_path="test_dag.py", + token="test_token", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + log_path="test.log", + ) + + def test_queue_workload_execute_task(self): + executor = EdgeExecutor() + workload = self._make_execute_task() + + executor.queue_workload(workload) + + with create_session() as session: + job = session.scalar(select(EdgeJobModel)) + assert job is not None + assert job.dag_id == "test_dag" + assert job.task_id == "test_task" + assert job.run_id == "test_run" + assert job.state == TaskInstanceState.QUEUED + assert '"type":"ExecuteTask"' in job.command or '"type": "ExecuteTask"' in job.command + + def test_queue_workload_execute_task_existing_job(self): + executor = EdgeExecutor() + workload = self._make_execute_task() + + executor.queue_workload(workload) + executor.queue_workload(workload) + + with create_session() as session: + jobs = session.scalars(select(EdgeJobModel)).all() + assert len(jobs) == 1 + assert jobs[0].state == TaskInstanceState.QUEUED + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecuteCallback requires Airflow 3.2+") + def test_queue_workload_execute_callback(self): + executor = EdgeExecutor() + id = str(uuid4()) + callback_data = CallbackDTO( + id=id, + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.dict", + "kwargs": {"a": 1, "b": 2, "c": 3}, + }, + ) + workload = ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + executor.queue_workload(workload) + + with create_session() as session: + job = session.scalar(select(EdgeJobModel)) + assert job is not None + assert job.dag_id == ExecuteCallback.TYPE + assert job.task_id == id + assert job.run_id == f"{ExecuteCallback.TYPE}-{id}" + assert job.state == TaskInstanceState.QUEUED + assert '"type":"ExecuteCallback"' in job.command or '"type": "ExecuteCallback"' in job.command + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecuteCallback requires Airflow 3.2+") + def test_queue_workload_execute_callback_existing_job(self): + executor = EdgeExecutor() + callback_data = CallbackDTO( + id=str(uuid4()), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.dict", + "kwargs": {"a": 1, "b": 2, "c": 3}, + }, + ) + workload = ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + executor.queue_workload(workload) + executor.queue_workload(workload) + + with create_session() as session: + jobs = session.scalars(select(EdgeJobModel)).all() + assert len(jobs) == 1 + assert jobs[0].state == TaskInstanceState.QUEUED + + def test_queue_workload_unknown_type_raises(self): + executor = EdgeExecutor() + with pytest.raises(TypeError, match="Don't know how to queue workload"): + executor.queue_workload(MagicMock(spec=[])) diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py index 667ec184c0aa5..4fed90e6e4860 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py @@ -18,18 +18,29 @@ from typing import TYPE_CHECKING from unittest.mock import patch +from uuid import uuid4 import pytest from sqlalchemy import delete, select +from airflow.executors.workloads import ExecuteTask, TaskInstanceDTO +from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.callback import CallbackDTO from airflow.providers.edge3.models.edge_job import EdgeJobModel -from airflow.providers.edge3.worker_api.routes.jobs import state +from airflow.providers.edge3.worker_api.routes.jobs import parse_command, state from airflow.utils.session import create_session from airflow.utils.state import TaskInstanceState +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS + if TYPE_CHECKING: from sqlalchemy.orm import Session + from airflow.executors.workloads import CallbackFetchMethod, ExecuteCallback + +if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import CallbackFetchMethod, ExecuteCallback + try: from airflow.sdk._shared.observability.metrics.dual_stats_manager import DualStatsManager # noqa: F401 @@ -108,3 +119,76 @@ def test_state(self, mock_stats_incr, session: Session): db_job: EdgeJobModel | None = session.scalar(select(EdgeJobModel)) assert db_job is not None assert db_job.state == TaskInstanceState.SUCCESS + + +class TestParseCommand: + def _make_execute_task(self) -> ExecuteTask: + ti = TaskInstanceDTO( + id=uuid4(), + dag_version_id=uuid4(), + task_id="test_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + map_index=-1, + pool_slots=1, + queue="default", + priority_weight=1, + ) + return ExecuteTask( + ti=ti, + dag_rel_path="test_dag.py", + token="test_token", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + log_path="test.log", + ) + + def _make_execute_callback(self) -> ExecuteCallback: + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.dict", + "kwargs": {"a": 1, "b": 2, "c": 3}, + }, + ) + return ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + def test_parse_command_execute_task(self): + workload = self._make_execute_task() + command_json = workload.model_dump_json() + + result = parse_command(command_json, dag_id="test_dag", run_id="test_run") + + assert isinstance(result, ExecuteTask) + assert result.ti.dag_id == "test_dag" + assert result.ti.task_id == "test_task" + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be skipped for Airflow < 3.2") + def test_parse_command_execute_callback(self): + workload = self._make_execute_callback() + command_json = workload.model_dump_json() + # Mimic how edge_executor stores callback jobs + dag_id = ExecuteCallback.TYPE + run_id = f"ExecuteCallback-{workload.callback.key}" + + result = parse_command(command_json, dag_id=dag_id, run_id=run_id) + + assert isinstance(result, ExecuteCallback) + assert result.callback.id == "12345678-1234-5678-1234-567812345678" + assert result.callback.fetch_method == CallbackFetchMethod.IMPORT_PATH + + def test_parse_command_non_callback_dag_id_returns_execute_task(self): + """Even if run_id starts with ExecuteCallback, dag_id must also match.""" + workload = self._make_execute_task() + command_json = workload.model_dump_json() + + result = parse_command(command_json, dag_id="some_dag", run_id="ExecuteCallback-something") + + assert isinstance(result, ExecuteTask)