diff --git a/src/dstack/_internal/cli/commands/attach.py b/src/dstack/_internal/cli/commands/attach.py index 12a036a5e5..b0d414a8c7 100644 --- a/src/dstack/_internal/cli/commands/attach.py +++ b/src/dstack/_internal/cli/commands/attach.py @@ -52,9 +52,8 @@ def _register(self): ) self._parser.add_argument( "--replica", - help="The replica number. Defaults to 0.", + help="The replica number. Defaults to any running replica.", type=int, - default=0, ) self._parser.add_argument( "--job", @@ -129,14 +128,15 @@ def _print_finished_message_when_available(run: Run) -> None: def _print_attached_message( run: Run, bind_address: Optional[str], - replica_num: int, + replica_num: Optional[int], job_num: int, ): if bind_address is None: bind_address = "localhost" - output = f"Attached to run [code]{run.name}[/] (replica={replica_num} job={job_num})\n" job = get_or_error(run._find_job(replica_num=replica_num, job_num=job_num)) + replica_num = job.job_spec.replica_num + output = f"Attached to run [code]{run.name}[/] (replica={replica_num} job={job_num})\n" name = run.name if replica_num != 0 or job_num != 0: name = job.job_spec.job_name diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 28bc0a07f6..dcfe5e3477 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -599,6 +599,7 @@ def _is_ready_to_attach(run: Run) -> bool: ] or run._run.jobs[0].job_submissions[-1].status in [JobStatus.SUBMITTED, JobStatus.PROVISIONING, JobStatus.PULLING] + or run._run.is_deployment_in_progress() ) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index ad3af41fb0..e1cdbaa732 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -162,9 +162,16 @@ def get_runs_table( for run in runs: run = run._run # TODO(egor-s): make public attribute + show_deployment_num = ( + verbose + and run.run_spec.configuration.type == "service" + or run.is_deployment_in_progress() + ) + merge_job_rows = len(run.jobs) == 1 and not show_deployment_num run_row: Dict[Union[str, int], Any] = { - "NAME": run.run_spec.run_name, + "NAME": run.run_spec.run_name + + (f" [secondary]deployment={run.deployment_num}[/]" if show_deployment_num else ""), "SUBMITTED": format_date(run.submitted_at), "STATUS": ( run.latest_job_submission.status_message @@ -174,7 +181,7 @@ def get_runs_table( } if run.error: run_row["ERROR"] = run.error - if len(run.jobs) != 1: + if not merge_job_rows: add_row_from_dict(table, run_row) for job in run.jobs: @@ -184,7 +191,12 @@ def get_runs_table( inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs) status += f" (inactive for {inactive_for})" job_row: Dict[Union[str, int], Any] = { - "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}", + "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}" + + ( + f" deployment={latest_job_submission.deployment_num}" + if show_deployment_num + else "" + ), "STATUS": latest_job_submission.status_message, "SUBMITTED": format_date(latest_job_submission.submitted_at), "ERROR": latest_job_submission.error, @@ -208,7 +220,7 @@ def get_runs_table( "PRICE": f"${jpd.price:.4f}".rstrip("0").rstrip("."), } ) - if len(run.jobs) == 1: + if merge_job_rows: # merge rows job_row.update(run_row) add_row_from_dict(table, job_row, style="secondary" if len(run.jobs) != 1 else None) diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 09e9b4ea7a..385f9bd8fa 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -19,6 +19,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: if current_resource is not None: current_resource_excludes = {} current_resource_excludes["status_message"] = True + if current_resource.deployment_num == 0: + current_resource_excludes["deployment_num"] = True apply_plan_excludes["current_resource"] = current_resource_excludes current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec) job_submissions_excludes = {} @@ -36,6 +38,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: } if all(js.exit_status is None for js in job_submissions): job_submissions_excludes["exit_status"] = True + if all(js.deployment_num == 0 for js in job_submissions): + job_submissions_excludes["deployment_num"] = True latest_job_submission = current_resource.latest_job_submission if latest_job_submission is not None: latest_job_submission_excludes = {} @@ -50,6 +54,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: } if latest_job_submission.exit_status is None: latest_job_submission_excludes["exit_status"] = True + if latest_job_submission.deployment_num == 0: + latest_job_submission_excludes["deployment_num"] = True return {"plan": apply_plan_excludes} diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index c241044879..5cff6571f0 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -289,6 +289,7 @@ class ClusterInfo(CoreModel): class JobSubmission(CoreModel): id: UUID4 submission_num: int + deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers submitted_at: datetime last_processed_at: datetime finished_at: Optional[datetime] @@ -516,6 +517,7 @@ class Run(CoreModel): latest_job_submission: Optional[JobSubmission] cost: float = 0 service: Optional[ServiceSpec] = None + deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers # TODO: make error a computed field after migrating to pydanticV2 error: Optional[str] = None deleted: Optional[bool] = None @@ -578,6 +580,13 @@ def _get_status_message( return "retrying" return status.value + def is_deployment_in_progress(self) -> bool: + return any( + not j.job_submissions[-1].status.is_finished() + and j.job_submissions[-1].deployment_num != self.deployment_num + for j in self.jobs + ) + class JobPlan(CoreModel): job_spec: JobSpec diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 5237cf68eb..73dfbbe026 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -1,18 +1,17 @@ import asyncio import datetime -import itertools from typing import List, Optional, Set, Tuple from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload -import dstack._internal.server.services.gateways as gateways import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, + JobSpec, JobStatus, JobTerminationReason, Run, @@ -24,22 +23,23 @@ from dstack._internal.server.models import JobModel, ProjectModel, RunModel from dstack._internal.server.services.jobs import ( find_job, - get_jobs_from_run_spec, + get_job_specs_from_run_spec, group_jobs_by_replica_latest, ) from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.runs import ( - create_job_model_for_new_submission, fmt, process_terminating_run, retry_run_replica_jobs, run_model_to_run, scale_run_replicas, ) +from dstack._internal.server.services.services import update_service_desired_replica_count from dstack._internal.utils import common from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) +ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment async def process_runs(batch_size: int = 1): @@ -133,46 +133,22 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): logger.debug("%s: pending run is not yet ready for resubmission", fmt(run_model)) return - # TODO(egor-s) consolidate with `scale_run_replicas` if possible - replicas = 1 + run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": - replicas = run.run_spec.configuration.replicas.min or 0 # new default - scaler = autoscalers.get_service_scaler(run.run_spec.configuration) - stats = None - if run_model.gateway_id is not None: - conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id) - stats = await conn.get_stats(run_model.project.name, run_model.run_name) - # replicas info doesn't matter for now - replicas = scaler.scale([], stats) - if replicas == 0: + run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 + await update_service_desired_replica_count( + session, + run_model, + run.run_spec.configuration, + # does not matter for pending services, since 0->n scaling should happen without delay + last_scaled_at=None, + ) + + if run_model.desired_replica_count == 0: # stay zero scaled return - scheduled_replicas = 0 - # Resubmit existing replicas - for replica_num, replica_jobs in itertools.groupby( - run.jobs, key=lambda j: j.job_spec.replica_num - ): - if scheduled_replicas >= replicas: - break - scheduled_replicas += 1 - for job in replica_jobs: - new_job_model = create_job_model_for_new_submission( - run_model=run_model, - job=job, - status=JobStatus.SUBMITTED, - ) - session.add(new_job_model) - # Create missing replicas - for replica_num in range(scheduled_replicas, replicas): - jobs = await get_jobs_from_run_spec(run.run_spec, replica_num=replica_num) - for job in jobs: - job_model = create_job_model_for_new_submission( - run_model=run_model, - job=job, - status=JobStatus.SUBMITTED, - ) - session.add(job_model) + await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) run_model.status = RunStatus.SUBMITTED logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model)) @@ -340,27 +316,11 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}: - # No need to retry if the run is terminating, + # No need to retry, scale, or redeploy replicas if the run is terminating, # pending run will retry replicas in `process_pending_run` - for _, replica_jobs in replicas_to_retry: - await retry_run_replica_jobs( - session, run_model, replica_jobs, only_failed=retry_single_job - ) - - if run_spec.configuration.type == "service": - scaler = autoscalers.get_service_scaler(run_spec.configuration) - stats = None - if run_model.gateway_id is not None: - conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id) - stats = await conn.get_stats(run_model.project.name, run_model.run_name) - # use replicas_info from before retrying - replicas_diff = scaler.scale(replicas_info, stats) - if replicas_diff != 0: - # FIXME: potentially long write transaction - # Why do we flush here? - await session.flush() - await session.refresh(run_model) - await scale_run_replicas(session, run_model, replicas_diff) + await _handle_run_replicas( + session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info + ) if run_model.status != new_status: logger.info( @@ -378,6 +338,130 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): run_model.resubmission_attempt += 1 +async def _handle_run_replicas( + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + replicas_to_retry: list[tuple[int, list[JobModel]]], + retry_single_job: bool, + replicas_info: list[autoscalers.ReplicaInfo], +) -> None: + """ + Does ONE of: + - replica retry + - replica scaling + - replica rolling deployment + + Does not do everything at once to avoid conflicts between the stages and long DB transactions. + """ + + if replicas_to_retry: + for _, replica_jobs in replicas_to_retry: + await retry_run_replica_jobs( + session, run_model, replica_jobs, only_failed=retry_single_job + ) + return + + if run_spec.configuration.type == "service": + await update_service_desired_replica_count( + session, + run_model, + run_spec.configuration, + # FIXME: should only include scaling events, not retries and deployments + last_scaled_at=max((r.timestamp for r in replicas_info), default=None), + ) + + max_replica_count = run_model.desired_replica_count + if _has_out_of_date_replicas(run_model): + # allow extra replicas when deployment is in progress + max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE + + active_replica_count = sum(1 for r in replicas_info if r.active) + if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1): + await scale_run_replicas( + session, + run_model, + replicas_diff=run_model.desired_replica_count - active_replica_count, + ) + return + + await _update_jobs_to_new_deployment_in_place(run_model, run_spec) + if _has_out_of_date_replicas(run_model): + non_terminated_replica_count = len( + {j.replica_num for j in run_model.jobs if not j.status.is_finished()} + ) + # Avoid using too much hardware during a deployment - never have + # more than max_replica_count non-terminated replicas. + if non_terminated_replica_count < max_replica_count: + # Start more up-to-date replicas that will eventually replace out-of-date replicas. + await scale_run_replicas( + session, + run_model, + replicas_diff=max_replica_count - non_terminated_replica_count, + ) + + replicas_to_stop_count = 0 + # stop any out-of-date replicas that are not running + replicas_to_stop_count += len( + { + j.replica_num + for j in run_model.jobs + if j.status + not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses() + and j.deployment_num < run_model.deployment_num + } + ) + running_replica_count = len( + {j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING} + ) + if running_replica_count > run_model.desired_replica_count: + # stop excessive running out-of-date replicas + replicas_to_stop_count += running_replica_count - run_model.desired_replica_count + if replicas_to_stop_count: + await scale_run_replicas( + session, + run_model, + replicas_diff=-replicas_to_stop_count, + ) + + +async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None: + """ + Bump deployment_num for jobs that do not require redeployment. + """ + + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): + if all(j.status.is_finished() for j in job_models): + continue + if all(j.deployment_num == run_model.deployment_num for j in job_models): + continue + new_job_specs = await get_job_specs_from_run_spec( + run_spec=run_spec, + replica_num=replica_num, + ) + assert len(new_job_specs) == len(job_models), ( + "Changing the number of jobs within a replica is not yet supported" + ) + can_update_all_jobs = True + for old_job_model, new_job_spec in zip(job_models, new_job_specs): + old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data) + if new_job_spec != old_job_spec: + can_update_all_jobs = False + break + if can_update_all_jobs: + for job_model in job_models: + job_model.deployment_num = run_model.deployment_num + + +def _has_out_of_date_replicas(run: RunModel) -> bool: + for job in run.jobs: + if job.deployment_num < run.deployment_num and not ( + job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN + ): + return True + return False + + def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]: """ Checks if the job should be retried. diff --git a/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py b/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py new file mode 100644 index 0000000000..a0cac48af3 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py @@ -0,0 +1,42 @@ +"""Add rolling deployment fields + +Revision ID: 35e90e1b0d3e +Revises: 35f732ee4cf5 +Create Date: 2025-05-29 15:30:27.878569 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "35e90e1b0d3e" +down_revision = "35f732ee4cf5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True)) + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.execute("UPDATE jobs SET deployment_num = 0") + batch_op.alter_column("deployment_num", nullable=False) + + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column("desired_replica_count", sa.Integer(), nullable=True)) + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.execute("UPDATE runs SET deployment_num = 0") + batch_op.execute("UPDATE runs SET desired_replica_count = 1") + batch_op.alter_column("deployment_num", nullable=False) + batch_op.alter_column("desired_replica_count", nullable=False) + + +def downgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.drop_column("deployment_num") + batch_op.drop_column("desired_replica_count") + + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("deployment_num") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index cb39b70786..c5e4749c99 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -350,6 +350,8 @@ class RunModel(BaseModel): run_spec: Mapped[str] = mapped_column(Text) service_spec: Mapped[Optional[str]] = mapped_column(Text) priority: Mapped[int] = mapped_column(Integer, default=0) + deployment_num: Mapped[int] = mapped_column(Integer) + desired_replica_count: Mapped[int] = mapped_column(Integer) jobs: Mapped[List["JobModel"]] = relationship( back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]" @@ -404,6 +406,7 @@ class JobModel(BaseModel): instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs") used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) replica_num: Mapped[int] = mapped_column(Integer) + deployment_num: Mapped[int] = mapped_column(Integer) job_runtime_data: Mapped[Optional[str]] = mapped_column(Text) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index f25c193f87..157090dd0e 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -128,6 +128,7 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission: return JobSubmission( id=job_model.id, submission_num=job_model.submission_num, + deployment_num=job_model.deployment_num, submitted_at=job_model.submitted_at.replace(tzinfo=timezone.utc), last_processed_at=last_processed_at, finished_at=finished_at, diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index a1ec23e466..badc2a0276 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -439,6 +439,7 @@ async def apply_plan( .values( run_spec=run_spec.json(), priority=run_spec.configuration.priority, + deployment_num=current_resource.deployment_num + 1, ) ) run = await get_run_by_name( @@ -501,6 +502,8 @@ async def submit_run( run_spec=run_spec.json(), last_processed_at=submitted_at, priority=run_spec.configuration.priority, + deployment_num=0, + desired_replica_count=1, # a relevant value will be set in process_runs.py ) session.add(run_model) @@ -539,6 +542,7 @@ def create_job_model_for_new_submission( job_num=job.job_spec.job_num, job_name=f"{job.job_spec.job_name}", replica_num=job.job_spec.replica_num, + deployment_num=run_model.deployment_num, submission_num=len(job.job_submissions), submitted_at=now, last_processed_at=now, @@ -662,13 +666,9 @@ def run_model_to_run( for job_num, job_submissions in itertools.groupby( replica_submissions, key=lambda j: j.job_num ): - job_spec = None submissions = [] + job_model = None for job_model in job_submissions: - if job_spec is None: - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) - if not include_sensitive: - _remove_job_spec_sensitive_info(job_spec) if include_job_submissions: job_submission = job_model_to_job_submission(job_model) if return_in_api: @@ -680,7 +680,11 @@ def run_model_to_run( if job_submission.job_provisioning_data.ssh_port is None: job_submission.job_provisioning_data.ssh_port = 22 submissions.append(job_submission) - if job_spec is not None: + if job_model is not None: + # Use the spec from the latest submission. Submissions can have different specs + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + if not include_sensitive: + _remove_job_spec_sensitive_info(job_spec) jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) @@ -707,6 +711,7 @@ def run_model_to_run( jobs=jobs, latest_job_submission=latest_job_submission, service=service_spec, + deployment_num=run_model.deployment_num, deleted=run_model.deleted, ) run.cost = _get_run_cost(run) @@ -897,9 +902,24 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec): _CONF_UPDATABLE_FIELDS = ["priority"] _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = { "dev-environment": ["inactivity_duration"], - # Most service fields can be updated via replica redeployment. - # TODO: Allow updating other fields when rolling deployment is supported. - "service": ["replicas", "scaling", "strip_prefix"], + "service": [ + # in-place + "replicas", + "scaling", + # rolling deployment + "resources", + "volumes", + "image", + "user", + "privileged", + "entrypoint", + "python", + "nvcc", + "single_branch", + "env", + "shell", + "commands", + ], } @@ -1004,34 +1024,33 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica abs(replicas_diff), ) - # lists of (importance, replica_num, jobs) + # lists of (importance, is_out_of_date, replica_num, jobs) active_replicas = [] inactive_replicas = [] for replica_num, replica_jobs in group_jobs_by_replica_latest(run_model.jobs): statuses = set(job.status for job in replica_jobs) + deployment_num = replica_jobs[0].deployment_num # same for all jobs + is_out_of_date = deployment_num < run_model.deployment_num if {JobStatus.TERMINATING, *JobStatus.finished_statuses()} & statuses: # if there are any terminating or finished jobs, the replica is inactive - inactive_replicas.append((0, replica_num, replica_jobs)) + inactive_replicas.append((0, is_out_of_date, replica_num, replica_jobs)) elif JobStatus.SUBMITTED in statuses: # if there are any submitted jobs, the replica is active and has the importance of 0 - active_replicas.append((0, replica_num, replica_jobs)) + active_replicas.append((0, is_out_of_date, replica_num, replica_jobs)) elif {JobStatus.PROVISIONING, JobStatus.PULLING} & statuses: # if there are any provisioning or pulling jobs, the replica is active and has the importance of 1 - active_replicas.append((1, replica_num, replica_jobs)) + active_replicas.append((1, is_out_of_date, replica_num, replica_jobs)) else: # all jobs are running, the replica is active and has the importance of 2 - active_replicas.append((2, replica_num, replica_jobs)) + active_replicas.append((2, is_out_of_date, replica_num, replica_jobs)) - # sort by importance (desc) and replica_num (asc) - active_replicas.sort(key=lambda r: (-r[0], r[1])) + # sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) + active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) if replicas_diff < 0: - if len(active_replicas) + replicas_diff < run_spec.configuration.replicas.min: - raise ServerClientError("Can't scale down below the minimum number of replicas") - - for _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): + for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): # scale down the less important replicas first for job in replica_jobs: if job.status.is_finished() or job.status == JobStatus.TERMINATING: @@ -1040,18 +1059,15 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica job.termination_reason = JobTerminationReason.SCALED_DOWN # background task will process the job later else: - if len(active_replicas) + replicas_diff > run_spec.configuration.replicas.max: - raise ServerClientError("Can't scale up above the maximum number of replicas") scheduled_replicas = 0 # rerun inactive replicas - for _, _, replica_jobs in inactive_replicas: + for _, _, _, replica_jobs in inactive_replicas: if scheduled_replicas == replicas_diff: break await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) scheduled_replicas += 1 - # create new replicas for replica_num in range( len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff ): @@ -1068,7 +1084,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica async def retry_run_replica_jobs( session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool ): - for job_model in latest_jobs: + new_jobs = await get_jobs_from_run_spec( + RunSpec.__response__.parse_raw(run_model.run_spec), + replica_num=latest_jobs[0].replica_num, + ) + assert len(new_jobs) == len(latest_jobs), ( + "Changing the number of jobs within a replica is not yet supported" + ) + for job_model, new_job in zip(latest_jobs, new_jobs): if not (job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING): if only_failed: # No need to resubmit, skip @@ -1079,10 +1102,7 @@ async def retry_run_replica_jobs( new_job_model = create_job_model_for_new_submission( run_model=run_model, - job=Job( - job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data), - job_submissions=[], - ), + job=new_job, status=JobStatus.SUBMITTED, ) # dirty hack to avoid passing all job submissions diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 3f0b335f98..d0c26c2a02 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -30,6 +30,7 @@ get_project_gateway_model_by_name, ) from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.services.autoscalers import get_service_scaler from dstack._internal.server.services.services.options import get_service_options from dstack._internal.utils.logging import get_logger @@ -258,3 +259,21 @@ def _get_gateway_https(configuration: GatewayConfiguration) -> bool: if configuration.certificate is not None and configuration.certificate.type == "lets-encrypt": return True return False + + +async def update_service_desired_replica_count( + session: AsyncSession, + run_model: RunModel, + configuration: ServiceConfiguration, + last_scaled_at: Optional[int], +) -> None: + scaler = get_service_scaler(configuration) + stats = None + if run_model.gateway_id is not None: + conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + stats = await conn.get_stats(run_model.project.name, run_model.run_name) + run_model.desired_replica_count = scaler.get_desired_count( + current_desired_count=run_model.desired_replica_count, + stats=stats, + last_scaled_at=last_scaled_at, + ) diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index 0a61e830ab..47eabaab31 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -1,7 +1,7 @@ import datetime import math from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -23,14 +23,20 @@ class ReplicaInfo(BaseModel): class BaseServiceScaler(ABC): @abstractmethod - def scale(self, replicas: List[ReplicaInfo], stats: Optional[PerWindowStats]) -> int: + def get_desired_count( + self, + current_desired_count: int, + stats: Optional[PerWindowStats], + last_scaled_at: Optional[datetime.datetime], + ) -> int: """ Args: - replicas: list of all replicas stats: service usage stats + current_desired_count: currently used desired count + last_scaled_at: last time service was scaled, None if it was never scaled yet Returns: - diff: number of replicas to add or remove + desired_count: desired count of replicas """ pass @@ -49,12 +55,14 @@ def __init__( self.min_replicas = min_replicas self.max_replicas = max_replicas - def scale(self, replicas: List[ReplicaInfo], stats: Optional[PerWindowStats]) -> int: - active_replicas = [r for r in replicas if r.active] - target_replicas = len(active_replicas) - # clip the target replicas to the min and max values - target_replicas = min(max(target_replicas, self.min_replicas), self.max_replicas) - return target_replicas - len(active_replicas) + def get_desired_count( + self, + current_desired_count: int, + stats: Optional[PerWindowStats], + last_scaled_at: Optional[datetime.datetime], + ) -> int: + # clip the desired count to the min and max values + return min(max(current_desired_count, self.min_replicas), self.max_replicas) class RPSAutoscaler(BaseServiceScaler): @@ -72,40 +80,43 @@ def __init__( self.scale_up_delay = scale_up_delay self.scale_down_delay = scale_down_delay - def scale(self, replicas: List[ReplicaInfo], stats: Optional[PerWindowStats]) -> int: + def get_desired_count( + self, + current_desired_count: int, + stats: Optional[PerWindowStats], + last_scaled_at: Optional[datetime.datetime], + ) -> int: if not stats: - return 0 + return current_desired_count now = common_utils.get_current_datetime() - active_replicas = [r for r in replicas if r.active] - last_scaled_at = max((r.timestamp for r in replicas), default=None) # calculate the average RPS over the last minute rps = stats[60].requests / 60 - target_replicas = math.ceil(rps / self.target) - # clip the target replicas to the min and max values - target_replicas = min(max(target_replicas, self.min_replicas), self.max_replicas) + new_desired_count = math.ceil(rps / self.target) + # clip the desired count to the min and max values + new_desired_count = min(max(new_desired_count, self.min_replicas), self.max_replicas) - if target_replicas > len(active_replicas): - if len(active_replicas) == 0: + if new_desired_count > current_desired_count: + if current_desired_count == 0: # no replicas, scale up immediately - return target_replicas + return new_desired_count if ( last_scaled_at is not None and (now - last_scaled_at).total_seconds() < self.scale_up_delay ): # too early to scale up, wait for the delay - return 0 - return target_replicas - len(active_replicas) - elif target_replicas < len(active_replicas): + return current_desired_count + return new_desired_count + elif new_desired_count < current_desired_count: if ( last_scaled_at is not None and (now - last_scaled_at).total_seconds() < self.scale_down_delay ): # too early to scale down, wait for the delay - return 0 - return target_replicas - len(active_replicas) - return 0 + return current_desired_count + return new_desired_count + return new_desired_count def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 15cf5ffb31..31045a036b 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -265,6 +265,7 @@ async def create_run( run_id: Optional[UUID] = None, deleted: bool = False, priority: int = 0, + deployment_num: int = 0, ) -> RunModel: if run_spec is None: run_spec = get_run_spec( @@ -286,6 +287,8 @@ async def create_run( last_processed_at=submitted_at, jobs=[], priority=priority, + deployment_num=deployment_num, + desired_replica_count=1, ) session.add(run) await session.commit() @@ -305,9 +308,12 @@ async def create_job( instance: Optional[InstanceModel] = None, job_num: int = 0, replica_num: int = 0, + deployment_num: Optional[int] = None, instance_assigned: bool = False, disconnected_at: Optional[datetime] = None, ) -> JobModel: + if deployment_num is None: + deployment_num = run.deployment_num run_spec = RunSpec.parse_raw(run.run_spec) job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0] job_spec.job_num = job_num @@ -318,6 +324,7 @@ async def create_job( job_num=job_num, job_name=run.run_name + f"-{job_num}-{replica_num}", replica_num=replica_num, + deployment_num=deployment_num, submission_num=submission_num, submitted_at=submitted_at, last_processed_at=last_processed_at, diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 8ce0b590ed..9069154705 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -31,6 +31,7 @@ from dstack._internal.core.models.runs import ( Job, JobSpec, + JobStatus, RunPlan, RunSpec, RunStatus, @@ -184,7 +185,7 @@ def logs( self, start_time: Optional[datetime] = None, diagnose: bool = False, - replica_num: int = 0, + replica_num: Optional[int] = None, job_num: int = 0, ) -> Iterable[bytes]: """ @@ -246,7 +247,7 @@ def attach( ssh_identity_file: Optional[PathLike] = None, bind_address: Optional[str] = None, ports_overrides: Optional[List[PortMapping]] = None, - replica_num: int = 0, + replica_num: Optional[int] = None, job_num: int = 0, ) -> bool: """ @@ -254,6 +255,7 @@ def attach( Args: ssh_identity_file: SSH keypair to access instances. + replica_num: replica_num or None to attach to any running replica. Raises: dstack.api.PortUsedError: If ports are in use or the run is attached by another process. @@ -265,7 +267,9 @@ def attach( job = self._find_job(replica_num=replica_num, job_num=job_num) if job is None: - raise ClientError(f"Failed to find replica={replica_num} job={job_num}") + replica_repr = replica_num if replica_num is not None else "" + raise ClientError(f"Failed to find replica={replica_repr} job={job_num}") + replica_num = job.job_spec.replica_num name = self.name if replica_num != 0 or job_num != 0: @@ -358,9 +362,14 @@ def detach(self): self._ssh_attach.detach() self._ssh_attach = None - def _find_job(self, replica_num: int, job_num: int) -> Optional[Job]: + def _find_job(self, replica_num: Optional[int], job_num: int) -> Optional[Job]: for j in self._run.jobs: - if j.job_spec.replica_num == replica_num and j.job_spec.job_num == job_num: + if ( + replica_num is not None + and j.job_spec.replica_num == replica_num + or replica_num is None + and j.job_submissions[-1].status == JobStatus.RUNNING + ) and j.job_spec.job_num == job_num: return j return None diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py index bbef75b764..d616494b64 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/tasks/test_process_runs.py @@ -1,5 +1,5 @@ import datetime -from typing import Union +from typing import Union, cast from unittest.mock import patch import pytest @@ -12,8 +12,10 @@ from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( + JobSpec, JobStatus, JobTerminationReason, + RunSpec, RunStatus, RunTerminationReason, ) @@ -33,7 +35,11 @@ async def make_run( - session: AsyncSession, status: RunStatus = RunStatus.SUBMITTED, replicas: Union[str, int] = 1 + session: AsyncSession, + status: RunStatus = RunStatus.SUBMITTED, + replicas: Union[str, int] = 1, + deployment_num: int = 0, + image: str = "ubuntu:latest", ) -> RunModel: project = await create_project(session=session) user = await create_user(session=session) @@ -54,6 +60,7 @@ async def make_run( commands=["echo hello"], port=8000, replicas=parse_obj_as(Range[int], replicas), + image=image, ), ) run = await create_run( @@ -64,6 +71,7 @@ async def make_run( run_name=run_name, run_spec=run_spec, status=status, + deployment_num=deployment_num, ) run.project = project return run @@ -336,5 +344,419 @@ async def test_pending_to_submitted_adds_replicas(self, test_db, session: AsyncS assert run.jobs[2].replica_num == 1 +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestRollingDeployment: + @pytest.mark.parametrize( + ("run_status", "job_statuses"), + [ + (RunStatus.RUNNING, (JobStatus.RUNNING, JobStatus.RUNNING)), + (RunStatus.RUNNING, (JobStatus.RUNNING, JobStatus.PULLING)), + (RunStatus.PROVISIONING, (JobStatus.PROVISIONING, JobStatus.PULLING)), + (RunStatus.PROVISIONING, (JobStatus.PROVISIONING, JobStatus.PROVISIONING)), + ], + ) + async def test_updates_deployment_num_in_place( + self, + test_db, + session: AsyncSession, + run_status: RunStatus, + job_statuses: tuple[JobStatus, JobStatus], + ) -> None: + run = await make_run(session, status=run_status, replicas=2, deployment_num=1) + for replica_num, job_status in enumerate(job_statuses): + await create_job( + session=session, + run=run, + status=job_status, + replica_num=replica_num, + deployment_num=0, # out of date + ) + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == run_status + assert len(run.jobs) == 2 + assert run.jobs[0].status == job_statuses[0] + assert run.jobs[0].replica_num == 0 + assert run.jobs[0].deployment_num == 1 # updated + assert run.jobs[1].status == job_statuses[1] + assert run.jobs[1].replica_num == 1 + assert run.jobs[1].deployment_num == 1 # updated + + async def test_not_updates_deployment_num_in_place_for_finished_replica( + self, test_db, session: AsyncSession + ) -> None: + run = await make_run(session, status=RunStatus.RUNNING, deployment_num=1) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + deployment_num=0, # out of date + ) + await create_job( + session=session, + run=run, + status=JobStatus.TERMINATED, + termination_reason=JobTerminationReason.SCALED_DOWN, + replica_num=1, + deployment_num=0, # out of date + ) + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert len(run.jobs) == 2 + assert run.jobs[0].status == JobStatus.RUNNING + assert run.jobs[0].replica_num == 0 + assert run.jobs[0].deployment_num == 1 # updated + assert run.jobs[1].status == JobStatus.TERMINATED + assert run.jobs[1].replica_num == 1 + assert run.jobs[1].deployment_num == 0 # not updated + + async def test_starts_new_replica(self, test_db, session: AsyncSession) -> None: + run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") + for replica_num in range(2): + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=replica_num, + ) + + run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(run_spec.configuration, ServiceConfiguration) + run_spec.configuration.image = "new" + run.run_spec = run_spec.json() + run.deployment_num += 1 + await session.commit() + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert len(run.jobs) == 3 + # old replicas remain as-is + for replica_num in range(2): + assert run.jobs[replica_num].status == JobStatus.RUNNING + assert run.jobs[replica_num].replica_num == replica_num + assert run.jobs[replica_num].deployment_num == 0 + assert ( + cast( + JobSpec, JobSpec.__response__.parse_raw(run.jobs[replica_num].job_spec_data) + ).image_name + == "old" + ) + # an extra replica is submitted + assert run.jobs[2].status == JobStatus.SUBMITTED + assert run.jobs[2].replica_num == 2 + assert run.jobs[2].deployment_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[2].job_spec_data)).image_name + == "new" + ) + + @pytest.mark.parametrize( + "new_replica_status", [JobStatus.SUBMITTED, JobStatus.PROVISIONING, JobStatus.PULLING] + ) + async def test_not_stops_out_of_date_replica_until_new_replica_is_running( + self, test_db, session: AsyncSession, new_replica_status: JobStatus + ) -> None: + run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") + for replica_num in range(2): + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=replica_num, + ) + + run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(run_spec.configuration, ServiceConfiguration) + run_spec.configuration.image = "new" + run.run_spec = run_spec.json() + run.deployment_num += 1 + await create_job( + session=session, + run=run, + status=new_replica_status, + replica_num=2, + ) + await session.commit() + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert len(run.jobs) == 3 + # All replicas remain as-is: + # - cannot yet start a new replica - there are already 3 non-terminated replicas + # (3 = 2 desired + 1 max_surge) + # - cannot yet stop an out-of-date replica - that would only leave one running replica, + # which is less than the desired count (2) + for replica_num in range(2): + assert run.jobs[replica_num].status == JobStatus.RUNNING + assert run.jobs[replica_num].replica_num == replica_num + assert run.jobs[replica_num].deployment_num == 0 + assert ( + cast( + JobSpec, JobSpec.__response__.parse_raw(run.jobs[replica_num].job_spec_data) + ).image_name + == "old" + ) + assert run.jobs[2].status == new_replica_status + assert run.jobs[2].replica_num == 2 + assert run.jobs[2].deployment_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[2].job_spec_data)).image_name + == "new" + ) + + async def test_stops_out_of_date_replica(self, test_db, session: AsyncSession) -> None: + run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") + for replica_num in range(2): + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=replica_num, + ) + + run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(run_spec.configuration, ServiceConfiguration) + run_spec.configuration.image = "new" + run.run_spec = run_spec.json() + run.deployment_num += 1 + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=2, + ) + await session.commit() + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert len(run.jobs) == 3 + # one old replica remains as-is + assert run.jobs[0].status == JobStatus.RUNNING + assert run.jobs[0].replica_num == 0 + assert run.jobs[0].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[0].job_spec_data)).image_name + == "old" + ) + # another old replica is terminated + assert run.jobs[1].status == JobStatus.TERMINATING + assert run.jobs[1].termination_reason == JobTerminationReason.SCALED_DOWN + assert run.jobs[1].replica_num == 1 + assert run.jobs[1].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[1].job_spec_data)).image_name + == "old" + ) + # the new replica remains as-is + assert run.jobs[2].status == JobStatus.RUNNING + assert run.jobs[2].replica_num == 2 + assert run.jobs[2].deployment_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[2].job_spec_data)).image_name + == "new" + ) + + async def test_not_starts_new_replica_until_out_of_date_replica_terminated( + self, test_db, session: AsyncSession + ) -> None: + run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.SCALED_DOWN, + replica_num=1, + ) + + run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(run_spec.configuration, ServiceConfiguration) + run_spec.configuration.image = "new" + run.run_spec = run_spec.json() + run.deployment_num += 1 + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=2, + ) + await session.commit() + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert len(run.jobs) == 3 + # All replicas remain as-is: + # - cannot yet start a new replica - there are already 3 non-terminated replicas + # (3 = 2 desired + 1 max_surge) + # - cannot yet stop an out-of-date replica - that would only leave one running replica, + # which is less than the desired count (2) + assert run.jobs[0].status == JobStatus.RUNNING + assert run.jobs[0].replica_num == 0 + assert run.jobs[0].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[0].job_spec_data)).image_name + == "old" + ) + assert run.jobs[1].status == JobStatus.TERMINATING + assert run.jobs[1].termination_reason == JobTerminationReason.SCALED_DOWN + assert run.jobs[1].replica_num == 1 + assert run.jobs[1].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[1].job_spec_data)).image_name + == "old" + ) + assert run.jobs[2].status == JobStatus.RUNNING + assert run.jobs[2].replica_num == 2 + assert run.jobs[2].deployment_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[2].job_spec_data)).image_name + == "new" + ) + + async def test_reuses_vacant_replica_num_when_starting_new_replica( + self, test_db, session: AsyncSession + ) -> None: + run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.TERMINATED, + termination_reason=JobTerminationReason.SCALED_DOWN, + replica_num=1, + ) + + run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(run_spec.configuration, ServiceConfiguration) + run_spec.configuration.image = "new" + run.run_spec = run_spec.json() + run.deployment_num += 1 + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=2, + ) + await session.commit() + + await process_runs.process_runs() + await session.refresh(run) + run.jobs.sort(key=lambda j: (j.replica_num, j.submission_num)) + assert run.status == RunStatus.RUNNING + assert len(run.jobs) == 4 # 3 active submissions, 1 terminated submission + # The running old replica remains as-is + assert run.jobs[0].status == JobStatus.RUNNING + assert run.jobs[0].replica_num == 0 + assert run.jobs[0].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[0].job_spec_data)).image_name + == "old" + ) + # The terminated old replica remains as-is + assert run.jobs[1].status == JobStatus.TERMINATED + assert run.jobs[1].termination_reason == JobTerminationReason.SCALED_DOWN + assert run.jobs[1].replica_num == 1 + assert run.jobs[1].deployment_num == 0 + assert run.jobs[1].submission_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[1].job_spec_data)).image_name + == "old" + ) + # The replica_num of the terminated old replica (1) is reused for the new replica + assert run.jobs[2].status == JobStatus.SUBMITTED + assert run.jobs[2].replica_num == 1 + assert run.jobs[2].deployment_num == 1 + assert run.jobs[2].submission_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[2].job_spec_data)).image_name + == "new" + ) + # The running new replica remains as-is + assert run.jobs[3].status == JobStatus.RUNNING + assert run.jobs[3].replica_num == 2 + assert run.jobs[3].deployment_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[3].job_spec_data)).image_name + == "new" + ) + + @pytest.mark.parametrize( + "new_replica_status", [JobStatus.SUBMITTED, JobStatus.PROVISIONING, JobStatus.PULLING] + ) + async def test_stops_non_running_out_of_date_replicas_unconditionally( + self, test_db, session: AsyncSession, new_replica_status: JobStatus + ) -> None: + run = await make_run(session, status=RunStatus.PROVISIONING, replicas=2, image="old") + for replica_num in range(2): + await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + replica_num=replica_num, + ) + + run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(run_spec.configuration, ServiceConfiguration) + run_spec.configuration.image = "new" + run.run_spec = run_spec.json() + run.deployment_num += 1 + await create_job( + session=session, + run=run, + status=new_replica_status, + replica_num=2, + ) + await session.commit() + + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.PROVISIONING + assert len(run.jobs) == 3 + # The two out of date replicas transition from pulling to terminating immediately. + # No need to keep these replicas - they don't contribute to reaching the desired count. + assert run.jobs[0].status == JobStatus.TERMINATING + assert run.jobs[0].replica_num == 0 + assert run.jobs[0].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[0].job_spec_data)).image_name + == "old" + ) + assert run.jobs[1].status == JobStatus.TERMINATING + assert run.jobs[1].termination_reason == JobTerminationReason.SCALED_DOWN + assert run.jobs[1].replica_num == 1 + assert run.jobs[1].deployment_num == 0 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[1].job_spec_data)).image_name + == "old" + ) + # The new replica remains as-is + assert run.jobs[2].status == new_replica_status + assert run.jobs[2].replica_num == 2 + assert run.jobs[2].deployment_num == 1 + assert ( + cast(JobSpec, JobSpec.__response__.parse_raw(run.jobs[2].job_spec_data)).image_name + == "new" + ) + + # TODO(egor-s): TestProcessRunsMultiNode # TODO(egor-s): TestProcessRunsAutoScaling diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 3c2181a209..65c4fcd131 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -385,6 +385,7 @@ def get_dev_env_run_dict( { "id": job_id, "submission_num": 0, + "deployment_num": 0, "submitted_at": submitted_at, "last_processed_at": last_processed_at, "finished_at": finished_at, @@ -404,6 +405,7 @@ def get_dev_env_run_dict( "latest_job_submission": { "id": job_id, "submission_num": 0, + "deployment_num": 0, "submitted_at": submitted_at, "last_processed_at": last_processed_at, "inactivity_secs": None, @@ -419,6 +421,7 @@ def get_dev_env_run_dict( }, "cost": 0.0, "service": None, + "deployment_num": 0, "termination_reason": None, "error": None, "deleted": deleted, @@ -520,6 +523,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli { "id": str(job.id), "submission_num": 0, + "deployment_num": 0, "submitted_at": run1_submitted_at.isoformat(), "last_processed_at": run1_submitted_at.isoformat(), "finished_at": None, @@ -539,6 +543,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli "latest_job_submission": { "id": str(job.id), "submission_num": 0, + "deployment_num": 0, "submitted_at": run1_submitted_at.isoformat(), "last_processed_at": run1_submitted_at.isoformat(), "finished_at": None, @@ -554,6 +559,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli }, "cost": 0, "service": None, + "deployment_num": 0, "termination_reason": None, "error": None, "deleted": False, @@ -571,6 +577,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli "latest_job_submission": None, "cost": 0, "service": None, + "deployment_num": 0, "termination_reason": None, "error": None, "deleted": False, @@ -1125,6 +1132,9 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl assert response.status_code == 200, response.json() await session.refresh(run_model) updated_run = run_model_to_run(run_model) + assert run.deployment_num == 0 + assert updated_run.deployment_num == 1 + assert run.run_spec.configuration.replicas == Range(min=1, max=1) assert updated_run.run_spec.configuration.replicas == Range(min=2, max=2) diff --git a/src/tests/_internal/server/services/services/test_autoscalers.py b/src/tests/_internal/server/services/services/test_autoscalers.py index 1e65b68425..5df80b0c1e 100644 --- a/src/tests/_internal/server/services/services/test_autoscalers.py +++ b/src/tests/_internal/server/services/services/test_autoscalers.py @@ -4,7 +4,7 @@ import pytest from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, Stat -from dstack._internal.server.services.services.autoscalers import ReplicaInfo, RPSAutoscaler +from dstack._internal.server.services.services.autoscalers import BaseServiceScaler, RPSAutoscaler @pytest.fixture @@ -24,113 +24,118 @@ def stats(rps: float) -> PerWindowStats: return {60: Stat(requests=int(rps * 60), request_time=0.1)} -def replica(time: datetime.datetime, active: bool = True, timestamp: int = -3600) -> ReplicaInfo: - return ReplicaInfo( - active=active, - timestamp=time + datetime.timedelta(seconds=timestamp), - ) - - class TestRPSAutoscaler: - def test_do_not_scale(self, rps_scaler, time): - assert rps_scaler.scale([replica(time, active=True)], stats(rps=10)) == 0 - - def test_scale_up(self, rps_scaler, time): - assert rps_scaler.scale([replica(time, active=True)], stats(rps=20)) == 1 + def test_do_not_scale(self, rps_scaler: BaseServiceScaler, time: datetime.datetime) -> None: + assert ( + rps_scaler.get_desired_count( + current_desired_count=1, + stats=stats(rps=10), + last_scaled_at=time - datetime.timedelta(seconds=3600), + ) + == 1 + ) - def test_scale_up_high_load(self, rps_scaler, time): + def test_scale_up(self, rps_scaler: BaseServiceScaler, time: datetime.datetime) -> None: assert ( - rps_scaler.scale( - [ - replica(time, active=True), - replica(time, active=True), - ], - stats(rps=50), + rps_scaler.get_desired_count( + current_desired_count=1, + stats=stats(rps=20), + last_scaled_at=time - datetime.timedelta(seconds=3600), ) - == 3 + == 2 ) - def test_scale_up_replicas_limit(self, rps_scaler, time): + def test_scale_up_high_load( + self, rps_scaler: BaseServiceScaler, time: datetime.datetime + ) -> None: assert ( - rps_scaler.scale( - [ - replica(time, active=True), - replica(time, active=True), - ], - stats(rps=1000), + rps_scaler.get_desired_count( + current_desired_count=2, + stats=stats(rps=50), + last_scaled_at=time - datetime.timedelta(seconds=3600), ) - == 3 + == 5 ) - def test_scale_down(self, rps_scaler, time): + def test_scale_up_replicas_limit( + self, rps_scaler: BaseServiceScaler, time: datetime.datetime + ) -> None: assert ( - rps_scaler.scale( - [replica(time, active=True), replica(time, active=True)], stats(rps=5) + rps_scaler.get_desired_count( + current_desired_count=2, + stats=stats(rps=1000), + last_scaled_at=time - datetime.timedelta(seconds=3600), ) - == -1 + == 5 ) - def test_scale_up_delayed_running(self, rps_scaler, time): + def test_scale_down(self, rps_scaler: BaseServiceScaler, time: datetime.datetime) -> None: assert ( - rps_scaler.scale( - [ - # submitted 1 minute ago, but the delay is 5 minutes - replica(time, active=True, timestamp=-60), - ], - stats(rps=20), + rps_scaler.get_desired_count( + current_desired_count=2, + stats=stats(rps=5), + last_scaled_at=time - datetime.timedelta(seconds=3600), ) - == 0 + == 1 ) - def test_scale_up_delayed_terminated(self, rps_scaler, time): + def test_scale_up_delayed( + self, rps_scaler: BaseServiceScaler, time: datetime.datetime + ) -> None: assert ( - rps_scaler.scale( - [ - replica(time, active=True), - # terminated 1 minute ago, but the delay is 5 minutes - replica(time, active=False, timestamp=-60), - ], - stats(rps=20), + rps_scaler.get_desired_count( + current_desired_count=1, + stats=stats(rps=20), + # last scaled 1 minute ago, but the delay is 5 minutes + last_scaled_at=time - datetime.timedelta(seconds=60), ) - == 0 + == 1 ) - def test_scale_down_delayed(self, rps_scaler, time): + def test_scale_down_delayed( + self, rps_scaler: BaseServiceScaler, time: datetime.datetime + ) -> None: assert ( - rps_scaler.scale( - [ - replica(time, active=True), - # submitted 5 minutes ago, but the delay is 10 minutes - replica(time, active=True, timestamp=-5 * 60), - ], - stats(rps=5), + rps_scaler.get_desired_count( + current_desired_count=2, + stats=stats(rps=5), + # last scaled 5 minutes ago, but the delay is 10 minutes + last_scaled_at=time - datetime.timedelta(seconds=5 * 60), ) - == 0 + == 2 ) - def test_scale_from_zero_immediately(self, rps_scaler, time): - assert rps_scaler.scale([], stats(rps=5)) == 1 + def test_scale_from_zero_first_time( + self, rps_scaler: BaseServiceScaler, time: datetime.datetime + ) -> None: + assert ( + rps_scaler.get_desired_count( + current_desired_count=0, + stats=stats(rps=5), + last_scaled_at=None, + ) + == 1 + ) - def test_scale_from_zero_immediately_terminated(self, rps_scaler, time): + def test_scale_from_zero_immediately( + self, rps_scaler: BaseServiceScaler, time: datetime.datetime + ) -> None: assert ( - rps_scaler.scale( - [ - # terminated 1 minute ago, but there are requests - replica(time, active=False, timestamp=-60), - ], - stats(rps=5), + rps_scaler.get_desired_count( + current_desired_count=0, + stats=stats(rps=5), + # last scaled 1 second ago, but there are requests + last_scaled_at=time - datetime.timedelta(seconds=1), ) == 1 ) - def test_scale_to_zero(self, rps_scaler, time): + def test_scale_to_zero(self, rps_scaler: BaseServiceScaler, time: datetime.datetime) -> None: assert ( - rps_scaler.scale( - [ - replica(time, active=True), - replica(time, active=True), - ], - stats(rps=0), + rps_scaler.get_desired_count( + current_desired_count=2, + stats=stats(rps=0), + last_scaled_at=time - datetime.timedelta(seconds=3600), ) - == -2 + == 0 ) diff --git a/src/tests/_internal/server/services/test_runs.py b/src/tests/_internal/server/services/test_runs.py index 0782e3b8b0..7d3f2f4595 100644 --- a/src/tests/_internal/server/services/test_runs.py +++ b/src/tests/_internal/server/services/test_runs.py @@ -4,7 +4,7 @@ from pydantic import parse_obj_as from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.errors import ServerClientError, ServerError +from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ScalingSpec, ServiceConfiguration from dstack._internal.core.models.profiles import Profile @@ -176,32 +176,6 @@ async def test_downscale_greater_replica_num(self, test_db, session: AsyncSessio assert run.jobs[1].status == JobStatus.TERMINATING assert run.jobs[1].termination_reason == JobTerminationReason.SCALED_DOWN - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_no_downscale_below_limit(self, test_db, session: AsyncSession): - run = await make_run( - session, - [ - JobStatus.RUNNING, - ], - replicas="1..2", - ) - with pytest.raises(ServerError): - await scale_wrapper(session, run, -1) - - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_no_upscale_above_limit(self, test_db, session: AsyncSession): - run = await make_run( - session, - [ - JobStatus.RUNNING, - ], - replicas="0..1", - ) - with pytest.raises(ServerError): - await scale_wrapper(session, run, 1) - @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_upscale_mixed(self, test_db, session: AsyncSession):