Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/executors/workloads/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from enum import Enum
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, ClassVar, Literal
from uuid import UUID

import structlog
Expand Down Expand Up @@ -74,6 +74,7 @@ class ExecuteCallback(BaseDagBundleWorkload):
callback: CallbackDTO

type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback")
TYPE: ClassVar[str] = "ExecuteCallback"

@classmethod
def make(
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/executors/workloads/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, ClassVar, Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -72,6 +72,7 @@ class ExecuteTask(BaseDagBundleWorkload):
sentry_integration: str = ""

type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask")
TYPE: ClassVar[str] = "ExecuteTask"

@classmethod
def make(
Expand Down
25 changes: 17 additions & 8 deletions providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
status_file_path,
write_pid_to_pidfile,
)
from airflow.providers.edge3.executors.utils import is_callback_execute
from airflow.providers.edge3.models.edge_worker import (
EdgeWorkerDuplicateException,
EdgeWorkerState,
Expand All @@ -60,7 +61,7 @@
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.executors.workloads import ExecuteTask
from airflow.executors.workloads import ExecuteCallback, ExecuteTask

logger = logging.getLogger(__name__)
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
Expand Down Expand Up @@ -226,15 +227,23 @@ def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -
results_queue.put(e)
return 1

def _launch_job(self, workload: ExecuteTask) -> tuple[Process, Queue[Exception]]:
def _launch_job(self, workload: ExecuteTask | ExecuteCallback) -> tuple[Process, Queue[Exception]]:
# Improvement: Use frozen GC to prevent child process from copying unnecessary memory
# See _spawn_workers_with_gc_freeze() in airflow-core/src/airflow/executors/local_executor.py
results_queue: Queue[Exception] = Queue()
process = Process(
target=self._run_job_via_supervisor,
kwargs={"workload": workload, "results_queue": results_queue},
)
process.start()
if is_callback_execute(workload):
process = Process(
# TODO : change the supervisor by using in https://github.com/apache/airflow/pull/62645
target=self._run_job_via_supervisor,
kwargs={"workload": workload, "results_queue": results_queue},
)
process.start()
else:
process = Process(
target=self._run_job_via_supervisor,
kwargs={"workload": workload, "results_queue": results_queue},
)
process.start()
return process, results_queue

async def _push_logs_in_chunks(self, job: Job):
Expand Down Expand Up @@ -343,7 +352,7 @@ async def fetch_and_run_job(self) -> None:

logger.info("Received job: %s", edge_job.identifier)

workload: ExecuteTask = edge_job.command
workload: ExecuteTask | ExecuteCallback = edge_job.command
process, results_queue = self._launch_job(workload)
if TYPE_CHECKING:
assert workload.log_path # We need to assume this is defined in here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.compat.sdk import Stats, conf, timezone
from airflow.providers.edge3.executors.utils import is_callback_execute
from airflow.providers.edge3.models.db import EdgeDBManager, check_db_manager_config
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
Expand Down Expand Up @@ -87,42 +88,71 @@ def queue_workload(
session: Session = NEW_SESSION,
) -> None:
"""Put new workload to queue. Airflow 3 entry point to execute a task."""
if not isinstance(workload, workloads.ExecuteTask):
raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}")
if is_callback_execute(workload):
existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == workload.type,
EdgeJobModel.task_id == workload.callback.key,
EdgeJobModel.run_id == f"{workload.type}-{workload.callback.key}",
)
).first()

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.command = workload.model_dump_json()
else:
session.add(
EdgeJobModel(
dag_id=workload.type,
task_id=workload.callback.key,
run_id=f"{workload.type}-{workload.callback.key}",
map_index=-1,
try_number=0,
queue=self.conf.get_mandatory_value("operators", "default_queue"),
concurrency_slots=1,
state=TaskInstanceState.QUEUED,
command=workload.model_dump_json(),
)
)

task_instance = workload.ti
key = task_instance.key

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == key.dag_id,
EdgeJobModel.task_id == key.task_id,
EdgeJobModel.run_id == key.run_id,
EdgeJobModel.map_index == key.map_index,
EdgeJobModel.try_number == key.try_number,
)
).first()
elif isinstance(workload, workloads.ExecuteTask):
# TODO: to be updated
task_instance = workload.ti
key = task_instance.key

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == key.dag_id,
EdgeJobModel.task_id == key.task_id,
EdgeJobModel.run_id == key.run_id,
EdgeJobModel.map_index == key.map_index,
EdgeJobModel.try_number == key.try_number,
)
).first()

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.queue = task_instance.queue
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = workload.model_dump_json()
else:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
)
)

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.queue = task_instance.queue
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = workload.model_dump_json()
else:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
)
)
raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}")

def _check_worker_liveness(self, session: Session) -> bool:
"""Reset worker state if heartbeat timed out."""
Expand Down
43 changes: 43 additions & 0 deletions providers/edge3/src/airflow/providers/edge3/executors/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Annotated, TypeGuard

from pydantic import Discriminator, Tag

from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteTask

ExecuteTypeBody = ExecuteTask
else:
from airflow.executors.workloads import ExecuteCallback, ExecuteTask

ExecuteTypeBody = Annotated[
Annotated[ExecuteTask, Tag("ExecuteTask")] | Annotated[ExecuteCallback, Tag("ExecuteCallback")],
Discriminator("type"),
]


def is_callback_execute(workload: ExecuteCallback | ExecuteTask) -> TypeGuard[ExecuteCallback]:
if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteCallback

return isinstance(workload, ExecuteCallback)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:


AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)

__all__ = [
"AIRFLOW_V_3_1_PLUS",
]
__all__ = ["AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS"]
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,24 @@
)

from fastapi import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, Discriminator, Field, Tag

from airflow.executors.workloads import ExecuteTask # noqa: TCH001
from airflow.providers.common.compat.sdk import TaskInstanceKey
from airflow.providers.edge3.executors.utils import ExecuteTypeBody
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState # noqa: TCH001
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteTask

ExecuteTypeBody = ExecuteTask
else:
from airflow.executors.workloads import ExecuteCallback, ExecuteTask

ExecuteTypeBody = Annotated[
Annotated[ExecuteTask, Tag("ExecuteTask")] | Annotated[ExecuteCallback, Tag("ExecuteCallback")],
Discriminator("type"),
]


class WorkerApiDocs:
Expand Down Expand Up @@ -91,7 +104,7 @@ class EdgeJobFetched(EdgeJobBase):
"""Job that is to be executed on the edge worker."""

command: Annotated[
ExecuteTask,
ExecuteTypeBody,
Field(
title="Command",
description="Command line to use to execute the job in Airflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Body, Depends, status
from sqlalchemy import select, update
Expand All @@ -27,6 +27,7 @@
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.executors.workloads import ExecuteTask
from airflow.providers.common.compat.sdk import Stats, timezone
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS

try:
from airflow.sdk.observability.stats import DualStatsManager
Expand All @@ -41,10 +42,19 @@
)
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.executors.workloads import ExecuteCallback

jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")


def parse_command(command: str) -> ExecuteTask:
def parse_command(command: str, dag_id: str, run_id: str) -> ExecuteTask | ExecuteCallback:
if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteCallback

if dag_id == ExecuteCallback.TYPE and run_id.startswith(ExecuteCallback.TYPE):
return ExecuteCallback.model_validate_json(command)

return ExecuteTask.model_validate_json(command)


Expand Down Expand Up @@ -102,7 +112,7 @@ def fetch(
run_id=job.run_id,
map_index=job.map_index,
try_number=job.try_number,
command=parse_command(job.command),
command=parse_command(job.command, job.dag_id, job.run_id),
concurrency_slots=job.concurrency_slots,
)

Expand Down
Loading
Loading