Skip to content

Commit 24d976e

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Remove create_group_run_spec and use Job Configurator instead
1 parent a0e13f6 commit 24d976e

8 files changed

Lines changed: 108 additions & 111 deletions

File tree

src/dstack/_internal/core/models/configurations.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def convert_ports(cls, v) -> PortMapping:
611611
class ConfigurationWithCommandsParams(CoreModel):
612612
commands: Annotated[CommandsList, Field(description="The shell commands to run")] = []
613613

614-
@root_validator(pre=True)
614+
@root_validator
615615
def check_image_or_commands_present(cls, values):
616616
# If replicas is list, skip validation - commands come from replica groups
617617
replicas = values.get("replicas")
@@ -957,7 +957,6 @@ def validate_top_level_properties_with_replica_groups(cls, values):
957957
)
958958

959959
resources = values.get("resources")
960-
from dstack._internal.core.models.resources import ResourcesSpec
961960

962961
default_resources = ResourcesSpec()
963962
if resources and resources.dict() != default_resources.dict():

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from dstack._internal.server.services.locking import get_locker
4141
from dstack._internal.server.services.prometheus.client_metrics import run_metrics
4242
from dstack._internal.server.services.runs import (
43-
create_group_run_spec,
4443
fmt,
4544
process_terminating_run,
4645
run_model_to_run,
@@ -482,7 +481,7 @@ async def _handle_run_replicas(
482481
session=session,
483482
run_model=run_model,
484483
group=group,
485-
base_run_spec=run_spec,
484+
run_spec=run_spec,
486485
desired_replica_counts=counts,
487486
)
488487
return
@@ -563,7 +562,6 @@ async def _update_jobs_to_new_deployment_in_place(
563562
session=session,
564563
project=run_model.project,
565564
)
566-
base_run_spec = run_spec
567565

568566
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
569567
if all(j.status.is_finished() for j in job_models):
@@ -573,22 +571,17 @@ async def _update_jobs_to_new_deployment_in_place(
573571

574572
# Determine which group this replica belongs to
575573
replica_group_name = None
576-
group_run_spec = base_run_spec
577574

578575
if replicas:
579576
job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data)
580577
replica_group_name = job_spec.replica_group
581578

582-
for group in replicas:
583-
if group.name == replica_group_name:
584-
group_run_spec = create_group_run_spec(base_run_spec, group)
585-
break
586-
587579
# FIXME: Handle getting image configuration errors or skip it.
588580
new_job_specs = await get_job_specs_from_run_spec(
589-
run_spec=group_run_spec,
581+
run_spec=run_spec,
590582
secrets=secrets,
591583
replica_num=replica_num,
584+
replica_group_name=replica_group_name,
592585
)
593586
assert len(new_job_specs) == len(job_models), (
594587
"Changing the number of jobs within a replica is not yet supported"
@@ -681,7 +674,7 @@ async def _handle_rolling_deployment_for_group(
681674
session: AsyncSession,
682675
run_model: RunModel,
683676
group: ReplicaGroup,
684-
base_run_spec: RunSpec,
677+
run_spec: RunSpec,
685678
desired_replica_counts: dict,
686679
) -> None:
687680
"""
@@ -726,7 +719,7 @@ async def _handle_rolling_deployment_for_group(
726719
run_model=run_model,
727720
group=group,
728721
replicas_diff=group_max_replica_count - non_terminated_replica_count,
729-
base_run_spec=base_run_spec,
722+
run_spec=run_spec,
730723
active_replicas=active_replicas,
731724
inactive_replicas=inactive_replicas,
732725
)
@@ -774,7 +767,7 @@ async def _handle_rolling_deployment_for_group(
774767
run_model=run_model,
775768
group=group,
776769
replicas_diff=-replicas_to_stop_count,
777-
base_run_spec=base_run_spec,
770+
run_spec=run_spec,
778771
active_replicas=active_replicas,
779772
inactive_replicas=inactive_replicas,
780773
)

src/dstack/_internal/server/services/jobs/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,31 @@ def switch_job_status(
9898

9999

100100
async def get_jobs_from_run_spec(
101-
run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
101+
run_spec: RunSpec,
102+
secrets: Dict[str, str],
103+
replica_num: int,
104+
replica_group_name: Optional[str] = None,
102105
) -> List[Job]:
103106
return [
104107
Job(job_spec=s, job_submissions=[])
105108
for s in await get_job_specs_from_run_spec(
106109
run_spec=run_spec,
107110
secrets=secrets,
108111
replica_num=replica_num,
112+
replica_group_name=replica_group_name,
109113
)
110114
]
111115

112116

113117
async def get_job_specs_from_run_spec(
114-
run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
118+
run_spec: RunSpec,
119+
secrets: Dict[str, str],
120+
replica_num: int,
121+
replica_group_name: Optional[str] = None,
115122
) -> List[JobSpec]:
116-
job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets)
123+
job_configurator = _get_job_configurator(
124+
run_spec=run_spec, secrets=secrets, replica_group_name=replica_group_name
125+
)
117126
job_specs = await job_configurator.get_job_specs(replica_num=replica_num)
118127
return job_specs
119128

@@ -241,10 +250,14 @@ def is_master_job(job: Job) -> bool:
241250
return job.job_spec.job_num == 0
242251

243252

244-
def _get_job_configurator(run_spec: RunSpec, secrets: Dict[str, str]) -> JobConfigurator:
253+
def _get_job_configurator(
254+
run_spec: RunSpec, secrets: Dict[str, str], replica_group_name: Optional[str] = None
255+
) -> JobConfigurator:
245256
configuration_type = RunConfigurationType(run_spec.configuration.type)
246257
configurator_class = _configuration_type_to_configurator_class_map[configuration_type]
247-
return configurator_class(run_spec=run_spec, secrets=secrets)
258+
return configurator_class(
259+
run_spec=run_spec, secrets=secrets, replica_group_name=replica_group_name
260+
)
248261

249262

250263
_job_configurator_classes = [

src/dstack/_internal/server/services/jobs/configurators/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,11 @@ def __init__(
9090
self,
9191
run_spec: RunSpec,
9292
secrets: Optional[Dict[str, str]] = None,
93+
replica_group_name: Optional[str] = None,
9394
):
9495
self.run_spec = run_spec
9596
self.secrets = secrets or {}
97+
self.replica_group_name = replica_group_name
9698

9799
async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
98100
job_spec = await self._get_job_spec(replica_num=replica_num, job_num=0, jobs_per_replica=1)
@@ -150,6 +152,7 @@ async def _get_job_spec(
150152
job_num=job_num,
151153
job_name=f"{self.run_spec.run_name}-{job_num}-{replica_num}",
152154
jobs_per_replica=jobs_per_replica,
155+
replica_group=self.replica_group_name or "default",
153156
app_specs=self._app_specs(),
154157
commands=await self._commands(),
155158
env=self._env(),
@@ -298,9 +301,15 @@ def _registry_auth(self) -> Optional[RegistryAuth]:
298301
return self.run_spec.configuration.registry_auth
299302

300303
def _requirements(self, jobs_per_replica: int) -> Requirements:
304+
resources = self.run_spec.configuration.resources
305+
if self.run_spec.configuration.type == "service":
306+
for group in self.run_spec.configuration.replica_groups:
307+
if group.name == self.replica_group_name:
308+
resources = group.resources
309+
break
301310
spot_policy = self._spot_policy()
302311
return Requirements(
303-
resources=self.run_spec.configuration.resources,
312+
resources=resources,
304313
max_price=self.run_spec.merged_profile.max_price,
305314
spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT),
306315
reservation=self.run_spec.merged_profile.reservation,

src/dstack/_internal/server/services/jobs/configurators/service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ class ServiceJobConfigurator(JobConfigurator):
1010

1111
def _shell_commands(self) -> List[str]:
1212
assert self.run_spec.configuration.type == "service"
13+
for group in self.run_spec.configuration.replica_groups:
14+
if group.name == self.replica_group_name:
15+
return group.commands
1316
return self.run_spec.configuration.commands
1417

1518
def _default_single_branch(self) -> bool:

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
ServerClientError,
1919
)
2020
from dstack._internal.core.models.common import ApplyAction
21-
from dstack._internal.core.models.configurations import ReplicaGroup
2221
from dstack._internal.core.models.profiles import (
2322
RetryEvent,
2423
)
@@ -520,29 +519,22 @@ async def submit_run(
520519

521520
global_replica_num = 0 # Global counter across all groups for unique replica_num
522521

523-
for replica_group in service_config.replica_groups or []:
522+
for replica_group in service_config.replica_groups:
524523
if run_spec.merged_profile.schedule is not None:
525524
group_initial_replicas = 0
526525
else:
527526
group_initial_replicas = replica_group.count.min or 0
528527

529528
# Each replica in this group gets the same group-specific configuration
530529
for group_replica_num in range(group_initial_replicas):
531-
group_run_spec = create_group_run_spec(
532-
base_run_spec=run_spec,
533-
replica_group=replica_group,
534-
)
535530
jobs = await get_jobs_from_run_spec(
536-
run_spec=group_run_spec,
531+
run_spec=run_spec,
537532
secrets=secrets,
538533
replica_num=global_replica_num,
534+
replica_group_name=replica_group.name,
539535
)
540536

541537
for job in jobs:
542-
assert replica_group.name is not None, (
543-
"ReplicaGroup name should be set by validator"
544-
)
545-
job.job_spec.replica_group = replica_group.name
546538
job_model = create_job_model_for_new_submission(
547539
run_model=run_model,
548540
job=job,
@@ -590,36 +582,6 @@ async def submit_run(
590582
return common_utils.get_or_error(run)
591583

592584

593-
def create_group_run_spec(
594-
base_run_spec: RunSpec,
595-
replica_group: ReplicaGroup,
596-
) -> RunSpec:
597-
# Create a copy of the configuration as a dict
598-
config_dict = base_run_spec.configuration.dict()
599-
600-
# Remove replicas and scaling fields since we're creating a single-replica config
601-
# This prevents validation errors when commands/resources are added
602-
config_dict.pop("replicas", None)
603-
config_dict.pop("scaling", None)
604-
605-
# Override with group-specific values (only if provided)
606-
if replica_group.commands:
607-
config_dict["commands"] = replica_group.commands
608-
609-
if replica_group.resources:
610-
config_dict["resources"] = replica_group.resources
611-
612-
# Create new configuration object with merged values
613-
# Use the same class as base (ServiceConfiguration)
614-
new_config = base_run_spec.configuration.__class__.parse_obj(config_dict)
615-
616-
# Create new RunSpec with modified configuration
617-
# Preserve all other RunSpec properties (repo_data, file_archives, etc.)
618-
run_spec_dict = base_run_spec.dict()
619-
run_spec_dict["configuration"] = new_config
620-
return RunSpec.parse_obj(run_spec_dict)
621-
622-
623585
def create_job_model_for_new_submission(
624586
run_model: RunModel,
625587
job: Job,

0 commit comments

Comments
 (0)