From d494c0b9500197e2ecf1d803a4f63e3abafae4f1 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 11 Mar 2025 10:31:17 -0700 Subject: [PATCH 1/3] fix eager error --- flytekit/core/worker_queue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/core/worker_queue.py b/flytekit/core/worker_queue.py index edb8c8802b..2c0b889a85 100644 --- a/flytekit/core/worker_queue.py +++ b/flytekit/core/worker_queue.py @@ -298,8 +298,8 @@ def _apply_updates(self, update_items: typing.Dict[uuid.UUID, Update]) -> None: ) exc = EagerException( - f"Error executing {update.work_item.entity.name} with error:" - f" {update.wf_exec.closure.error}" + f"Error executing {update.work_item.entity.name} with error Type:" + f" {update.wf_exec.closure.error.code}. Message: {update.wf_exec.closure.error.message}" ) item.error = exc From a75365538df1ed95aa4251a01073caff8872a72c Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 12 Mar 2025 18:19:34 -0700 Subject: [PATCH 2/3] wip --- .gitignore | 2 +- flytekit/core/context_manager.py | 10 +- flytekit/core/controller.py | 394 ++++++++++++++++++++++++++ flytekit/core/eager_controller.py | 104 +++++++ flytekit/core/history_store.py | 97 +++++++ flytekit/core/promise.py | 3 +- flytekit/core/python_function_task.py | 20 +- flytekit/core/store/__init__.py | 0 flytekit/core/store/graph_pb2.py | 80 ++++++ flytekit/core/store/graph_pb2.pyi | 254 +++++++++++++++++ flytekit/core/store/graph_pb2_grpc.py | 330 +++++++++++++++++++++ flytekit/core/worker_queue.py | 24 +- 12 files changed, 1298 insertions(+), 20 deletions(-) create mode 100644 flytekit/core/controller.py create mode 100644 flytekit/core/eager_controller.py create mode 100644 flytekit/core/history_store.py create mode 100644 flytekit/core/store/__init__.py create mode 100644 flytekit/core/store/graph_pb2.py create mode 100644 flytekit/core/store/graph_pb2.pyi create mode 100644 flytekit/core/store/graph_pb2_grpc.py diff --git a/.gitignore b/.gitignore index 0db8768ef2..99b0d35ca4 100644 --- a/.gitignore +++ b/.gitignore @@ -39,4 +39,4 @@ coverage.xml # Version file is auto-generated by setuptools_scm flytekit/_version.py -testing +samplecode diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index c8d4d92b40..0fce05afe6 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -42,7 +42,7 @@ if typing.TYPE_CHECKING: from flytekit.clients import friendly as friendly_client # noqa from flytekit.clients.friendly import SynchronousFlyteClient - from flytekit.core.worker_queue import Controller + from flytekit.core.eager_controller import EagerController from flytekit.deck.deck import Deck # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -699,7 +699,7 @@ class FlyteContext(object): in_a_condition: bool = False origin_stackframe: Optional[traceback.FrameSummary] = None output_metadata_tracker: Optional[OutputMetadataTracker] = None - worker_queue: Optional[Controller] = None + worker_queue: Optional[EagerController] = None @property def user_space_params(self) -> Optional[ExecutionParameters]: @@ -751,7 +751,7 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder: def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder: return self.new_builder().with_output_metadata_tracker(t) - def with_worker_queue(self, wq: Controller) -> Builder: + def with_worker_queue(self, wq: EagerController) -> Builder: return self.new_builder().with_worker_queue(wq) def with_client(self, c: SynchronousFlyteClient) -> Builder: @@ -820,7 +820,7 @@ class Builder(object): serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False output_metadata_tracker: Optional[OutputMetadataTracker] = None - worker_queue: Optional[Controller] = None + worker_queue: Optional[EagerController] = None def build(self) -> FlyteContext: return FlyteContext( @@ -881,7 +881,7 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext self.output_metadata_tracker = t return self - def with_worker_queue(self, wq: Controller) -> FlyteContext.Builder: + def with_worker_queue(self, wq: EagerController) -> FlyteContext.Builder: self.worker_queue = wq return self diff --git a/flytekit/core/controller.py b/flytekit/core/controller.py new file mode 100644 index 0000000000..9dd97ecd7d --- /dev/null +++ b/flytekit/core/controller.py @@ -0,0 +1,394 @@ +import asyncio +import threading +from asyncio import Queue, Semaphore, Event +from typing import Protocol, TypeVar, Generic, Dict, Any, Tuple, Set +import logging +from datetime import datetime +import time + +# Setup basic logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Define the resource protocol +class ResourceProtocol(Protocol): + name: str + last_updated: datetime + + async def sync(self) -> 'ResourceProtocol': + """Sync the resource state and return updated version (polling mode)""" + ... + + def is_terminal(self) -> bool: + """Check if resource has reached terminal state""" + ... + + def is_started(self) -> bool: + """Check if resource has been started.""" + ... + + async def launch(self) -> bool: + """Launch the resource, invoked only once for resources that need to be started. + Returns: True if successfully launched, False otherwise.""" + ... + +# Type variable for generic resource +R = TypeVar('R', bound=ResourceProtocol) + +class Informer(Generic[R]): + """Base informer with either polling or watch mode""" + + def __init__(self, shared_queue: Queue, sync_period: float = 5.0, use_watch: bool = False): + self.sync_period = sync_period if not use_watch else None + self.resources: Dict[str, R] = {} + self.shared_queue = shared_queue + self.running = False + self.use_watch = use_watch + self._lock = asyncio.Lock() + self._watch_task: asyncio.Task = None + + async def add_resource(self, resource: R): + """Add a new resource to watch""" + async with self._lock: + if resource.name not in self.resources: + self.resources[resource.name] = resource + logger.info(f"Added resource to informer: {resource.name}") + await self.shared_queue.put(resource) + + async def remove_resource(self, resource_name: str): + """Remove a resource from watching""" + async with self._lock: + if resource_name in self.resources: + del self.resources[resource_name] + logger.info(f"Removed resource from informer: {resource_name}") + + async def _on_resource_update(self, resource: R): + """Callback for watch-based updates""" + async with self._lock: + if resource.name in self.resources: + self.resources[resource.name] = resource + await self.shared_queue.put(resource) + else: + logger.debug(f"Received update for unknown resource: {resource.name}") + + async def watch(self): + """Watch for updates on all resources - to be implemented by subclasses for watch mode""" + raise NotImplementedError("Watch mode requires a custom informer implementation") + + async def start(self): + """Start the informer""" + if self.running: + logger.warning("Informer already running") + return + self.running = True + logger.info(f"Started informer in {'watch' if self.use_watch else 'polling'} mode") + if self.use_watch: + self._watch_task = asyncio.create_task(self.watch()) + else: + asyncio.create_task(self._run_polling()) + + async def stop(self): + """Stop the informer""" + self.running = False + if self.use_watch and self._watch_task: + self._watch_task.cancel() + self._watch_task = None + logger.info("Stopped informer") + + async def _run_polling(self): + """Background resource update loop for polling mode""" + while self.running: + try: + async with self._lock: + resources_to_sync = list(self.resources.items()) + + for name, resource in resources_to_sync: + new_resource = await resource.sync() + async with self._lock: + if name in self.resources: + self.resources[name] = new_resource + await self.shared_queue.put(new_resource) + + await asyncio.sleep(self.sync_period) + except Exception as e: + logger.error(f"Error in polling loop: {e}") + await asyncio.sleep(self.sync_period) + +class Controller(Generic[R]): + """Generic controller with high-level submit API""" + + def __init__(self, informer: Informer[R], shared_queue: Queue, max_concurrent_launches: int = 2): + self.informer = informer + self.shared_queue = shared_queue + self.running = False + self.max_concurrent_launches = max_concurrent_launches + self.launch_semaphore = Semaphore(max_concurrent_launches) + self.launching_tasks: Dict[str, asyncio.Task] = {} + self.completion_events: Dict[str, Event] = {} # Track completion events + + @classmethod + def for_watchable_resource(cls, informer: Informer[R], max_concurrent_launches: int = 2) -> 'Controller[R]': + if not informer.use_watch: + raise ValueError("Informer must be in watch mode") + watch_queue = Queue() + return cls(informer, watch_queue, max_concurrent_launches) + + @classmethod + def for_sync_resource(cls, max_concurrent_launches: int = 2) -> 'Controller[R]': + poll_queue = Queue() + poll_informer = Informer[R](poll_queue, sync_period=2.0, use_watch=False) + return cls(poll_informer, poll_queue, max_concurrent_launches) + + async def add_resource(self, resource: R): + """Public API to add a resource without waiting for completion""" + await self.informer.add_resource(resource) + await self.shared_queue.put(resource.name) + + async def submit_resource(self, resource: R) -> R: + """Submit a resource and await its completion, returning the final state""" + async with self.informer._lock: + if resource.name in self.informer.resources: + raise ValueError(f"Resource {resource.name} already exists") + if resource.name in self.completion_events: + raise ValueError(f"Resource {resource.name} is already being processed") + + # Create completion event and add resource + self.completion_events[resource.name] = Event() + await self.add_resource(resource) + + # Wait for completion + await self.completion_events[resource.name].wait() + + # Get final resource state and clean up + async with self.informer._lock: + final_resource = self.informer.resources.get(resource.name) + del self.completion_events[resource.name] + if final_resource: + await self.informer.remove_resource(resource.name) + return final_resource + + async def launch_resource(self, resource: R): + """Attempt to launch a resource until successful""" + async with self.launch_semaphore: + while not resource.is_started() and self.running: + logger.info(f"Attempting to launch resource: {resource.name}") + success = await resource.launch() + if success: + logger.info(f"Successfully launched resource: {resource.name}") + await self.shared_queue.put(resource) + break + else: + logger.warning(f"Failed to launch resource: {resource.name}, retrying...") + await asyncio.sleep(1) + + async def process_resource(self, resource: R): + """Process resource updates""" + logger.info(f"Processing resource: name={resource.name}, " + f"started={resource.is_started()}") + + if not resource.is_started(): + if resource.name not in self.launching_tasks: + task = asyncio.create_task(self.launch_resource(resource)) + self.launching_tasks[resource.name] = task + task.add_done_callback(lambda t: self.launching_tasks.pop(resource.name, None)) + elif resource.is_terminal(): + if resource.name in self.completion_events: + self.completion_events[resource.name].set() # Signal completion + await self.informer.remove_resource(resource.name) + + async def get_resource_status(self) -> Tuple[Set[str], Set[str]]: + """Return current set of launched (running) and waiting-to-be-launched resources""" + async with self.informer._lock: + launched = {name for name, res in self.informer.resources.items() if res.is_started()} + waiting = {name for name, res in self.informer.resources.items() if not res.is_started()} + return launched, waiting + + async def _log_resource_stats(self): + """Periodically log resource stats if debug is enabled""" + while self.running: + if logger.isEnabledFor(logging.DEBUG): + launched, waiting = await self.get_resource_status() + logger.debug(f"Resource stats: Launched={launched}, Waiting={waiting}") + await asyncio.sleep(2.0) + + async def run(self): + """Run loop with resource status logging""" + if self.running: + logger.warning("Controller already running") + return + + print("Controller running") + self.running = True + await self.informer.start() + asyncio.create_task(self._log_resource_stats()) + + while self.running: + try: + item = await self.shared_queue.get() + + if isinstance(item, str): + logger.info(f"Received resource name: {item}") + async with self.informer._lock: + if item in self.informer.resources: + await self.process_resource(self.informer.resources[item]) + else: + await self.process_resource(item) + + self.shared_queue.task_done() + + except Exception as e: + logger.error(f"Error in controller loop: {e}") + await asyncio.sleep(1.0) + + def _run(self, loop): + """Run the controller in the background""" + loop.run_until_complete(self.run()) + + def start(self): + """Synchronously start the controller in the background.""" + print("Starting controller") + thread = threading.Thread(target=self._run, kwargs={"loop" :asyncio.new_event_loop()}, daemon=True) + thread.start() + + + async def stop(self): + """Stop the controller""" + self.running = False + for task in self.launching_tasks.values(): + task.cancel() + self.launching_tasks.clear() + for event in self.completion_events.values(): + event.set() # Unblock any waiting submit calls + self.completion_events.clear() + await self.informer.stop() + +# ----------------- DEMO ----------------- + +class PollingResource: + def __init__(self, name: str, data: Dict[str, Any] = None): + self.name = name + self.data = data or {"status": "initial"} + self.last_updated = datetime.now() + self._started = False + + async def sync(self) -> 'PollingResource': + await asyncio.sleep(0.1) + current_status = self.data.get("status", "initial") + if current_status == "initial": + new_status = "running" if self._started else "initial" + elif current_status == "running": + new_status = "completed" + else: + new_status = current_status + + return PollingResource( + name=self.name, + data={"status": new_status, "timestamp": str(datetime.now())} + ) + + def is_terminal(self) -> bool: + return self.data.get("status") == "completed" + + def is_started(self) -> bool: + return self._started + + async def launch(self) -> bool: + await asyncio.sleep(0.1) + if not self._started and self.data["status"] == "initial": + self._started = True + self.data["status"] = "running" + self.last_updated = datetime.now() + return True + return False + +class WatchResource: + def __init__(self, name: str, data: Dict[str, Any] = None): + self.name = name + self.data = data or {"phase": "pending"} + self.last_updated = datetime.now() + self._started = False + + async def sync(self) -> 'WatchResource': + raise NotImplementedError("Sync not supported for watch-based resource") + + def is_terminal(self) -> bool: + return self.data.get("phase") == "done" + + def is_started(self) -> bool: + return self._started + + async def launch(self) -> bool: + await asyncio.sleep(0.1) + if not self._started and self.data["phase"] == "pending": + if time.time() % 2 < 1: # Fail half the time + return False + self._started = True + self.data["phase"] = "active" + self.last_updated = datetime.now() + return True + return False + +class WatchInformer(Informer[WatchResource]): + async def watch(self): + """Simulated watch implementation for all resources""" + while self.running: + await asyncio.sleep(1.0) + async with self._lock: + resources = list(self.resources.items()) + + for name, resource in resources: + if resource.is_started(): + states = ["active", "done"] + current_idx = states.index(resource.data["phase"]) if resource.data["phase"] in states else -1 + if current_idx < len(states) - 1: + current_idx += 1 + updated = WatchResource( + name=name, + data={"phase": states[current_idx], "time": str(datetime.now())} + ) + updated._started = True + await self._on_resource_update(updated) + +async def main(): + # Enable debug logging + logger.setLevel(logging.DEBUG) + + # Polling-based controller + poll_controller = Controller.for_sync_resource(max_concurrent_launches=2) + + # Watch-based controller + watch_queue = Queue() + watch_informer = WatchInformer(watch_queue, use_watch=True) + watch_controller = Controller.for_watchable_resource(watch_informer, max_concurrent_launches=2) + + # Start both controllers + poll_task = asyncio.create_task(poll_controller.run()) + watch_task = asyncio.create_task(watch_controller.run()) + + # Submit resources and await completion + async def submit_and_log(controller: Controller, resource: ResourceProtocol): + logger.info(f"Submitting {resource.name}") + final_resource = await controller.submit_resource(resource) + logger.info(f"Completed {resource.name}: {final_resource.data if final_resource else 'None'}") + + await asyncio.gather( + submit_and_log(poll_controller, PollingResource("poll-1")), + submit_and_log(watch_controller, WatchResource("watch-1")), + submit_and_log(poll_controller, PollingResource("poll-2")), + submit_and_log(watch_controller, WatchResource("watch-2")), + submit_and_log(poll_controller, PollingResource("poll-3")), + submit_and_log(watch_controller, WatchResource("watch-3")) + ) + + # Cleanup + await poll_controller.stop() + await watch_controller.stop() + poll_task.cancel() + watch_task.cancel() + try: + await asyncio.gather(poll_task, watch_task) + except asyncio.CancelledError: + pass + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/flytekit/core/eager_controller.py b/flytekit/core/eager_controller.py new file mode 100644 index 0000000000..b03ebe6426 --- /dev/null +++ b/flytekit/core/eager_controller.py @@ -0,0 +1,104 @@ +import asyncio +import hashlib +import typing +from dataclasses import dataclass +from datetime import datetime +from functools import cached_property +from typing import Optional, Any, Dict + +from flytekit.core.launch_plan import LaunchPlan +from flytekit.core.base_task import PythonTask +from flytekit.core.controller import Controller, ResourceProtocol, Informer +from flytekit.core.reference_entity import ReferenceEntity +from flytekit.core.workflow import WorkflowBase + +RunnableEntity = typing.Union[WorkflowBase, LaunchPlan, PythonTask, ReferenceEntity] + + +@dataclass +class LocalNode(ResourceProtocol): + root_exec_id: str + entity: RunnableEntity + last_updated: datetime = None + input_kwargs: Optional[Dict[str, Any]] = None + result: Any = None + node_group: Optional[str] = None + error: Optional[BaseException] = None + _task: Optional[asyncio.Task] = None # Private field to store the background task + name: str = None + + def __post_init__(self): + self.last_updated = datetime.now() + if self.input_kwargs is None: + self.input_kwargs = {} + self.name = self.key # Cache the key + + async def _run_entity(self): + """Internal async method to run the entity and store results/errors""" + try: + # Assuming RunnableEntity.__call__ is synchronous + # If it's async, remove the await asyncio.to_thread() + result = await asyncio.to_thread(self.entity.__call__, **self.input_kwargs) + self.result = result + self.error = None + except Exception as e: + self.result = None + self.error = e + + async def sync(self) -> 'LocalNode': + """ + Non-blocking sync that always returns self. + Checks task status and updates results if complete. + """ + return self + + async def launch(self) -> bool: + """ + Launch the RunnableEntity in the background + """ + if self._task is None: + self._task = asyncio.create_task(self._run_entity()) + return True + + def is_terminal(self) -> bool: + """ + Returns True if the node has completed (successfully or with error) + """ + if self._task is None: + return False + return self._task.done() + + @cached_property + def key(self) -> str: + """Make a deterministic name""" + components = f"{self.root_exec_id}-{self.name}-{self.input_kwargs}" + ( + f"-{self.node_group}" if self.node_group else "") + + # has the components into something deterministic + hex = hashlib.md5(components.encode()).hexdigest() + exec_name = f"{hex}" + return exec_name + + +class EagerController(Controller[LocalNode]): + + def __init__(self, root_exec_id: str): + poll_queue = asyncio.Queue() + poll_informer = Informer[LocalNode](poll_queue, sync_period=2.0, use_watch=False) + super(EagerController, self).__init__(informer=poll_informer, shared_queue=poll_queue, + max_concurrent_launches=2) + self.root_exec_id = root_exec_id + + async def _process_resource(self, resource: LocalNode): + """Process resource updates - can be overridden by subclasses""" + pass + + async def submit_node(self, entity: RunnableEntity, kwargs) -> Any: + """ + Submit a node to the controller + """ + node = LocalNode(root_exec_id=self.root_exec_id, entity=entity, input_kwargs=kwargs) + n = await self.submit_resource(node) + if n.error: + raise n.error + return n.result diff --git a/flytekit/core/history_store.py b/flytekit/core/history_store.py new file mode 100644 index 0000000000..19f9e4db2e --- /dev/null +++ b/flytekit/core/history_store.py @@ -0,0 +1,97 @@ +import hashlib +import shelve +import typing +from dataclasses import dataclass +from enum import Enum + +import grpc + +from flytekit.core.store.graph_pb2 import CreateNodeRequest, Node, Status, ExecutionID, UpdateNodeStatusRequest, NodeID, \ + CreateExecutionRequest, TaskID +from flytekit.core.store.graph_pb2_grpc import GraphServiceStub +from google.protobuf import timestamp_pb2 + + +def _calc_key(exec_id: str, name: str, input_kwargs: dict[str, typing.Any], node_group: str | None = None) -> str: + """Make a deterministic name""" + components = f"{exec_id}-{name}-{input_kwargs}" + (f"-{node_group}" if node_group else "") + + # has the components into something deterministic + hex = hashlib.md5(components.encode()).hexdigest() + exec_name = f"{hex}" + return exec_name + + +class ItemStatus(Enum): + PENDING = "Pending" + RUNNING = "Running" + SUCCESS = "Success" + FAILED = "Failed" + + +@dataclass +class Data: + input_kwargs: dict[str, typing.Any] + result: typing.Any = None + error: typing.Optional[BaseException] = None + status: ItemStatus = ItemStatus.PENDING + + +class Store: + + def __init__(self, exec_id: str = "testing2"): + self._store = shelve.open(f"{exec_id}.db") + self._client = GraphServiceStub(grpc.insecure_channel('localhost:8080')) + self._exec_id = exec_id + self._create_execution() + + def _create_execution(self): + print(f"Creating execution {self._exec_id}") + self._client.CreateExecution( + CreateExecutionRequest( + execution_id=ExecutionID(id=self._exec_id), + task_id=TaskID(id="root"), + ) + ) + + def add(self, name: str, value: Data, parent_key: str | None = None, + node_group: str | None = None): + print(f"Adding node {name} to store") + key = _calc_key(self._exec_id, name, value.input_kwargs, node_group) + n = Node( + id=key, + name=name, + start_time=timestamp_pb2.Timestamp().GetCurrentTime(), + status=Status.QUEUED, + parent_node_id=parent_key if parent_key else None, + ) + req = CreateNodeRequest( + execution_id=ExecutionID(id=self._exec_id), + node=n, + parent_node_id=parent_key if parent_key else None, + node_group=node_group if node_group else None, + ) + res = self._client.CreateNode(req) + self._store[key] = value + + def has(self, name: str, input_kwargs: dict[str, typing.Any], node_group: str | None = None) -> bool: + key = _calc_key(self._exec_id, name, input_kwargs, node_group) + return key in self._store + + def get(self, name: str, input_kwargs: dict[str, typing.Any], node_group: str | None = None) -> Data: + key = _calc_key(self._exec_id, name, input_kwargs, node_group) + return self._store[key] + + def update(self, name: str, value: Data, status: Status, + node_group: str | None = None, error: str | None = None): + print(f"Updating node {name} to store") + key = _calc_key(self._exec_id, name, value.input_kwargs, node_group) + self._client.UpdateNodeStatus( + UpdateNodeStatusRequest( + execution_id=ExecutionID(id=self._exec_id), + node_id=NodeID(key), + status=status, + error_message=error if error else None, + ) + ) + self._store[key] = value diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index a9b8dd284b..c17e8698ed 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1435,7 +1435,8 @@ async def async_flyte_entity_call_handler( # for both nested eager, async, and sync tasks, submit to the informer. if not ctx.worker_queue: raise AssertionError("Worker queue missing, must be set when trying to execute tasks in an eager workflow") - result = await ctx.worker_queue.add(entity, input_kwargs=kwargs) + print(f"Submitting node {entity.name} to worker queue") + result = await ctx.worker_queue.submit_node(entity, kwargs) return result # eager local execution, and all other call patterns are handled by the sync version diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 864934da9f..d5de395a1a 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -17,6 +17,7 @@ from __future__ import annotations +import asyncio import inspect import os import signal @@ -34,6 +35,7 @@ from flytekit.core.constants import EAGER_ROOT_ENV_NAME from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring +from flytekit.core.eager_controller import EagerController from flytekit.core.interface import Interface, transform_function_to_interface from flytekit.core.promise import ( Promise, @@ -516,7 +518,7 @@ async def async_execute(self, *args, **kwargs) -> Any: # Args is present because the asyn helper function passes it, but everything should be in kwargs by this point assert len(args) == 1 ctx = FlyteContextManager.current_context() - is_local_execution = cast(ExecutionState, ctx.execution_state).is_local_execution() + is_local_execution = False # cast(ExecutionState, ctx.execution_state).is_local_execution() if not is_local_execution: # a real execution return await self.run_with_backend(**kwargs) @@ -530,9 +532,10 @@ async def async_execute(self, *args, **kwargs) -> Any: def execute(self, **kwargs) -> Any: ctx = FlyteContextManager.current_context() - is_local_execution = cast(ExecutionState, ctx.execution_state).is_local_execution() + is_local_execution = False # cast(ExecutionState, ctx.execution_state).is_local_execution() builder = ctx.new_builder() if not is_local_execution: + print("Executing remote!") # ensure that the worker queue is in context if not ctx.worker_queue: from flytekit.configuration.plugin import get_plugin @@ -578,11 +581,15 @@ def execute(self, **kwargs) -> Any: # Note: The construction of this object is in this function because this function should be on the # main thread of pyflyte-execute. It needs to be on the main thread because signal handlers can only # be installed on the main thread. - c = Controller(remote=remote, ss=ss, tag=tag, root_tag=root_tag, exec_prefix=prefix) - handler = c.get_signal_handler() - signal.signal(signal.SIGINT, handler) - signal.signal(signal.SIGTERM, handler) + # c = Controller(remote=remote, ss=ss, tag=tag, root_tag=root_tag, exec_prefix=prefix) + # handler = c.get_signal_handler() + # signal.signal(signal.SIGINT, handler) + # signal.signal(signal.SIGTERM, handler) + # builder = ctx.with_worker_queue(c) + c = EagerController(prefix) builder = ctx.with_worker_queue(c) + c.start() + print("Added Controller!") else: raise AssertionError("Worker queue should not be already present in the context for eager execution") with FlyteContextManager.with_context(builder): @@ -613,6 +620,7 @@ async def run_with_backend(self, **kwargs): Deck("Eager Executions", html) if base_error: + print(base_error.with_traceback(None)) # now have to fail this eager task, because we don't want it to show up as succeeded. raise FlyteNonRecoverableSystemException(base_error) return result diff --git a/flytekit/core/store/__init__.py b/flytekit/core/store/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/core/store/graph_pb2.py b/flytekit/core/store/graph_pb2.py new file mode 100644 index 0000000000..34799aa31b --- /dev/null +++ b/flytekit/core/store/graph_pb2.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: proto/graph.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/graph.proto\x12\x08graph.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"\x18\n\x06TaskID\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"\x1d\n\x0b\x45xecutionID\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"\x18\n\x06NodeID\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"\xaa\x01\n\x04Task\x12 \n\x02id\x18\x01 \x01(\x0b\x32\x10.graph.v1.TaskIDR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x18\n\x07\x63ommand\x18\x04 \x01(\tR\x07\x63ommand\x12\x16\n\x06inputs\x18\x05 \x03(\tR\x06inputs\x12\x18\n\x07outputs\x18\x06 \x03(\tR\x07outputs\"\xe5\x01\n\rExecutionInfo\x12\x38\n\x0c\x65xecution_id\x18\x01 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\x12(\n\x06status\x18\x02 \x01(\x0e\x32\x10.graph.v1.StatusR\x06status\x12\x39\n\nstart_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\tstartTime\x12\x35\n\x08\x65nd_time\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\x07\x65ndTime\"\xeb\x02\n\x04Node\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12(\n\x06status\x18\x03 \x01(\x0e\x32\x10.graph.v1.StatusR\x06status\x12\x39\n\nstart_time\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\tstartTime\x12\x35\n\x08\x65nd_time\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\x07\x65ndTime\x12$\n\x0eparent_node_id\x18\x06 \x01(\tR\x0cparentNodeId\x12>\n\nproperties\x18\x07 \x03(\x0b\x32\x1e.graph.v1.Node.PropertiesEntryR\nproperties\x1a=\n\x0fPropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xb5\x01\n\x04\x45\x64ge\x12\x16\n\x06source\x18\x01 \x01(\tR\x06source\x12\x16\n\x06target\x18\x02 \x01(\tR\x06target\x12>\n\nproperties\x18\x03 \x03(\x0b\x32\x1e.graph.v1.Edge.PropertiesEntryR\nproperties\x1a=\n\x0fPropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"V\n\x08NodeInfo\x12$\n\x05nodes\x18\x01 \x03(\x0b\x32\x0e.graph.v1.NodeR\x05nodes\x12$\n\x05\x65\x64ges\x18\x02 \x03(\x0b\x32\x0e.graph.v1.EdgeR\x05\x65\x64ges\"}\n\x16\x43reateExecutionRequest\x12)\n\x07task_id\x18\x01 \x01(\x0b\x32\x10.graph.v1.TaskIDR\x06taskId\x12\x38\n\x0c\x65xecution_id\x18\x02 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\"\xe2\x01\n\x11\x43reateNodeRequest\x12\x38\n\x0c\x65xecution_id\x18\x01 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\x12\"\n\x04node\x18\x02 \x01(\x0b\x32\x0e.graph.v1.NodeR\x04node\x12)\n\x0eparent_node_id\x18\x03 \x01(\tH\x00R\x0cparentNodeId\x88\x01\x01\x12\"\n\nnode_group\x18\x04 \x01(\tH\x01R\tnodeGroup\x88\x01\x01\x42\x11\n\x0f_parent_node_idB\r\n\x0b_node_group\"\x18\n\x06Worker\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"\x91\x02\n\x0c\x45rrorMessage\x12\x12\n\x04\x63ode\x18\x01 \x01(\tR\x04\x63ode\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12?\n\nerror_kind\x18\x05 \x01(\x0e\x32 .graph.v1.ErrorMessage.ErrorKindR\terrorKind\x12\x38\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12(\n\x06worker\x18\x07 \x01(\x0b\x32\x10.graph.v1.WorkerR\x06worker\".\n\tErrorKind\x12\x0b\n\x07UNKNOWN\x10\x00\x12\n\n\x06SYSTEM\x10\x01\x12\x08\n\x04USER\x10\x02\"\xf4\x02\n\x17UpdateNodeStatusRequest\x12\x38\n\x0c\x65xecution_id\x18\x01 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\x12)\n\x07node_id\x18\x02 \x01(\x0b\x32\x10.graph.v1.NodeIDR\x06nodeId\x12(\n\x06status\x18\x03 \x01(\x0e\x32\x10.graph.v1.StatusR\x06status\x12;\n\x0eparent_node_id\x18\x04 \x01(\x0b\x32\x10.graph.v1.NodeIDH\x00R\x0cparentNodeId\x88\x01\x01\x12@\n\rerror_message\x18\x05 \x01(\x0b\x32\x16.graph.v1.ErrorMessageH\x01R\x0c\x65rrorMessage\x88\x01\x01\x12\x1b\n\x06reason\x18\x06 \x01(\tH\x02R\x06reason\x88\x01\x01\x42\x11\n\x0f_parent_node_idB\x10\n\x0e_error_messageB\t\n\x07_reason\"\x89\x01\n\x19WatchAllExecutionsRequest\x12.\n\x07task_id\x18\x01 \x01(\x0b\x32\x10.graph.v1.TaskIDH\x00R\x06taskId\x88\x01\x01\x12\x30\n\x05\x61\x66ter\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\x05\x61\x66terB\n\n\x08_task_id\"Y\n\x1dWatchNodesForExecutionRequest\x12\x38\n\x0c\x65xecution_id\x18\x01 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\"\xfe\x01\n\x1cUpdateExecutionStatusRequest\x12\x38\n\x0c\x65xecution_id\x18\x01 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\x12(\n\x06status\x18\x02 \x01(\x0e\x32\x10.graph.v1.StatusR\x06status\x12@\n\rerror_message\x18\x03 \x01(\x0b\x32\x16.graph.v1.ErrorMessageH\x00R\x0c\x65rrorMessage\x88\x01\x01\x12\x1b\n\x06reason\x18\x04 \x01(\tH\x01R\x06reason\x88\x01\x01\x42\x10\n\x0e_error_messageB\t\n\x07_reason\"|\n\x15GetNodeDetailsRequest\x12\x38\n\x0c\x65xecution_id\x18\x01 \x01(\x0b\x32\x15.graph.v1.ExecutionIDR\x0b\x65xecutionId\x12)\n\x07node_id\x18\x02 \x01(\x0b\x32\x10.graph.v1.NodeIDR\x06nodeId\"\x95\x04\n\x0bNodeDetails\x12\"\n\x04node\x18\x01 \x01(\x0b\x32\x0e.graph.v1.NodeR\x04node\x12(\n\x06status\x18\x02 \x01(\x0e\x32\x10.graph.v1.StatusR\x06status\x12@\n\rerror_message\x18\x03 \x01(\x0b\x32\x16.graph.v1.ErrorMessageH\x00R\x0c\x65rrorMessage\x88\x01\x01\x12\x1b\n\x06reason\x18\x04 \x01(\tH\x01R\x06reason\x88\x01\x01\x12\x39\n\nstart_time\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\tstartTime\x12:\n\x08\x65nd_time\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x02R\x07\x65ndTime\x88\x01\x01\x12)\n\x07task_id\x18\x07 \x01(\x0b\x32\x10.graph.v1.TaskIDR\x06taskId\x12L\n\rother_details\x18\x08 \x03(\x0b\x32\'.graph.v1.NodeDetails.OtherDetailsEntryR\x0cotherDetails\x1a?\n\x11OtherDetailsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x10\n\x0e_error_messageB\t\n\x07_reasonB\x0b\n\t_end_time*\xa1\x01\n\x06Status\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\x0e\n\nSUCCEEDING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\x0b\n\x07\x46\x41ILING\x10\x05\x12\n\n\x06\x46\x41ILED\x10\x06\x12\x0b\n\x07\x41\x42ORTED\x10\x07\x12\r\n\tTIMED_OUT\x10\x08\x12\r\n\tRECOVERED\x10\t\x12\x0c\n\x08\x41\x42ORTING\x10\n2\xa0\x05\n\x0cGraphService\x12V\n\x12WatchAllExecutions\x12#.graph.v1.WatchAllExecutionsRequest\x1a\x17.graph.v1.ExecutionInfo\"\x00\x30\x01\x12Y\n\x16WatchNodesForExecution\x12\'.graph.v1.WatchNodesForExecutionRequest\x1a\x12.graph.v1.NodeInfo\"\x00\x30\x01\x12\x30\n\nCreateTask\x12\x0e.graph.v1.Task\x1a\x10.graph.v1.TaskID\"\x00\x12-\n\x07GetTask\x12\x10.graph.v1.TaskID\x1a\x0e.graph.v1.Task\"\x00\x12L\n\x0f\x43reateExecution\x12 .graph.v1.CreateExecutionRequest\x1a\x15.graph.v1.ExecutionID\"\x00\x12=\n\nCreateNode\x12\x1b.graph.v1.CreateNodeRequest\x1a\x10.graph.v1.NodeID\"\x00\x12G\n\x10UpdateNodeStatus\x12!.graph.v1.UpdateNodeStatusRequest\x1a\x0e.graph.v1.Node\"\x00\x12Z\n\x15UpdateExecutionStatus\x12&.graph.v1.UpdateExecutionStatusRequest\x1a\x17.graph.v1.ExecutionInfo\"\x00\x12J\n\x0eGetNodeDetails\x12\x1f.graph.v1.GetNodeDetailsRequest\x1a\x15.graph.v1.NodeDetails\"\x00\x42\x14Z\x12graph-viewer/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.graph_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'Z\022graph-viewer/proto' + _NODE_PROPERTIESENTRY._options = None + _NODE_PROPERTIESENTRY._serialized_options = b'8\001' + _EDGE_PROPERTIESENTRY._options = None + _EDGE_PROPERTIESENTRY._serialized_options = b'8\001' + _NODEDETAILS_OTHERDETAILSENTRY._options = None + _NODEDETAILS_OTHERDETAILSENTRY._serialized_options = b'8\001' + _globals['_STATUS']._serialized_start=3374 + _globals['_STATUS']._serialized_end=3535 + _globals['_TASKID']._serialized_start=64 + _globals['_TASKID']._serialized_end=88 + _globals['_EXECUTIONID']._serialized_start=90 + _globals['_EXECUTIONID']._serialized_end=119 + _globals['_NODEID']._serialized_start=121 + _globals['_NODEID']._serialized_end=145 + _globals['_TASK']._serialized_start=148 + _globals['_TASK']._serialized_end=318 + _globals['_EXECUTIONINFO']._serialized_start=321 + _globals['_EXECUTIONINFO']._serialized_end=550 + _globals['_NODE']._serialized_start=553 + _globals['_NODE']._serialized_end=916 + _globals['_NODE_PROPERTIESENTRY']._serialized_start=855 + _globals['_NODE_PROPERTIESENTRY']._serialized_end=916 + _globals['_EDGE']._serialized_start=919 + _globals['_EDGE']._serialized_end=1100 + _globals['_EDGE_PROPERTIESENTRY']._serialized_start=855 + _globals['_EDGE_PROPERTIESENTRY']._serialized_end=916 + _globals['_NODEINFO']._serialized_start=1102 + _globals['_NODEINFO']._serialized_end=1188 + _globals['_CREATEEXECUTIONREQUEST']._serialized_start=1190 + _globals['_CREATEEXECUTIONREQUEST']._serialized_end=1315 + _globals['_CREATENODEREQUEST']._serialized_start=1318 + _globals['_CREATENODEREQUEST']._serialized_end=1544 + _globals['_WORKER']._serialized_start=1546 + _globals['_WORKER']._serialized_end=1570 + _globals['_ERRORMESSAGE']._serialized_start=1573 + _globals['_ERRORMESSAGE']._serialized_end=1846 + _globals['_ERRORMESSAGE_ERRORKIND']._serialized_start=1800 + _globals['_ERRORMESSAGE_ERRORKIND']._serialized_end=1846 + _globals['_UPDATENODESTATUSREQUEST']._serialized_start=1849 + _globals['_UPDATENODESTATUSREQUEST']._serialized_end=2221 + _globals['_WATCHALLEXECUTIONSREQUEST']._serialized_start=2224 + _globals['_WATCHALLEXECUTIONSREQUEST']._serialized_end=2361 + _globals['_WATCHNODESFOREXECUTIONREQUEST']._serialized_start=2363 + _globals['_WATCHNODESFOREXECUTIONREQUEST']._serialized_end=2452 + _globals['_UPDATEEXECUTIONSTATUSREQUEST']._serialized_start=2455 + _globals['_UPDATEEXECUTIONSTATUSREQUEST']._serialized_end=2709 + _globals['_GETNODEDETAILSREQUEST']._serialized_start=2711 + _globals['_GETNODEDETAILSREQUEST']._serialized_end=2835 + _globals['_NODEDETAILS']._serialized_start=2838 + _globals['_NODEDETAILS']._serialized_end=3371 + _globals['_NODEDETAILS_OTHERDETAILSENTRY']._serialized_start=3266 + _globals['_NODEDETAILS_OTHERDETAILSENTRY']._serialized_end=3329 + _globals['_GRAPHSERVICE']._serialized_start=3538 + _globals['_GRAPHSERVICE']._serialized_end=4210 +# @@protoc_insertion_point(module_scope) diff --git a/flytekit/core/store/graph_pb2.pyi b/flytekit/core/store/graph_pb2.pyi new file mode 100644 index 0000000000..6b199fae6c --- /dev/null +++ b/flytekit/core/store/graph_pb2.pyi @@ -0,0 +1,254 @@ +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Status(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + UNDEFINED: _ClassVar[Status] + QUEUED: _ClassVar[Status] + RUNNING: _ClassVar[Status] + SUCCEEDING: _ClassVar[Status] + SUCCEEDED: _ClassVar[Status] + FAILING: _ClassVar[Status] + FAILED: _ClassVar[Status] + ABORTED: _ClassVar[Status] + TIMED_OUT: _ClassVar[Status] + RECOVERED: _ClassVar[Status] + ABORTING: _ClassVar[Status] +UNDEFINED: Status +QUEUED: Status +RUNNING: Status +SUCCEEDING: Status +SUCCEEDED: Status +FAILING: Status +FAILED: Status +ABORTED: Status +TIMED_OUT: Status +RECOVERED: Status +ABORTING: Status + +class TaskID(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class ExecutionID(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class NodeID(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class Task(_message.Message): + __slots__ = ["id", "name", "description", "command", "inputs", "outputs"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + COMMAND_FIELD_NUMBER: _ClassVar[int] + INPUTS_FIELD_NUMBER: _ClassVar[int] + OUTPUTS_FIELD_NUMBER: _ClassVar[int] + id: TaskID + name: str + description: str + command: str + inputs: _containers.RepeatedScalarFieldContainer[str] + outputs: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, id: _Optional[_Union[TaskID, _Mapping]] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., command: _Optional[str] = ..., inputs: _Optional[_Iterable[str]] = ..., outputs: _Optional[_Iterable[str]] = ...) -> None: ... + +class ExecutionInfo(_message.Message): + __slots__ = ["execution_id", "status", "start_time", "end_time"] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + START_TIME_FIELD_NUMBER: _ClassVar[int] + END_TIME_FIELD_NUMBER: _ClassVar[int] + execution_id: ExecutionID + status: Status + start_time: _timestamp_pb2.Timestamp + end_time: _timestamp_pb2.Timestamp + def __init__(self, execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ..., status: _Optional[_Union[Status, str]] = ..., start_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., end_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class Node(_message.Message): + __slots__ = ["id", "name", "status", "start_time", "end_time", "parent_node_id", "properties"] + class PropertiesEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + START_TIME_FIELD_NUMBER: _ClassVar[int] + END_TIME_FIELD_NUMBER: _ClassVar[int] + PARENT_NODE_ID_FIELD_NUMBER: _ClassVar[int] + PROPERTIES_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + status: Status + start_time: _timestamp_pb2.Timestamp + end_time: _timestamp_pb2.Timestamp + parent_node_id: str + properties: _containers.ScalarMap[str, str] + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., status: _Optional[_Union[Status, str]] = ..., start_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., end_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., parent_node_id: _Optional[str] = ..., properties: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class Edge(_message.Message): + __slots__ = ["source", "target", "properties"] + class PropertiesEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + SOURCE_FIELD_NUMBER: _ClassVar[int] + TARGET_FIELD_NUMBER: _ClassVar[int] + PROPERTIES_FIELD_NUMBER: _ClassVar[int] + source: str + target: str + properties: _containers.ScalarMap[str, str] + def __init__(self, source: _Optional[str] = ..., target: _Optional[str] = ..., properties: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class NodeInfo(_message.Message): + __slots__ = ["nodes", "edges"] + NODES_FIELD_NUMBER: _ClassVar[int] + EDGES_FIELD_NUMBER: _ClassVar[int] + nodes: _containers.RepeatedCompositeFieldContainer[Node] + edges: _containers.RepeatedCompositeFieldContainer[Edge] + def __init__(self, nodes: _Optional[_Iterable[_Union[Node, _Mapping]]] = ..., edges: _Optional[_Iterable[_Union[Edge, _Mapping]]] = ...) -> None: ... + +class CreateExecutionRequest(_message.Message): + __slots__ = ["task_id", "execution_id"] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + task_id: TaskID + execution_id: ExecutionID + def __init__(self, task_id: _Optional[_Union[TaskID, _Mapping]] = ..., execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ...) -> None: ... + +class CreateNodeRequest(_message.Message): + __slots__ = ["execution_id", "node", "parent_node_id", "node_group"] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + NODE_FIELD_NUMBER: _ClassVar[int] + PARENT_NODE_ID_FIELD_NUMBER: _ClassVar[int] + NODE_GROUP_FIELD_NUMBER: _ClassVar[int] + execution_id: ExecutionID + node: Node + parent_node_id: str + node_group: str + def __init__(self, execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ..., node: _Optional[_Union[Node, _Mapping]] = ..., parent_node_id: _Optional[str] = ..., node_group: _Optional[str] = ...) -> None: ... + +class Worker(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class ErrorMessage(_message.Message): + __slots__ = ["code", "message", "error_kind", "timestamp", "worker"] + class ErrorKind(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + UNKNOWN: _ClassVar[ErrorMessage.ErrorKind] + SYSTEM: _ClassVar[ErrorMessage.ErrorKind] + USER: _ClassVar[ErrorMessage.ErrorKind] + UNKNOWN: ErrorMessage.ErrorKind + SYSTEM: ErrorMessage.ErrorKind + USER: ErrorMessage.ErrorKind + CODE_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + ERROR_KIND_FIELD_NUMBER: _ClassVar[int] + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + WORKER_FIELD_NUMBER: _ClassVar[int] + code: str + message: str + error_kind: ErrorMessage.ErrorKind + timestamp: _timestamp_pb2.Timestamp + worker: Worker + def __init__(self, code: _Optional[str] = ..., message: _Optional[str] = ..., error_kind: _Optional[_Union[ErrorMessage.ErrorKind, str]] = ..., timestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., worker: _Optional[_Union[Worker, _Mapping]] = ...) -> None: ... + +class UpdateNodeStatusRequest(_message.Message): + __slots__ = ["execution_id", "node_id", "status", "parent_node_id", "error_message", "reason"] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + NODE_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + PARENT_NODE_ID_FIELD_NUMBER: _ClassVar[int] + ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int] + REASON_FIELD_NUMBER: _ClassVar[int] + execution_id: ExecutionID + node_id: NodeID + status: Status + parent_node_id: NodeID + error_message: ErrorMessage + reason: str + def __init__(self, execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ..., node_id: _Optional[_Union[NodeID, _Mapping]] = ..., status: _Optional[_Union[Status, str]] = ..., parent_node_id: _Optional[_Union[NodeID, _Mapping]] = ..., error_message: _Optional[_Union[ErrorMessage, _Mapping]] = ..., reason: _Optional[str] = ...) -> None: ... + +class WatchAllExecutionsRequest(_message.Message): + __slots__ = ["task_id", "after"] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + AFTER_FIELD_NUMBER: _ClassVar[int] + task_id: TaskID + after: _timestamp_pb2.Timestamp + def __init__(self, task_id: _Optional[_Union[TaskID, _Mapping]] = ..., after: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class WatchNodesForExecutionRequest(_message.Message): + __slots__ = ["execution_id"] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + execution_id: ExecutionID + def __init__(self, execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ...) -> None: ... + +class UpdateExecutionStatusRequest(_message.Message): + __slots__ = ["execution_id", "status", "error_message", "reason"] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int] + REASON_FIELD_NUMBER: _ClassVar[int] + execution_id: ExecutionID + status: Status + error_message: ErrorMessage + reason: str + def __init__(self, execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ..., status: _Optional[_Union[Status, str]] = ..., error_message: _Optional[_Union[ErrorMessage, _Mapping]] = ..., reason: _Optional[str] = ...) -> None: ... + +class GetNodeDetailsRequest(_message.Message): + __slots__ = ["execution_id", "node_id"] + EXECUTION_ID_FIELD_NUMBER: _ClassVar[int] + NODE_ID_FIELD_NUMBER: _ClassVar[int] + execution_id: ExecutionID + node_id: NodeID + def __init__(self, execution_id: _Optional[_Union[ExecutionID, _Mapping]] = ..., node_id: _Optional[_Union[NodeID, _Mapping]] = ...) -> None: ... + +class NodeDetails(_message.Message): + __slots__ = ["node", "status", "error_message", "reason", "start_time", "end_time", "task_id", "other_details"] + class OtherDetailsEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + NODE_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int] + REASON_FIELD_NUMBER: _ClassVar[int] + START_TIME_FIELD_NUMBER: _ClassVar[int] + END_TIME_FIELD_NUMBER: _ClassVar[int] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + OTHER_DETAILS_FIELD_NUMBER: _ClassVar[int] + node: Node + status: Status + error_message: ErrorMessage + reason: str + start_time: _timestamp_pb2.Timestamp + end_time: _timestamp_pb2.Timestamp + task_id: TaskID + other_details: _containers.ScalarMap[str, str] + def __init__(self, node: _Optional[_Union[Node, _Mapping]] = ..., status: _Optional[_Union[Status, str]] = ..., error_message: _Optional[_Union[ErrorMessage, _Mapping]] = ..., reason: _Optional[str] = ..., start_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., end_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., task_id: _Optional[_Union[TaskID, _Mapping]] = ..., other_details: _Optional[_Mapping[str, str]] = ...) -> None: ... diff --git a/flytekit/core/store/graph_pb2_grpc.py b/flytekit/core/store/graph_pb2_grpc.py new file mode 100644 index 0000000000..dc9aaecdb6 --- /dev/null +++ b/flytekit/core/store/graph_pb2_grpc.py @@ -0,0 +1,330 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from flytekit.core.store import graph_pb2 as proto_dot_graph__pb2 + + +class GraphServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.WatchAllExecutions = channel.unary_stream( + '/graph.v1.GraphService/WatchAllExecutions', + request_serializer=proto_dot_graph__pb2.WatchAllExecutionsRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.ExecutionInfo.FromString, + ) + self.WatchNodesForExecution = channel.unary_stream( + '/graph.v1.GraphService/WatchNodesForExecution', + request_serializer=proto_dot_graph__pb2.WatchNodesForExecutionRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.NodeInfo.FromString, + ) + self.CreateTask = channel.unary_unary( + '/graph.v1.GraphService/CreateTask', + request_serializer=proto_dot_graph__pb2.Task.SerializeToString, + response_deserializer=proto_dot_graph__pb2.TaskID.FromString, + ) + self.GetTask = channel.unary_unary( + '/graph.v1.GraphService/GetTask', + request_serializer=proto_dot_graph__pb2.TaskID.SerializeToString, + response_deserializer=proto_dot_graph__pb2.Task.FromString, + ) + self.CreateExecution = channel.unary_unary( + '/graph.v1.GraphService/CreateExecution', + request_serializer=proto_dot_graph__pb2.CreateExecutionRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.ExecutionID.FromString, + ) + self.CreateNode = channel.unary_unary( + '/graph.v1.GraphService/CreateNode', + request_serializer=proto_dot_graph__pb2.CreateNodeRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.NodeID.FromString, + ) + self.UpdateNodeStatus = channel.unary_unary( + '/graph.v1.GraphService/UpdateNodeStatus', + request_serializer=proto_dot_graph__pb2.UpdateNodeStatusRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.Node.FromString, + ) + self.UpdateExecutionStatus = channel.unary_unary( + '/graph.v1.GraphService/UpdateExecutionStatus', + request_serializer=proto_dot_graph__pb2.UpdateExecutionStatusRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.ExecutionInfo.FromString, + ) + self.GetNodeDetails = channel.unary_unary( + '/graph.v1.GraphService/GetNodeDetails', + request_serializer=proto_dot_graph__pb2.GetNodeDetailsRequest.SerializeToString, + response_deserializer=proto_dot_graph__pb2.NodeDetails.FromString, + ) + + +class GraphServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def WatchAllExecutions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def WatchNodesForExecution(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateTask(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTask(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateExecution(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateNode(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateNodeStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateExecutionStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNodeDetails(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GraphServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'WatchAllExecutions': grpc.unary_stream_rpc_method_handler( + servicer.WatchAllExecutions, + request_deserializer=proto_dot_graph__pb2.WatchAllExecutionsRequest.FromString, + response_serializer=proto_dot_graph__pb2.ExecutionInfo.SerializeToString, + ), + 'WatchNodesForExecution': grpc.unary_stream_rpc_method_handler( + servicer.WatchNodesForExecution, + request_deserializer=proto_dot_graph__pb2.WatchNodesForExecutionRequest.FromString, + response_serializer=proto_dot_graph__pb2.NodeInfo.SerializeToString, + ), + 'CreateTask': grpc.unary_unary_rpc_method_handler( + servicer.CreateTask, + request_deserializer=proto_dot_graph__pb2.Task.FromString, + response_serializer=proto_dot_graph__pb2.TaskID.SerializeToString, + ), + 'GetTask': grpc.unary_unary_rpc_method_handler( + servicer.GetTask, + request_deserializer=proto_dot_graph__pb2.TaskID.FromString, + response_serializer=proto_dot_graph__pb2.Task.SerializeToString, + ), + 'CreateExecution': grpc.unary_unary_rpc_method_handler( + servicer.CreateExecution, + request_deserializer=proto_dot_graph__pb2.CreateExecutionRequest.FromString, + response_serializer=proto_dot_graph__pb2.ExecutionID.SerializeToString, + ), + 'CreateNode': grpc.unary_unary_rpc_method_handler( + servicer.CreateNode, + request_deserializer=proto_dot_graph__pb2.CreateNodeRequest.FromString, + response_serializer=proto_dot_graph__pb2.NodeID.SerializeToString, + ), + 'UpdateNodeStatus': grpc.unary_unary_rpc_method_handler( + servicer.UpdateNodeStatus, + request_deserializer=proto_dot_graph__pb2.UpdateNodeStatusRequest.FromString, + response_serializer=proto_dot_graph__pb2.Node.SerializeToString, + ), + 'UpdateExecutionStatus': grpc.unary_unary_rpc_method_handler( + servicer.UpdateExecutionStatus, + request_deserializer=proto_dot_graph__pb2.UpdateExecutionStatusRequest.FromString, + response_serializer=proto_dot_graph__pb2.ExecutionInfo.SerializeToString, + ), + 'GetNodeDetails': grpc.unary_unary_rpc_method_handler( + servicer.GetNodeDetails, + request_deserializer=proto_dot_graph__pb2.GetNodeDetailsRequest.FromString, + response_serializer=proto_dot_graph__pb2.NodeDetails.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'graph.v1.GraphService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class GraphService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def WatchAllExecutions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/graph.v1.GraphService/WatchAllExecutions', + proto_dot_graph__pb2.WatchAllExecutionsRequest.SerializeToString, + proto_dot_graph__pb2.ExecutionInfo.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def WatchNodesForExecution(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/graph.v1.GraphService/WatchNodesForExecution', + proto_dot_graph__pb2.WatchNodesForExecutionRequest.SerializeToString, + proto_dot_graph__pb2.NodeInfo.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/CreateTask', + proto_dot_graph__pb2.Task.SerializeToString, + proto_dot_graph__pb2.TaskID.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/GetTask', + proto_dot_graph__pb2.TaskID.SerializeToString, + proto_dot_graph__pb2.Task.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateExecution(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/CreateExecution', + proto_dot_graph__pb2.CreateExecutionRequest.SerializeToString, + proto_dot_graph__pb2.ExecutionID.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateNode(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/CreateNode', + proto_dot_graph__pb2.CreateNodeRequest.SerializeToString, + proto_dot_graph__pb2.NodeID.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateNodeStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/UpdateNodeStatus', + proto_dot_graph__pb2.UpdateNodeStatusRequest.SerializeToString, + proto_dot_graph__pb2.Node.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateExecutionStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/UpdateExecutionStatus', + proto_dot_graph__pb2.UpdateExecutionStatusRequest.SerializeToString, + proto_dot_graph__pb2.ExecutionInfo.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetNodeDetails(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/graph.v1.GraphService/GetNodeDetails', + proto_dot_graph__pb2.GetNodeDetailsRequest.SerializeToString, + proto_dot_graph__pb2.NodeDetails.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/flytekit/core/worker_queue.py b/flytekit/core/worker_queue.py index 2c0b889a85..2b2f85eb60 100644 --- a/flytekit/core/worker_queue.py +++ b/flytekit/core/worker_queue.py @@ -4,6 +4,7 @@ import atexit import hashlib import re +import shelve import threading import time import typing @@ -12,9 +13,11 @@ from enum import Enum from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.core import history_store from flytekit.core.base_task import PythonTask from flytekit.core.constants import EAGER_ROOT_ENV_NAME, EAGER_TAG_KEY, EAGER_TAG_ROOT_KEY from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.history_store import ItemStatus, Data from flytekit.core.launch_plan import LaunchPlan from flytekit.core.options import Options from flytekit.core.reference_entity import ReferenceEntity @@ -72,13 +75,6 @@ """ -class ItemStatus(Enum): - PENDING = "Pending" - RUNNING = "Running" - SUCCESS = "Success" - FAILED = "Failed" - - @dataclass class Update: # The item to update @@ -89,6 +85,7 @@ class Update: status: typing.Optional[ItemStatus] = None wf_exec: typing.Optional[FlyteWorkflowExecution] = None error: typing.Optional[BaseException] = None + outputs: typing.Optional[typing.Any] = None @dataclass(unsafe_hash=True) @@ -182,6 +179,7 @@ def __init__( self.remote = remote self.ss = ss self.exec_prefix = exec_prefix + self._store = history_store.Store() self.entries_lock = threading.Lock() # Import this to ensure context is loaded... python is reloading this module because its in a different thread @@ -284,6 +282,7 @@ def _apply_updates(self, update_items: typing.Dict[uuid.UUID, Update]) -> None: if update.wf_exec is None: raise AssertionError(f"update's wf_exec missing for {item.entity.name}") item.result = update.wf_exec.outputs.as_python_native(item.python_interface) + self._store.update(entity_name, Data(input_kwargs=item.input_kwargs, result=item.result, status=item.status), history_store.Status.SUCCEEDED) elif update.status == ItemStatus.FAILED: # If update object already has an error, then use that, otherwise look for one in the # execution closure. @@ -302,6 +301,7 @@ def _apply_updates(self, update_items: typing.Dict[uuid.UUID, Update]) -> None: f" {update.wf_exec.closure.error.code}. Message: {update.wf_exec.closure.error.message}" ) item.error = exc + self._store.update(entity_name, Data(item.input_kwargs, error=item.error, status=item.status), history_store.Status.FAILED, str(item.error)) # otherwise it's still pending or running @@ -379,6 +379,7 @@ def launch_execution(self, wi: WorkItem, idx: int) -> FlyteWorkflowExecution: options = Options(labels=l) exec_name = self.get_execution_name(wi.entity, idx, wi.input_kwargs) logger.info(f"Generated execution name {exec_name} for {idx}th call of {wi.entity.name}") + from flytekit.remote.remote_callable import RemoteEntity if isinstance(wi.entity, RemoteEntity): @@ -409,12 +410,21 @@ async def add(self, entity: RunnableEntity, input_kwargs: dict[str, typing.Any]) """ Add an entity along with the requested inputs to be submitted to Admin for running and return a future """ + print(f"Adding {entity.name} with {input_kwargs}") # need to also check to see if the entity has already been registered, and if not, register it. i = WorkItem(entity=entity, input_kwargs=input_kwargs) with self.entries_lock: if entity.name not in self.entries: self.entries[entity.name] = [] + if not self._store.has(entity.name, input_kwargs): + self._store.add(entity.name, Data(input_kwargs), node_group=None) + else: + v = self._store.get(entity.name, input_kwargs) + if v.status == ItemStatus.SUCCESS: + return v.result + elif v.status == ItemStatus.FAILED: + raise v.error self.entries[entity.name].append(i) # wait for it to finish one way or another From 8f3431273d6dbe9d544e3395634d01461854e422 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 12 Mar 2025 22:39:43 -0700 Subject: [PATCH 3/3] working --- flytekit/core/controller.py | 46 +++++++++++++-------------- flytekit/core/eager_controller.py | 11 ++++++- flytekit/core/python_function_task.py | 24 +++++++++++--- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/flytekit/core/controller.py b/flytekit/core/controller.py index 9dd97ecd7d..9fad5ff088 100644 --- a/flytekit/core/controller.py +++ b/flytekit/core/controller.py @@ -7,7 +7,7 @@ import time # Setup basic logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Define the resource protocol @@ -35,6 +35,7 @@ async def launch(self) -> bool: # Type variable for generic resource R = TypeVar('R', bound=ResourceProtocol) + class Informer(Generic[R]): """Base informer with either polling or watch mode""" @@ -141,12 +142,16 @@ def for_sync_resource(cls, max_concurrent_launches: int = 2) -> 'Controller[R]': async def add_resource(self, resource: R): """Public API to add a resource without waiting for completion""" + print(f"{threading.current_thread().name} Adding resource {resource.name}") await self.informer.add_resource(resource) await self.shared_queue.put(resource.name) + print(f"{threading.current_thread().name} Done adding resource {resource.name}") async def submit_resource(self, resource: R) -> R: """Submit a resource and await its completion, returning the final state""" + print(f"{threading.current_thread().name} Submitting resource {resource.name}") async with self.informer._lock: + # TODO change informer to have get and other methods. we should never access resources directly if resource.name in self.informer.resources: raise ValueError(f"Resource {resource.name} already exists") if resource.name in self.completion_events: @@ -156,16 +161,20 @@ async def submit_resource(self, resource: R) -> R: self.completion_events[resource.name] = Event() await self.add_resource(resource) + print(f"{threading.current_thread().name} Waiting for completion of {resource.name}") # Wait for completion await self.completion_events[resource.name].wait() + print(f"{threading.current_thread().name} Resource {resource.name} completed") # Get final resource state and clean up - async with self.informer._lock: - final_resource = self.informer.resources.get(resource.name) - del self.completion_events[resource.name] - if final_resource: - await self.informer.remove_resource(resource.name) - return final_resource + final_resource = self.informer.resources.get(resource.name) + if final_resource is None: + raise ValueError(f"Resource {resource.name} not found") + del self.completion_events[resource.name] + print(f"{threading.current_thread().name} Removed completion event for {resource.name}") + await self.informer.remove_resource(resource.name) + print(f"{threading.current_thread().name} Removed resource {resource.name}, final={final_resource}") + return final_resource async def launch_resource(self, resource: R): """Attempt to launch a resource until successful""" @@ -194,7 +203,6 @@ async def process_resource(self, resource: R): elif resource.is_terminal(): if resource.name in self.completion_events: self.completion_events[resource.name].set() # Signal completion - await self.informer.remove_resource(resource.name) async def get_resource_status(self) -> Tuple[Set[str], Set[str]]: """Return current set of launched (running) and waiting-to-be-launched resources""" @@ -206,25 +214,28 @@ async def get_resource_status(self) -> Tuple[Set[str], Set[str]]: async def _log_resource_stats(self): """Periodically log resource stats if debug is enabled""" while self.running: - if logger.isEnabledFor(logging.DEBUG): - launched, waiting = await self.get_resource_status() - logger.debug(f"Resource stats: Launched={launched}, Waiting={waiting}") + launched, waiting = await self.get_resource_status() + logger.debug(f"Resource stats: Launched={launched}, Waiting={waiting}") await asyncio.sleep(2.0) async def run(self): """Run loop with resource status logging""" + print(f"{threading.current_thread().name} Controller starting") if self.running: logger.warning("Controller already running") return - print("Controller running") + print(f"{threading.current_thread().name} Controller running") self.running = True await self.informer.start() + print(f"{threading.current_thread().name} Informer started") asyncio.create_task(self._log_resource_stats()) while self.running: try: + print(f"{threading.current_thread().name} Waiting for resource") item = await self.shared_queue.get() + print(f"{threading.current_thread().name} Got resource {item}") if isinstance(item, str): logger.info(f"Received resource name: {item}") @@ -240,17 +251,6 @@ async def run(self): logger.error(f"Error in controller loop: {e}") await asyncio.sleep(1.0) - def _run(self, loop): - """Run the controller in the background""" - loop.run_until_complete(self.run()) - - def start(self): - """Synchronously start the controller in the background.""" - print("Starting controller") - thread = threading.Thread(target=self._run, kwargs={"loop" :asyncio.new_event_loop()}, daemon=True) - thread.start() - - async def stop(self): """Stop the controller""" self.running = False diff --git a/flytekit/core/eager_controller.py b/flytekit/core/eager_controller.py index b03ebe6426..ba54b6db61 100644 --- a/flytekit/core/eager_controller.py +++ b/flytekit/core/eager_controller.py @@ -38,7 +38,7 @@ async def _run_entity(self): try: # Assuming RunnableEntity.__call__ is synchronous # If it's async, remove the await asyncio.to_thread() - result = await asyncio.to_thread(self.entity.__call__, **self.input_kwargs) + result = await self.entity._task_function(**self.input_kwargs) self.result = result self.error = None except Exception as e: @@ -56,16 +56,21 @@ async def launch(self) -> bool: """ Launch the RunnableEntity in the background """ + print(f"Launching {self.key}") if self._task is None: self._task = asyncio.create_task(self._run_entity()) + print(f"Launched {self.key} with task {self._task}") return True def is_terminal(self) -> bool: """ Returns True if the node has completed (successfully or with error) """ + print(f"Checking terminal status for {self.key}") if self._task is None: + print(f"Task not found for {self.key}") return False + print(f"Task status for {self.key}: {self._task.done()}") return self._task.done() @cached_property @@ -79,6 +84,10 @@ def key(self) -> str: exec_name = f"{hex}" return exec_name + def is_started(self) -> bool: + """Check if resource has been started.""" + return self._task is not None + class EagerController(Controller[LocalNode]): diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index d5de395a1a..5e7fd3b36f 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -470,6 +470,8 @@ def __init__( else: kwargs["metadata"] = TaskMetadata(is_eager=True) + self._controller = None + super().__init__( task_config, task_function, @@ -518,6 +520,18 @@ async def async_execute(self, *args, **kwargs) -> Any: # Args is present because the asyn helper function passes it, but everything should be in kwargs by this point assert len(args) == 1 ctx = FlyteContextManager.current_context() + if self._controller is None: + tag = ctx.user_space_params.execution_id.name + root_tag = os.environ.get(EAGER_ROOT_ENV_NAME, tag) + + # Prefix is a combination of the name of this eager workflow, and the current execution id. + prefix = self.name.split(".")[-1][:8] + prefix = f"e-{prefix}-{tag[:5]}" + prefix = _dnsify(prefix) + controller = EagerController(prefix) + self._controller = asyncio.create_task(controller.run()) + ctx = ctx.with_worker_queue(controller).build() + FlyteContextManager.push_context(ctx) is_local_execution = False # cast(ExecutionState, ctx.execution_state).is_local_execution() if not is_local_execution: # a real execution @@ -586,9 +600,9 @@ def execute(self, **kwargs) -> Any: # signal.signal(signal.SIGINT, handler) # signal.signal(signal.SIGTERM, handler) # builder = ctx.with_worker_queue(c) - c = EagerController(prefix) - builder = ctx.with_worker_queue(c) - c.start() + # c = EagerController(prefix) + # builder = ctx.with_worker_queue(c) + # c.start() print("Added Controller!") else: raise AssertionError("Worker queue should not be already present in the context for eager execution") @@ -616,8 +630,8 @@ async def run_with_backend(self, **kwargs): logger.error(f"Leaving eager execution because of {ee}") base_error = ee - html = cast(Controller, ctx.worker_queue).render_html() - Deck("Eager Executions", html) + # html = cast(Controller, ctx.worker_queue).render_html() + # Deck("Eager Executions", html) if base_error: print(base_error.with_traceback(None))