@@ -340,7 +340,7 @@ async def get_plan(
340340 action = ApplyAction .UPDATE
341341
342342 secrets = await get_project_secrets_mapping (session = session , project = project )
343-
343+
344344 # For services with replica groups, create jobs for all groups during planning
345345 jobs = []
346346 if (
@@ -407,12 +407,12 @@ async def get_plan(
407407 job_plans = []
408408 for job in jobs :
409409 job_offers : List [InstanceOfferWithAvailability ] = []
410-
410+
411411 # Filter pool offers to match this job's GPU requirements
412412 gpu_req = None
413413 if job .job_spec .requirements .resources and job .job_spec .requirements .resources .gpu :
414414 gpu_req = job .job_spec .requirements .resources .gpu .name
415-
415+
416416 matching_pool_offers = []
417417 for pool_offer in pool_offers :
418418 offer_gpus = pool_offer .instance .resources .gpus
@@ -424,9 +424,9 @@ async def get_plan(
424424 elif not gpu_req :
425425 # No GPU requirement, include all pool offers
426426 matching_pool_offers .append (pool_offer )
427-
427+
428428 job_offers .extend (matching_pool_offers )
429-
429+
430430 # Use shared offers if all jobs are identical, otherwise fetch per-job
431431 if shared_offers :
432432 job_offers .extend (offer for _ , offer in shared_offers )
@@ -443,7 +443,7 @@ async def get_plan(
443443 instance_mounts = check_run_spec_requires_instance_mounts (effective_run_spec ),
444444 )
445445 job_offers .extend (offer for _ , offer in job_specific_offers )
446-
446+
447447 job_offers .sort (key = lambda offer : not offer .availability .is_available ())
448448
449449 job_spec = job .job_spec
@@ -1287,7 +1287,21 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel):
12871287 )
12881288
12891289
1290- async def scale_run_replicas (session : AsyncSession , run_model : RunModel , replicas_diff : int ):
1290+ async def scale_run_replicas (
1291+ session : AsyncSession ,
1292+ run_model : RunModel ,
1293+ replicas_diff : int ,
1294+ allow_exceeding_max : bool = False ,
1295+ ):
1296+ """
1297+ Scale run replicas up or down.
1298+
1299+ Args:
1300+ session: Database session
1301+ run_model: The run to scale
1302+ replicas_diff: Number of replicas to add (positive) or remove (negative)
1303+ allow_exceeding_max: If True, allow scaling beyond configured max (for rolling deployments)
1304+ """
12911305 if replicas_diff == 0 :
12921306 # nothing to do
12931307 return
@@ -1349,9 +1363,10 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
13491363 # Get group minimums
13501364 group_mins = {g .name : g .replicas .min for g in normalized_groups }
13511365
1352- # Terminate from end (reversed), but skip if group not autoscalable or at minimum
1366+ # Terminate from end (reversed)
1367+ # For rolling deployments (allow_exceeding_max), prioritize terminating out-of-date replicas
13531368 terminated_count = 0
1354- for _ , _ , _ , replica_jobs in reversed (active_replicas ):
1369+ for _ , is_out_of_date , _ , replica_jobs in reversed (active_replicas ):
13551370 if terminated_count >= abs (replicas_diff ):
13561371 break
13571372
@@ -1360,7 +1375,19 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
13601375
13611376 group_name = replica_jobs [0 ].replica_group_name or "default"
13621377
1363- # Skip if not autoscalable
1378+ # For rolling deployment, allow terminating any out-of-date replica
1379+ if allow_exceeding_max and is_out_of_date :
1380+ # Terminate this replica (out-of-date during rolling deployment)
1381+ for job in replica_jobs :
1382+ if not job .status .is_finished () and job .status != JobStatus .TERMINATING :
1383+ job .status = JobStatus .TERMINATING
1384+ job .termination_reason = JobTerminationReason .SCALED_DOWN
1385+
1386+ group_counts [group_name ] -= 1
1387+ terminated_count += 1
1388+ continue
1389+
1390+ # For normal scaling, skip if not autoscalable
13641391 if normalized_groups and group_name not in autoscalable_groups :
13651392 continue
13661393
@@ -1379,29 +1406,42 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
13791406 group_counts [group_name ] -= 1
13801407 terminated_count += 1
13811408 else :
1382- # SCALE UP: Choose from autoscalable groups
1383- autoscalable_groups = [g for g in normalized_groups if g .replicas .min != g .replicas .max ]
1384-
1385- if normalized_groups and not autoscalable_groups :
1386- # No autoscalable groups, cannot scale
1387- logger .info ("%s: no autoscalable groups available for scaling up" , fmt (run_model ))
1388- return
1389-
1390- # Count current replicas per group to respect maximums
1409+ # SCALE UP
1410+ # Count current replicas per group
13911411 group_counts = {}
13921412 for _ , _ , _ , replica_jobs in active_replicas :
13931413 if replica_jobs :
13941414 group_name = replica_jobs [0 ].replica_group_name or "default"
13951415 group_counts [group_name ] = group_counts .get (group_name , 0 ) + 1
13961416
1397- # Filter groups that haven't reached maximum
1398- eligible_groups = [
1399- g for g in autoscalable_groups if group_counts .get (g .name , 0 ) < (g .replicas .max or float ("inf" ))
1400- ] if normalized_groups else normalized_groups
1417+ # First, identify groups below minimum (need to scale regardless of autoscalability)
1418+ below_min_groups = [
1419+ g for g in normalized_groups
1420+ if group_counts .get (g .name , 0 ) < (g .replicas .min or 0 )
1421+ ]
1422+
1423+ # Then, identify autoscalable groups that can scale beyond minimum
1424+ autoscalable_groups = [
1425+ g for g in normalized_groups
1426+ if g .replicas .min != g .replicas .max and (
1427+ allow_exceeding_max or group_counts .get (g .name , 0 ) < (g .replicas .max or float ("inf" ))
1428+ )
1429+ ]
1430+
1431+ # Eligible groups are: below-min groups + autoscalable groups
1432+ eligible_groups = []
1433+ if below_min_groups :
1434+ eligible_groups .extend (below_min_groups )
1435+ elif autoscalable_groups :
1436+ # Only use autoscalable groups if no groups are below minimum
1437+ eligible_groups .extend (autoscalable_groups )
1438+ elif allow_exceeding_max and normalized_groups :
1439+ # For rolling deployments, allow exceeding max even for fixed groups
1440+ eligible_groups .extend (normalized_groups )
14011441
14021442 if normalized_groups and not eligible_groups :
1403- # All groups at maximum
1404- logger .info ("%s: all autoscalable groups at maximum capacity " , fmt (run_model ))
1443+ # All groups at their limits
1444+ logger .info ("%s: all replica groups at their limits (min/max) " , fmt (run_model ))
14051445 return
14061446
14071447 scheduled_replicas = 0
@@ -1410,10 +1450,10 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
14101450 for _ , _ , _ , replica_jobs in inactive_replicas :
14111451 if scheduled_replicas == replicas_diff :
14121452 break
1413- # Only reuse if from autoscalable group
1453+ # Only reuse if from eligible group
14141454 if replica_jobs :
14151455 group_name = replica_jobs [0 ].replica_group_name or "default"
1416- if not normalized_groups or group_name in {g .name for g in autoscalable_groups }:
1456+ if not normalized_groups or group_name in {g .name for g in eligible_groups }:
14171457 await retry_run_replica_jobs (session , run_model , replica_jobs , only_failed = False )
14181458 scheduled_replicas += 1
14191459
0 commit comments