Skip to content

Commit 5192caf

Browse files
committed
chore: run linter and fix failing tests
1 parent eece104 commit 5192caf

11 files changed

Lines changed: 190 additions & 147 deletions

File tree

src/dstack/_internal/cli/utils/run.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def th(s: str) -> str:
119119
if include_run_properties:
120120
props.add_row(th("Configuration"), run_spec.configuration_path)
121121
props.add_row(th("Type"), run_spec.configuration.type)
122-
122+
123123
from dstack._internal.core.models.configurations import ServiceConfiguration
124-
124+
125125
if (
126126
include_run_properties
127127
and isinstance(run_spec.configuration, ServiceConfiguration)
@@ -130,21 +130,21 @@ def th(s: str) -> str:
130130
groups_info = []
131131
for group in run_spec.configuration.replica_groups:
132132
group_parts = [f"[cyan]{group.name}[/cyan]"]
133-
133+
134134
if group.replicas.min == group.replicas.max:
135135
group_parts.append(f"×{group.replicas.max}")
136136
else:
137137
group_parts.append(f"×{group.replicas.min}..{group.replicas.max}")
138138
group_parts.append("[dim](autoscalable)[/dim]")
139-
139+
140140
group_parts.append(f"[dim]({group.resources.pretty_format()})[/dim]")
141-
141+
142142
groups_info.append(" ".join(group_parts))
143-
143+
144144
props.add_row(th("Replica groups"), "\n".join(groups_info))
145145
else:
146146
props.add_row(th("Resources"), pretty_req)
147-
147+
148148
props.add_row(th("Spot policy"), spot_policy)
149149
props.add_row(th("Max price"), max_price)
150150
if include_run_properties:
@@ -163,27 +163,27 @@ def th(s: str) -> str:
163163
offers.add_column("INSTANCE TYPE", style="grey58", no_wrap=True, ratio=2)
164164
offers.add_column("PRICE", style="grey58", ratio=1)
165165
offers.add_column()
166-
166+
167167
# For replica groups, show offers from all job plans
168168
if len(run_plan.job_plans) > 1:
169169
# Multiple jobs - aggregate offers from all groups
170170
all_offers = []
171171
groups_with_no_offers = []
172172
total_offers_count = 0
173-
173+
174174
for jp in run_plan.job_plans:
175175
group_name = jp.job_spec.replica_group_name or "default"
176176
if jp.total_offers == 0:
177177
groups_with_no_offers.append(group_name)
178178
for offer in jp.offers[:max_offers] if max_offers else jp.offers:
179179
all_offers.append((group_name, offer))
180180
total_offers_count += jp.total_offers
181-
181+
182182
# Sort by price
183183
all_offers.sort(key=lambda x: x[1].price)
184184
if max_offers:
185185
all_offers = all_offers[:max_offers]
186-
186+
187187
# Show groups with no offers FIRST
188188
for group_name in groups_with_no_offers:
189189
offers.add_row(
@@ -196,7 +196,7 @@ def th(s: str) -> str:
196196
"",
197197
style="secondary",
198198
)
199-
199+
200200
# Then show groups with offers
201201
for i, (group_name, offer) in enumerate(all_offers, start=1):
202202
r = offer.instance.resources
@@ -212,10 +212,10 @@ def th(s: str) -> str:
212212
instance = offer.instance.name
213213
if offer.total_blocks > 1:
214214
instance += f" ({offer.blocks}/{offer.total_blocks})"
215-
215+
216216
# Add group name prefix for multi-group display
217217
backend_display = f"[cyan]{group_name}[/cyan]: {offer.backend.replace('remote', 'ssh')} ({offer.region})"
218-
218+
219219
offers.add_row(
220220
f"{i}",
221221
backend_display,
@@ -225,7 +225,7 @@ def th(s: str) -> str:
225225
availability,
226226
style=None if i == 1 or not include_run_properties else "secondary",
227227
)
228-
228+
229229
if total_offers_count > len(all_offers):
230230
offers.add_row("", "...", style="secondary")
231231
else:
@@ -260,14 +260,14 @@ def th(s: str) -> str:
260260

261261
console.print(props)
262262
console.print()
263-
263+
264264
# Check if we have offers to display
265265
has_offers = False
266266
if len(run_plan.job_plans) > 1:
267267
has_offers = any(len(jp.offers) > 0 for jp in run_plan.job_plans)
268268
else:
269269
has_offers = len(job_plan.offers) > 0
270-
270+
271271
if has_offers:
272272
console.print(offers)
273273
# Show summary for multi-job plans
@@ -343,12 +343,12 @@ def get_runs_table(
343343
if verbose and latest_job_submission.inactivity_secs:
344344
inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs)
345345
status += f" (inactive for {inactive_for})"
346-
346+
347347
job_name_parts = [f" replica={job.job_spec.replica_num}"]
348348
if job.job_spec.replica_group_name:
349349
job_name_parts.append(f"[cyan]group={job.job_spec.replica_group_name}[/cyan]")
350350
job_name_parts.append(f"job={job.job_spec.job_num}")
351-
351+
352352
job_row: Dict[Union[str, int], Any] = {
353353
"NAME": " ".join(job_name_parts)
354354
+ (

src/dstack/_internal/core/models/configurations.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -846,32 +846,32 @@ def validate_gateway(
846846
def validate_replica_groups_xor_replicas(cls, values):
847847
replica_groups = values.get("replica_groups")
848848
replicas = values.get("replicas")
849-
849+
850850
# Check if user specified both
851851
has_groups = replica_groups is not None
852852
has_replicas = replicas != Range[int](min=1, max=1)
853-
853+
854854
if has_groups and has_replicas:
855855
raise ValueError("Cannot specify both 'replicas' and 'replica_groups'")
856-
856+
857857
if has_groups:
858858
# Validate unique names
859859
names = [g.name for g in replica_groups]
860860
if len(names) != len(set(names)):
861861
raise ValueError("Replica group names must be unique")
862-
862+
863863
# Validate at least one group
864864
if not replica_groups:
865865
raise ValueError("replica_groups cannot be empty")
866-
866+
867867
return values
868868

869869
@root_validator()
870870
def validate_scaling(cls, values):
871871
scaling = values.get("scaling")
872872
replicas = values.get("replicas")
873873
replica_groups = values.get("replica_groups")
874-
874+
875875
if replica_groups:
876876
# Check if any group has a range
877877
has_range = any(g.replicas.min != g.replicas.max for g in replica_groups)
@@ -883,7 +883,7 @@ def validate_scaling(cls, values):
883883
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
884884
elif replicas and replicas.min == replicas.max and scaling:
885885
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
886-
886+
887887
return values
888888

889889
@validator("rate_limits")

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ async def _handle_run_replicas(
484484
session,
485485
run_model,
486486
replicas_diff=max_replica_count - non_terminated_replica_count,
487+
allow_exceeding_max=True, # Allow exceeding max for rolling deployments
487488
)
488489

489490
replicas_to_stop_count = 0
@@ -510,6 +511,7 @@ async def _handle_run_replicas(
510511
session,
511512
run_model,
512513
replicas_diff=-replicas_to_stop_count,
514+
allow_exceeding_max=True, # Allow terminating out-of-date replicas during rolling deployment
513515
)
514516

515517

src/dstack/_internal/server/services/runs.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ async def get_plan(
340340
action = ApplyAction.UPDATE
341341

342342
secrets = await get_project_secrets_mapping(session=session, project=project)
343-
343+
344344
# For services with replica groups, create jobs for all groups during planning
345345
jobs = []
346346
if (
@@ -407,12 +407,12 @@ async def get_plan(
407407
job_plans = []
408408
for job in jobs:
409409
job_offers: List[InstanceOfferWithAvailability] = []
410-
410+
411411
# Filter pool offers to match this job's GPU requirements
412412
gpu_req = None
413413
if job.job_spec.requirements.resources and job.job_spec.requirements.resources.gpu:
414414
gpu_req = job.job_spec.requirements.resources.gpu.name
415-
415+
416416
matching_pool_offers = []
417417
for pool_offer in pool_offers:
418418
offer_gpus = pool_offer.instance.resources.gpus
@@ -424,9 +424,9 @@ async def get_plan(
424424
elif not gpu_req:
425425
# No GPU requirement, include all pool offers
426426
matching_pool_offers.append(pool_offer)
427-
427+
428428
job_offers.extend(matching_pool_offers)
429-
429+
430430
# Use shared offers if all jobs are identical, otherwise fetch per-job
431431
if shared_offers:
432432
job_offers.extend(offer for _, offer in shared_offers)
@@ -443,7 +443,7 @@ async def get_plan(
443443
instance_mounts=check_run_spec_requires_instance_mounts(effective_run_spec),
444444
)
445445
job_offers.extend(offer for _, offer in job_specific_offers)
446-
446+
447447
job_offers.sort(key=lambda offer: not offer.availability.is_available())
448448

449449
job_spec = job.job_spec
@@ -1287,7 +1287,21 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel):
12871287
)
12881288

12891289

1290-
async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int):
1290+
async def scale_run_replicas(
1291+
session: AsyncSession,
1292+
run_model: RunModel,
1293+
replicas_diff: int,
1294+
allow_exceeding_max: bool = False,
1295+
):
1296+
"""
1297+
Scale run replicas up or down.
1298+
1299+
Args:
1300+
session: Database session
1301+
run_model: The run to scale
1302+
replicas_diff: Number of replicas to add (positive) or remove (negative)
1303+
allow_exceeding_max: If True, allow scaling beyond configured max (for rolling deployments)
1304+
"""
12911305
if replicas_diff == 0:
12921306
# nothing to do
12931307
return
@@ -1349,9 +1363,10 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
13491363
# Get group minimums
13501364
group_mins = {g.name: g.replicas.min for g in normalized_groups}
13511365

1352-
# Terminate from end (reversed), but skip if group not autoscalable or at minimum
1366+
# Terminate from end (reversed)
1367+
# For rolling deployments (allow_exceeding_max), prioritize terminating out-of-date replicas
13531368
terminated_count = 0
1354-
for _, _, _, replica_jobs in reversed(active_replicas):
1369+
for _, is_out_of_date, _, replica_jobs in reversed(active_replicas):
13551370
if terminated_count >= abs(replicas_diff):
13561371
break
13571372

@@ -1360,7 +1375,19 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
13601375

13611376
group_name = replica_jobs[0].replica_group_name or "default"
13621377

1363-
# Skip if not autoscalable
1378+
# For rolling deployment, allow terminating any out-of-date replica
1379+
if allow_exceeding_max and is_out_of_date:
1380+
# Terminate this replica (out-of-date during rolling deployment)
1381+
for job in replica_jobs:
1382+
if not job.status.is_finished() and job.status != JobStatus.TERMINATING:
1383+
job.status = JobStatus.TERMINATING
1384+
job.termination_reason = JobTerminationReason.SCALED_DOWN
1385+
1386+
group_counts[group_name] -= 1
1387+
terminated_count += 1
1388+
continue
1389+
1390+
# For normal scaling, skip if not autoscalable
13641391
if normalized_groups and group_name not in autoscalable_groups:
13651392
continue
13661393

@@ -1379,29 +1406,42 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
13791406
group_counts[group_name] -= 1
13801407
terminated_count += 1
13811408
else:
1382-
# SCALE UP: Choose from autoscalable groups
1383-
autoscalable_groups = [g for g in normalized_groups if g.replicas.min != g.replicas.max]
1384-
1385-
if normalized_groups and not autoscalable_groups:
1386-
# No autoscalable groups, cannot scale
1387-
logger.info("%s: no autoscalable groups available for scaling up", fmt(run_model))
1388-
return
1389-
1390-
# Count current replicas per group to respect maximums
1409+
# SCALE UP
1410+
# Count current replicas per group
13911411
group_counts = {}
13921412
for _, _, _, replica_jobs in active_replicas:
13931413
if replica_jobs:
13941414
group_name = replica_jobs[0].replica_group_name or "default"
13951415
group_counts[group_name] = group_counts.get(group_name, 0) + 1
13961416

1397-
# Filter groups that haven't reached maximum
1398-
eligible_groups = [
1399-
g for g in autoscalable_groups if group_counts.get(g.name, 0) < (g.replicas.max or float("inf"))
1400-
] if normalized_groups else normalized_groups
1417+
# First, identify groups below minimum (need to scale regardless of autoscalability)
1418+
below_min_groups = [
1419+
g for g in normalized_groups
1420+
if group_counts.get(g.name, 0) < (g.replicas.min or 0)
1421+
]
1422+
1423+
# Then, identify autoscalable groups that can scale beyond minimum
1424+
autoscalable_groups = [
1425+
g for g in normalized_groups
1426+
if g.replicas.min != g.replicas.max and (
1427+
allow_exceeding_max or group_counts.get(g.name, 0) < (g.replicas.max or float("inf"))
1428+
)
1429+
]
1430+
1431+
# Eligible groups are: below-min groups + autoscalable groups
1432+
eligible_groups = []
1433+
if below_min_groups:
1434+
eligible_groups.extend(below_min_groups)
1435+
elif autoscalable_groups:
1436+
# Only use autoscalable groups if no groups are below minimum
1437+
eligible_groups.extend(autoscalable_groups)
1438+
elif allow_exceeding_max and normalized_groups:
1439+
# For rolling deployments, allow exceeding max even for fixed groups
1440+
eligible_groups.extend(normalized_groups)
14011441

14021442
if normalized_groups and not eligible_groups:
1403-
# All groups at maximum
1404-
logger.info("%s: all autoscalable groups at maximum capacity", fmt(run_model))
1443+
# All groups at their limits
1444+
logger.info("%s: all replica groups at their limits (min/max)", fmt(run_model))
14051445
return
14061446

14071447
scheduled_replicas = 0
@@ -1410,10 +1450,10 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
14101450
for _, _, _, replica_jobs in inactive_replicas:
14111451
if scheduled_replicas == replicas_diff:
14121452
break
1413-
# Only reuse if from autoscalable group
1453+
# Only reuse if from eligible group
14141454
if replica_jobs:
14151455
group_name = replica_jobs[0].replica_group_name or "default"
1416-
if not normalized_groups or group_name in {g.name for g in autoscalable_groups}:
1456+
if not normalized_groups or group_name in {g.name for g in eligible_groups}:
14171457
await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
14181458
scheduled_replicas += 1
14191459

src/dstack/_internal/server/services/services/autoscalers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler:
131131
assert conf.replicas.max is not None
132132
min_replicas = conf.replicas.min
133133
max_replicas = conf.replicas.max
134-
134+
135135
if conf.scaling is None:
136136
return ManualScaler(
137137
min_replicas=min_replicas,

0 commit comments

Comments
 (0)