Skip to content

Commit 145d125

Browse files
authored
Support external pipelines registration (#3691)
1 parent 109cab4 commit 145d125

1 file changed

Lines changed: 22 additions & 10 deletions

File tree

  • src/dstack/_internal/server/background/pipeline_tasks

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
class PipelineManager:
2626
def __init__(self) -> None:
27-
self._pipelines: list[Pipeline] = [
27+
self._pipelines: list[Pipeline] = []
28+
self._hinter = PipelineHinter()
29+
for builtin_pipeline in [
2830
ComputeGroupPipeline(),
2931
FleetPipeline(),
3032
GatewayPipeline(),
@@ -35,8 +37,12 @@ def __init__(self) -> None:
3537
PlacementGroupPipeline(),
3638
RunPipeline(),
3739
VolumePipeline(),
38-
]
39-
self._hinter = PipelineHinter(self._pipelines)
40+
]:
41+
self.register_pipeline(builtin_pipeline)
42+
43+
def register_pipeline(self, pipeline: Pipeline):
44+
self._pipelines.append(pipeline)
45+
self._hinter.register_pipeline(pipeline)
4046

4147
def start(self):
4248
for pipeline in self._pipelines:
@@ -64,11 +70,11 @@ def hinter(self):
6470

6571

6672
class PipelineHinter:
67-
def __init__(self, pipelines: list[Pipeline]) -> None:
68-
self._pipelines = pipelines
73+
def __init__(self) -> None:
6974
self._hint_fetch_map: dict[str, list[Pipeline]] = {}
70-
for pipeline in self._pipelines:
71-
self._hint_fetch_map.setdefault(pipeline.hint_fetch_model_name, []).append(pipeline)
75+
76+
def register_pipeline(self, pipeline: Pipeline):
77+
self._hint_fetch_map.setdefault(pipeline.hint_fetch_model_name, []).append(pipeline)
7278

7379
def hint_fetch(self, model_name: str):
7480
pipelines = self._hint_fetch_map.get(model_name)
@@ -79,11 +85,17 @@ def hint_fetch(self, model_name: str):
7985
pipeline.hint_fetch()
8086

8187

88+
_pipeline_manager = PipelineManager()
89+
90+
91+
def get_pipeline_manager() -> PipelineManager:
92+
return _pipeline_manager
93+
94+
8295
def start_pipeline_tasks() -> PipelineManager:
8396
"""
8497
Start tasks processed by fetch-workers pipelines based on db + in-memory queues.
8598
Suitable for tasks that run frequently and need to lock rows for a long time.
8699
"""
87-
pipeline_manager = PipelineManager()
88-
pipeline_manager.start()
89-
return pipeline_manager
100+
_pipeline_manager.start()
101+
return _pipeline_manager

0 commit comments

Comments
 (0)