Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dstack/_internal/cli/commands/attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _register(self):
)
self._parser.add_argument(
"--replica",
help="The replica number. Defaults to 0.",
help="The replica number. Defaults to any running replica.",
type=int,
default=0,
required=None,
)
self._parser.add_argument(
"--job",
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def _is_ready_to_attach(run: Run) -> bool:
]
or run._run.jobs[0].job_submissions[-1].status
in [JobStatus.SUBMITTED, JobStatus.PROVISIONING, JobStatus.PULLING]
or run._run.is_deployment_in_progress()
)


Expand Down
20 changes: 16 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,16 @@ def get_runs_table(

for run in runs:
run = run._run # TODO(egor-s): make public attribute
show_deployment_num = (
verbose
and run.run_spec.configuration.type == "service"
or run.is_deployment_in_progress()
)
merge_job_rows = len(run.jobs) == 1 and not show_deployment_num

run_row: Dict[Union[str, int], Any] = {
"NAME": run.run_spec.run_name,
"NAME": run.run_spec.run_name
+ (f" [secondary]deployment={run.deployment_num}[/]" if show_deployment_num else ""),
"SUBMITTED": format_date(run.submitted_at),
"STATUS": (
run.latest_job_submission.status_message
Expand All @@ -174,7 +181,7 @@ def get_runs_table(
}
if run.error:
run_row["ERROR"] = run.error
if len(run.jobs) != 1:
if not merge_job_rows:
add_row_from_dict(table, run_row)

for job in run.jobs:
Expand All @@ -184,7 +191,12 @@ def get_runs_table(
inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs)
status += f" (inactive for {inactive_for})"
job_row: Dict[Union[str, int], Any] = {
"NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
"NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}"
+ (
f" deployment={latest_job_submission.deployment_num}"
if show_deployment_num
else ""
),
"STATUS": latest_job_submission.status_message,
"SUBMITTED": format_date(latest_job_submission.submitted_at),
"ERROR": latest_job_submission.error,
Expand All @@ -208,7 +220,7 @@ def get_runs_table(
"PRICE": f"${jpd.price:.4f}".rstrip("0").rstrip("."),
}
)
if len(run.jobs) == 1:
if merge_job_rows:
# merge rows
job_row.update(run_row)
add_row_from_dict(table, job_row, style="secondary" if len(run.jobs) != 1 else None)
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
if current_resource is not None:
current_resource_excludes = {}
current_resource_excludes["status_message"] = True
if current_resource.deployment_num == 0:
current_resource_excludes["deployment_num"] = True
apply_plan_excludes["current_resource"] = current_resource_excludes
current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec)
job_submissions_excludes = {}
Expand All @@ -36,6 +38,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
}
if all(js.exit_status is None for js in job_submissions):
job_submissions_excludes["exit_status"] = True
if all(js.deployment_num == 0 for js in job_submissions):
job_submissions_excludes["deployment_num"] = True
latest_job_submission = current_resource.latest_job_submission
if latest_job_submission is not None:
latest_job_submission_excludes = {}
Expand All @@ -50,6 +54,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
}
if latest_job_submission.exit_status is None:
latest_job_submission_excludes["exit_status"] = True
if latest_job_submission.deployment_num == 0:
latest_job_submission_excludes["deployment_num"] = True
return {"plan": apply_plan_excludes}


Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ class ClusterInfo(CoreModel):
class JobSubmission(CoreModel):
id: UUID4
submission_num: int
deployment_num: int = 0 # default for compatibility with pre-TODO servers
submitted_at: datetime
last_processed_at: datetime
finished_at: Optional[datetime]
Expand Down Expand Up @@ -516,6 +517,7 @@ class Run(CoreModel):
latest_job_submission: Optional[JobSubmission]
cost: float = 0
service: Optional[ServiceSpec] = None
deployment_num: int = 0 # default for compatibility with pre-TODO servers
# TODO: make error a computed field after migrating to pydanticV2
error: Optional[str] = None
deleted: Optional[bool] = None
Expand Down Expand Up @@ -578,6 +580,13 @@ def _get_status_message(
return "retrying"
return status.value

def is_deployment_in_progress(self) -> bool:
return any(
not j.job_submissions[-1].status.is_finished()
and j.job_submissions[-1].deployment_num != self.deployment_num
for j in self.jobs
)


class JobPlan(CoreModel):
job_spec: JobSpec
Expand Down
204 changes: 144 additions & 60 deletions src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import asyncio
import datetime
import itertools
from typing import List, Optional, Set, Tuple

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload

import dstack._internal.server.services.gateways as gateways
import dstack._internal.server.services.services.autoscalers as autoscalers
from dstack._internal.core.errors import ServerError
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
from dstack._internal.core.models.runs import (
Job,
JobSpec,
JobStatus,
JobTerminationReason,
Run,
Expand All @@ -24,22 +23,23 @@
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
from dstack._internal.server.services.jobs import (
find_job,
get_jobs_from_run_spec,
get_job_specs_from_run_spec,
group_jobs_by_replica_latest,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.runs import (
create_job_model_for_new_submission,
fmt,
process_terminating_run,
retry_run_replica_jobs,
run_model_to_run,
scale_run_replicas,
)
from dstack._internal.server.services.services import update_service_desired_replica_count
from dstack._internal.utils import common
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment


async def process_runs(batch_size: int = 1):
Expand Down Expand Up @@ -133,46 +133,22 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
logger.debug("%s: pending run is not yet ready for resubmission", fmt(run_model))
return

# TODO(egor-s) consolidate with `scale_run_replicas` if possible
replicas = 1
run_model.desired_replica_count = 1
if run.run_spec.configuration.type == "service":
replicas = run.run_spec.configuration.replicas.min or 0 # new default
scaler = autoscalers.get_service_scaler(run.run_spec.configuration)
stats = None
if run_model.gateway_id is not None:
conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
stats = await conn.get_stats(run_model.project.name, run_model.run_name)
# replicas info doesn't matter for now
replicas = scaler.scale([], stats)
if replicas == 0:
run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0
await update_service_desired_replica_count(
session,
run_model,
run.run_spec.configuration,
# does not matter for pending services, since 0->n scaling should happen without delay
last_scaled_at=None,
)

if run_model.desired_replica_count == 0:
# stay zero scaled
return

scheduled_replicas = 0
# Resubmit existing replicas
for replica_num, replica_jobs in itertools.groupby(
run.jobs, key=lambda j: j.job_spec.replica_num
):
if scheduled_replicas >= replicas:
break
scheduled_replicas += 1
for job in replica_jobs:
new_job_model = create_job_model_for_new_submission(
run_model=run_model,
job=job,
status=JobStatus.SUBMITTED,
)
session.add(new_job_model)
# Create missing replicas
for replica_num in range(scheduled_replicas, replicas):
jobs = await get_jobs_from_run_spec(run.run_spec, replica_num=replica_num)
for job in jobs:
job_model = create_job_model_for_new_submission(
run_model=run_model,
job=job,
status=JobStatus.SUBMITTED,
)
session.add(job_model)
await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count)

run_model.status = RunStatus.SUBMITTED
logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model))
Expand Down Expand Up @@ -340,27 +316,11 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER

if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}:
# No need to retry if the run is terminating,
# No need to retry, scale, or redeploy replicas if the run is terminating,
# pending run will retry replicas in `process_pending_run`
for _, replica_jobs in replicas_to_retry:
await retry_run_replica_jobs(
session, run_model, replica_jobs, only_failed=retry_single_job
)

if run_spec.configuration.type == "service":
scaler = autoscalers.get_service_scaler(run_spec.configuration)
stats = None
if run_model.gateway_id is not None:
conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
stats = await conn.get_stats(run_model.project.name, run_model.run_name)
# use replicas_info from before retrying
replicas_diff = scaler.scale(replicas_info, stats)
if replicas_diff != 0:
# FIXME: potentially long write transaction
# Why do we flush here?
await session.flush()
await session.refresh(run_model)
await scale_run_replicas(session, run_model, replicas_diff)
await _handle_run_replicas(
session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info
)

if run_model.status != new_status:
logger.info(
Expand All @@ -378,6 +338,130 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
run_model.resubmission_attempt += 1


async def _handle_run_replicas(
session: AsyncSession,
run_model: RunModel,
run_spec: RunSpec,
replicas_to_retry: list[tuple[int, list[JobModel]]],
retry_single_job: bool,
replicas_info: list[autoscalers.ReplicaInfo],
) -> None:
"""
Does ONE of:
- replica retry
- replica scaling
- replica rolling deployment

Does not do everything at once to avoid conflicts between the stages and long DB transactions.
"""

if replicas_to_retry:
for _, replica_jobs in replicas_to_retry:
await retry_run_replica_jobs(
session, run_model, replica_jobs, only_failed=retry_single_job
)
return

if run_spec.configuration.type == "service":
await update_service_desired_replica_count(
session,
run_model,
run_spec.configuration,
# FIXME: should only include scaling events, not retries and deployments
last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
)

max_replica_count = run_model.desired_replica_count
if _has_out_of_date_replicas(run_model):
# allow extra replicas when deployment is in progress
max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE

active_replica_count = sum(1 for r in replicas_info if r.active)
if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1):
await scale_run_replicas(
session,
run_model,
replicas_diff=run_model.desired_replica_count - active_replica_count,
)
return

await _update_jobs_to_new_deployment_in_place(run_model)
if _has_out_of_date_replicas(run_model):
non_terminated_replica_count = len(
{j.replica_num for j in run_model.jobs if not j.status.is_finished()}
)
# Avoid using too much hardware during a deployment - never have
# more than max_replica_count non-terminated replicas.
if non_terminated_replica_count < max_replica_count:
# Start more up-to-date replicas that will eventually replace out-of-date replicas.
await scale_run_replicas(
session,
run_model,
replicas_diff=max_replica_count - non_terminated_replica_count,
)

replicas_to_stop_count = 0
# stop any out-of-date replicas that are not running
replicas_to_stop_count += len(
{
j.replica_num
for j in run_model.jobs
if j.status
not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses()
and j.deployment_num < run_model.deployment_num
}
)
running_replica_count = len(
{j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING}
)
if running_replica_count > run_model.desired_replica_count:
# stop excessive running out-of-date replicas
replicas_to_stop_count += running_replica_count - run_model.desired_replica_count
if replicas_to_stop_count:
await scale_run_replicas(
session,
run_model,
replicas_diff=-replicas_to_stop_count,
)


async def _update_jobs_to_new_deployment_in_place(run_model: RunModel) -> None:
"""
Bump deployment_num for jobs that do not require redeployment.
"""

for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
if all(j.status.is_finished() for j in job_models):
continue
if all(j.deployment_num == run_model.deployment_num for j in job_models):
continue
new_job_specs = await get_job_specs_from_run_spec(
run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
Comment thread
jvstme marked this conversation as resolved.
Outdated
replica_num=replica_num,
)
assert len(new_job_specs) == len(job_models), (
"Changing the number of jobs within a replica is not yet supported"
)
can_update_all_jobs = True
for old_job_model, new_job_spec in zip(job_models, new_job_specs):
old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
if new_job_spec != old_job_spec:
can_update_all_jobs = False
break
if can_update_all_jobs:
for job_model in job_models:
job_model.deployment_num = run_model.deployment_num


def _has_out_of_date_replicas(run: RunModel) -> bool:
for job in run.jobs:
if job.deployment_num < run.deployment_num and not (
job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN
):
return True
return False


def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]:
"""
Checks if the job should be retried.
Expand Down
Loading
Loading