-
Notifications
You must be signed in to change notification settings - Fork 0
WIP: add replacement for dask #1111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -26,12 +26,15 @@ | |||||
| overload, | ||||||
| ) | ||||||
|
|
||||||
| from typing_extensions import TypeVar as TypeVarExtension | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from dimos.core.introspection.module import ModuleInfo | ||||||
|
|
||||||
| from typing import TypeVar | ||||||
|
|
||||||
| from dask.distributed import Actor, get_worker | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| from reactivex.disposable import CompositeDisposable | ||||||
| from typing_extensions import TypeVar | ||||||
|
|
||||||
| from dimos.core import colors | ||||||
| from dimos.core.core import T, rpc | ||||||
|
|
@@ -82,7 +85,7 @@ class ModuleConfig: | |||||
| frame_id: str | None = None | ||||||
|
|
||||||
|
|
||||||
| ModuleConfigT = TypeVar("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) | ||||||
| ModuleConfigT = TypeVarExtension("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) | ||||||
|
|
||||||
|
|
||||||
| class ModuleBase(Configurable[ModuleConfigT], SkillContainer, Resource): | ||||||
|
|
@@ -355,7 +358,7 @@ def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: # type | |||||
| return result[0] if len(result) == 1 else result | ||||||
|
|
||||||
|
|
||||||
| class DaskModule(ModuleBase[ModuleConfigT]): | ||||||
| class Module(ModuleBase[ModuleConfigT]): | ||||||
| ref: Actor | ||||||
| worker: int | ||||||
|
|
||||||
|
|
@@ -454,5 +457,4 @@ def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]) -> | |||||
| getattr(self, output_name).transport.dask_register_subscriber(subscriber) | ||||||
|
|
||||||
|
|
||||||
| # global setting | ||||||
| Module = DaskModule | ||||||
| ModuleT = TypeVar("ModuleT", bound="Module") | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,22 +12,26 @@ | |||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| from concurrent.futures import ThreadPoolExecutor | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| import time | ||||||
| from typing import TypeVar | ||||||
|
|
||||||
| from traitlets import Any | ||||||
|
|
||||||
| from dimos import core | ||||||
| from dimos.core import DimosCluster, Module | ||||||
| from dimos.core import DimosCluster | ||||||
| from dimos.core.global_config import GlobalConfig | ||||||
| from dimos.core.module import Module, ModuleT | ||||||
| from dimos.core.resource import Resource | ||||||
|
|
||||||
| T = TypeVar("T", bound="Module") | ||||||
| from dimos.core.rpc_client import RPCClient | ||||||
| from dimos.core.worker_manager import WorkerManager | ||||||
|
|
||||||
|
|
||||||
| class ModuleCoordinator(Resource): | ||||||
| _client: DimosCluster | None = None | ||||||
| _client: DimosCluster | WorkerManager | None = None | ||||||
| _global_config: GlobalConfig | ||||||
| _n: int | None = None | ||||||
| _memory_limit: str = "auto" | ||||||
| _deployed_modules: dict[type[Module], Module] = {} | ||||||
| _deployed_modules: dict[type[Module], RPCClient] = {} | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
|
|
@@ -37,29 +41,55 @@ def __init__( | |||||
| cfg = global_config or GlobalConfig() | ||||||
| self._n = n if n is not None else cfg.n_dask_workers | ||||||
| self._memory_limit = cfg.memory_limit | ||||||
| self._global_config = cfg | ||||||
|
|
||||||
| def start(self) -> None: | ||||||
| self._client = core.start(self._n, self._memory_limit) | ||||||
| if self._global_config.dask: | ||||||
| self._client = core.start(self._n, self._memory_limit) | ||||||
| else: | ||||||
| self._client = WorkerManager() | ||||||
|
|
||||||
| def stop(self) -> None: | ||||||
| for module in reversed(self._deployed_modules.values()): | ||||||
| module.stop() | ||||||
|
|
||||||
| self._client.close_all() # type: ignore[union-attr] | ||||||
|
|
||||||
| def deploy(self, module_class: type[T], *args, **kwargs) -> T: # type: ignore[no-untyped-def] | ||||||
| def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCClient: | ||||||
| if not self._client: | ||||||
| raise ValueError("Not started") | ||||||
|
|
||||||
| module = self._client.deploy(module_class, *args, **kwargs) # type: ignore[attr-defined] | ||||||
| module = self._client.deploy(module_class, *args, **kwargs) # type: ignore[union-attr] | ||||||
| self._deployed_modules[module_class] = module | ||||||
| return module # type: ignore[no-any-return] | ||||||
| return module | ||||||
|
|
||||||
| def deploy_parallel( | ||||||
| self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[str, Any]]] | ||||||
| ) -> list[RPCClient]: | ||||||
| if not self._client: | ||||||
| raise ValueError("Not started") | ||||||
|
|
||||||
| if isinstance(self._client, WorkerManager): | ||||||
| modules = self._client.deploy_parallel(module_specs) | ||||||
| for (module_class, _, _), module in zip(module_specs, modules, strict=True): | ||||||
| self._deployed_modules[module_class] = module | ||||||
| return modules # type: ignore[return-value] | ||||||
| else: | ||||||
| return [ | ||||||
| self.deploy(module_class, *args, **kwargs) | ||||||
| for module_class, args, kwargs in module_specs | ||||||
| ] | ||||||
|
|
||||||
| def start_all_modules(self) -> None: | ||||||
| for module in self._deployed_modules.values(): | ||||||
| module.start() | ||||||
| modules = list(self._deployed_modules.values()) | ||||||
| if isinstance(self._client, WorkerManager): | ||||||
| with ThreadPoolExecutor(max_workers=len(modules)) as executor: | ||||||
| list(executor.map(lambda m: m.start(), modules)) | ||||||
| else: | ||||||
| for module in modules: | ||||||
| module.start() | ||||||
|
|
||||||
| def get_instance(self, module: type[T]) -> T | None: | ||||||
| def get_instance(self, module: type[ModuleT]) -> ModuleT | None: | ||||||
| return self._deployed_modules.get(module) # type: ignore[return-value] | ||||||
|
|
||||||
| def loop(self) -> None: | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # Copyright 2026 Dimensional Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
|
|
||
| from dimos.core import In, Module, Out, rpc | ||
| from dimos.core.worker_manager import WorkerManager | ||
| from dimos.msgs.geometry_msgs import Vector3 | ||
|
|
||
|
|
||
| class SimpleModule(Module): | ||
| output: Out[Vector3] | ||
| input: In[Vector3] | ||
|
|
||
| counter: int = 0 | ||
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| pass | ||
|
|
||
| @rpc | ||
| def increment(self) -> int: | ||
| self.counter += 1 | ||
| return self.counter | ||
|
|
||
| @rpc | ||
| def get_counter(self) -> int: | ||
| return self.counter | ||
|
|
||
|
|
||
| class AnotherModule(Module): | ||
| value: int = 100 | ||
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| pass | ||
|
|
||
| @rpc | ||
| def add(self, n: int) -> int: | ||
| self.value += n | ||
| return self.value | ||
|
|
||
| @rpc | ||
| def get_value(self) -> int: | ||
| return self.value | ||
|
|
||
|
|
||
| class ThirdModule(Module): | ||
| multiplier: int = 1 | ||
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| pass | ||
|
|
||
| @rpc | ||
| def multiply(self, n: int) -> int: | ||
| self.multiplier *= n | ||
| return self.multiplier | ||
|
|
||
| @rpc | ||
| def get_multiplier(self) -> int: | ||
| return self.multiplier | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def worker_manager(): | ||
| manager = WorkerManager() | ||
| try: | ||
| yield manager | ||
| finally: | ||
| manager.close_all() | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| def test_worker_manager_basic(worker_manager): | ||
| module = worker_manager.deploy(SimpleModule) | ||
| module.start() | ||
|
|
||
| result = module.increment() | ||
| assert result == 1 | ||
|
|
||
| result = module.increment() | ||
| assert result == 2 | ||
|
|
||
| result = module.get_counter() | ||
| assert result == 2 | ||
|
|
||
| module.stop() | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| def test_worker_manager_multiple_different_modules(worker_manager): | ||
| module1 = worker_manager.deploy(SimpleModule) | ||
| module2 = worker_manager.deploy(AnotherModule) | ||
|
|
||
| module1.start() | ||
| module2.start() | ||
|
|
||
| # Each module has its own state | ||
| module1.increment() | ||
| module1.increment() | ||
| module2.add(10) | ||
|
|
||
| assert module1.get_counter() == 2 | ||
| assert module2.get_value() == 110 | ||
|
|
||
| # Stop modules to clean up threads | ||
| module1.stop() | ||
| module2.stop() | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| def test_worker_manager_parallel_deployment(worker_manager): | ||
| modules = worker_manager.deploy_parallel( | ||
| [ | ||
| (SimpleModule, (), {}), | ||
| (AnotherModule, (), {}), | ||
| (ThirdModule, (), {}), | ||
| ] | ||
| ) | ||
|
|
||
| assert len(modules) == 3 | ||
| module1, module2, module3 = modules | ||
|
|
||
| # Start all modules | ||
| module1.start() | ||
| module2.start() | ||
| module3.start() | ||
|
|
||
| # Each module has its own state | ||
| module1.increment() | ||
| module2.add(50) | ||
| module3.multiply(5) | ||
|
|
||
| assert module1.get_counter() == 1 | ||
| assert module2.get_value() == 150 | ||
| assert module3.get_multiplier() == 5 | ||
|
|
||
| # Stop modules | ||
| module1.stop() | ||
| module2.stop() | ||
| module3.stop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeVaris imported twice - once fromtyping_extensionsasTypeVarExtension(line 29), and again fromtyping(line 34). The second import shadows the first one.