From fe41d1867c732d98fc895d2ed213d56b923db57c Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Sun, 29 Jun 2025 18:21:46 +0200 Subject: [PATCH] Rolling deployments for repo updates - Support rolling deployments for services when the run repo is updated: new commits are added, the branch is changed, uncommitted files are updated, etc. Switching from one repo to another is not yet supported. > **Note**: A side effect of this change is that if the run configuration file is stored in the same repo that is used for the run, any changes to the configuration file will also be considered a change to the repo (`repo_code_hash`), and hence require a rolling deployment for services or a full restart for tasks and dev environments. Previously, this was not the case, since changes to `repo_code_hash` were ignored for existing jobs. This new behavior makes it more difficult to avoid redeployment when changing some configuration properties, namely `priority`, `inactivity_duration`, `replicas`, and `scaling`. However, we consider this acceptable, since changing these properties in-place is an advanced use case and can still be achieved by moving the configuration file out of the repo. - Improve run plan output in `dstack apply` when attempting an in-place update: - Show not only the list of changed configuration properties, but also other changes from the run spec, such as repo-related changes. ```shell $ dstack apply -f test-service.dstack.yml Active run test-service already exists. Detected changes that can be updated in-place: - Repo state (branch, commit, or other) - Repo files - Configuration properties: - env Update the run? [y/n]: ``` - Show the list of changes not only when in-place update is possible, but also when it is not. This will help users understand why a run cannot be updated in-place. ```shell $ dstack apply -f test-service.dstack.yml Active run test-service already exists. Detected changes that cannot be updated in-place: - Repo files - Configuration properties: - gateway Stop and override the run? [y/n]: ``` Currently, all detected changes are listed together. An area for future improvement is highlighting the changes that prevent an in-place update. --- runner/internal/executor/executor.go | 8 ++ runner/internal/executor/executor_test.go | 5 +- runner/internal/executor/repo.go | 10 +- runner/internal/schemas/schemas.go | 4 + .../cli/services/configurators/run.py | 71 +++++++++++--- .../_internal/core/compatibility/runs.py | 25 ++++- src/dstack/_internal/core/models/runs.py | 8 ++ .../background/tasks/process_running_jobs.py | 17 +++- src/dstack/_internal/server/schemas/runner.py | 1 + .../services/jobs/configurators/base.py | 2 + src/dstack/_internal/server/services/runs.py | 16 +++- src/dstack/_internal/server/testing/common.py | 10 +- src/dstack/_internal/utils/nested_list.py | 47 ++++++++++ src/dstack/api/_public/runs.py | 30 ++++-- .../cli/services/configurators/test_run.py | 94 ++++++++++++++++++- .../_internal/server/routers/test_runs.py | 4 + src/tests/_internal/utils/test_nested_list.py | 56 +++++++++++ 17 files changed, 367 insertions(+), 41 deletions(-) create mode 100644 src/dstack/_internal/utils/nested_list.py create mode 100644 src/tests/_internal/utils/test_nested_list.py diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 14f46d035c..c720eb02e9 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -276,6 +276,14 @@ func (ex *RunExecutor) SetRunnerState(state string) { ex.state = state } +func (ex *RunExecutor) getRepoData() schemas.RepoData { + if ex.jobSpec.RepoData == nil { + // jobs submitted before 0.19.17 do not have jobSpec.RepoData + return ex.run.RunSpec.RepoData + } + return *ex.jobSpec.RepoData +} + func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error { node_rank := ex.jobSpec.JobNum nodes_num := ex.jobSpec.JobsPerReplica diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index 2165521bff..8d275b1375 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -132,7 +132,7 @@ func TestExecutor_RemoteRepo(t *testing.T) { var b bytes.Buffer ex := makeTestExecutor(t) - ex.run.RunSpec.RepoData = schemas.RepoData{ + ex.jobSpec.RepoData = &schemas.RepoData{ RepoType: "remote", RepoBranch: "main", RepoHash: "2b83592e506ed6fe8e49f4eaa97c3866bc9402b1", @@ -148,7 +148,7 @@ func TestExecutor_RemoteRepo(t *testing.T) { err = ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.run.RunSpec.RepoData.RepoHash, ex.run.RunSpec.RepoData.RepoConfigName, ex.run.RunSpec.RepoData.RepoConfigEmail) + expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) assert.Equal(t, expected, b.String()) } @@ -178,6 +178,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor { Env: make(map[string]string), MaxDuration: 0, // no timeout WorkingDir: &workingDir, + RepoData: &schemas.RepoData{RepoType: "local"}, }, Secrets: make(map[string]string), RepoCredentials: &schemas.RepoCredentials{ diff --git a/runner/internal/executor/repo.go b/runner/internal/executor/repo.go index 5afbb2515a..4e6271f29e 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/executor/repo.go @@ -40,7 +40,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { err = gerrors.Wrap(err_) } }() - switch ex.run.RunSpec.RepoData.RepoType { + switch ex.getRepoData().RepoType { case "remote": log.Trace(ctx, "Fetching git repository") if err := ex.prepareGit(ctx); err != nil { @@ -52,7 +52,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { return gerrors.Wrap(err) } default: - return gerrors.Newf("unknown RepoType: %s", ex.run.RunSpec.RepoData.RepoType) + return gerrors.Newf("unknown RepoType: %s", ex.getRepoData().RepoType) } return err } @@ -61,8 +61,8 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { repoManager := repo.NewManager( ctx, ex.repoCredentials.CloneURL, - ex.run.RunSpec.RepoData.RepoBranch, - ex.run.RunSpec.RepoData.RepoHash, + ex.getRepoData().RepoBranch, + ex.getRepoData().RepoHash, ex.jobSpec.SingleBranch, ).WithLocalPath(ex.workingDir) if ex.repoCredentials != nil { @@ -92,7 +92,7 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { if err := repoManager.Checkout(); err != nil { return gerrors.Wrap(err) } - if err := repoManager.SetConfig(ex.run.RunSpec.RepoData.RepoConfigName, ex.run.RunSpec.RepoData.RepoConfigEmail); err != nil { + if err := repoManager.SetConfig(ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail); err != nil { return gerrors.Wrap(err) } diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 8d4cef2f1c..e3b950639c 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -68,6 +68,10 @@ type JobSpec struct { MaxDuration int `json:"max_duration"` SSHKey *SSHKey `json:"ssh_key"` WorkingDir *string `json:"working_dir"` + // `RepoData` is optional for compatibility with jobs submitted before 0.19.17. + // Use `RunExecutor.getRepoData()` to get non-nil `RepoData`. + // TODO: make required when supporting jobs submitted before 0.19.17 is no longer relevant. + RepoData *RepoData `json:"repo_data"` } type ClusterInfo struct { diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index b23d227a5a..7adc3b90be 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -41,12 +41,13 @@ ) from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.resources import CPUSpec -from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunStatus +from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.diff import diff_models from dstack._internal.utils.common import local_time from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.nested_list import NestedList, NestedListItem from dstack.api._public.repos import get_ssh_keypair from dstack.api._public.runs import Run from dstack.api.utils import load_profile @@ -102,25 +103,20 @@ def apply_configuration( confirm_message = f"Submit the run [code]{conf.name}[/]?" stop_run_name = None if run_plan.current_resource is not None: - changed_fields = [] - if run_plan.action == ApplyAction.UPDATE: - diff = diff_models( - run_plan.get_effective_run_spec().configuration, - run_plan.current_resource.run_spec.configuration, - ) - changed_fields = list(diff.keys()) - if run_plan.action == ApplyAction.UPDATE and len(changed_fields) > 0: + diff = render_run_spec_diff( + run_plan.get_effective_run_spec(), + run_plan.current_resource.run_spec, + ) + if run_plan.action == ApplyAction.UPDATE and diff is not None: console.print( f"Active run [code]{conf.name}[/] already exists." - " Detected configuration changes that can be updated in-place:" - f" {changed_fields}" + f" Detected changes that [code]can[/] be updated in-place:\n{diff}" ) confirm_message = "Update the run?" - elif run_plan.action == ApplyAction.UPDATE and len(changed_fields) == 0: + elif run_plan.action == ApplyAction.UPDATE and diff is None: stop_run_name = run_plan.current_resource.run_spec.run_name console.print( - f"Active run [code]{conf.name}[/] already exists." - " Detected no configuration changes." + f"Active run [code]{conf.name}[/] already exists. Detected no changes." ) if command_args.yes and not command_args.force: console.print("Use --force to apply anyway.") @@ -129,7 +125,8 @@ def apply_configuration( elif not run_plan.current_resource.status.is_finished(): stop_run_name = run_plan.current_resource.run_spec.run_name console.print( - f"Active run [code]{conf.name}[/] already exists and cannot be updated in-place." + f"Active run [code]{conf.name}[/] already exists." + f" Detected changes that [error]cannot[/] be updated in-place:\n{diff}" ) confirm_message = "Stop and override the run?" @@ -611,3 +608,47 @@ def _run_resubmitted(run: Run, current_job_submission: Optional[JobSubmission]) not run.status.is_finished() and run._run.latest_job_submission.submitted_at > current_job_submission.submitted_at ) + + +def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]: + changed_spec_fields = list(diff_models(old_spec, new_spec)) + if not changed_spec_fields: + return None + friendly_spec_field_names = { + "repo_id": "Repo ID", + "repo_code_hash": "Repo files", + "repo_data": "Repo state (branch, commit, or other)", + "ssh_key_pub": "Public SSH key", + } + nested_list = NestedList() + for spec_field in changed_spec_fields: + if spec_field == "merged_profile": + continue + elif spec_field == "configuration": + if type(old_spec.configuration) is not type(new_spec.configuration): + item = NestedListItem("Configuration type") + else: + item = NestedListItem( + "Configuration properties:", + children=[ + NestedListItem(field) + for field in diff_models(old_spec.configuration, new_spec.configuration) + ], + ) + elif spec_field == "profile": + if type(old_spec.profile) is not type(new_spec.profile): + item = NestedListItem("Profile") + else: + item = NestedListItem( + "Profile properties:", + children=[ + NestedListItem(field) + for field in diff_models(old_spec.profile, new_spec.profile) + ], + ) + elif spec_field in friendly_spec_field_names: + item = NestedListItem(friendly_spec_field_names[spec_field]) + else: + item = NestedListItem(spec_field.replace("_", " ").capitalize()) + nested_list.children.append(item) + return nested_list.render() diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 696cd7f025..172ba879b4 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSubmission, RunSpec +from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec from dstack._internal.server.schemas.runs import GetRunPlanRequest @@ -25,7 +25,10 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec) job_submissions_excludes = {} current_resource_excludes["jobs"] = { - "__all__": {"job_submissions": {"__all__": job_submissions_excludes}} + "__all__": { + "job_spec": get_job_spec_excludes([job.job_spec for job in current_resource.jobs]), + "job_submissions": {"__all__": job_submissions_excludes}, + } } job_submissions = [js for j in current_resource.jobs for js in j.job_submissions] if all(map(_should_exclude_job_submission_jpd_cpu_arch, job_submissions)): @@ -123,6 +126,24 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: return None +def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]: + """ + Returns `job_spec` exclude mapping to exclude certain fields from the request. + Use this method to exclude new fields when they are not set to keep + clients backward-compatibility with older servers. + """ + spec_excludes: dict[str, Any] = {} + + if all(s.repo_code_hash is None for s in job_specs): + spec_excludes["repo_code_hash"] = True + if all(s.repo_data is None for s in job_specs): + spec_excludes["repo_data"] = True + + if spec_excludes: + return spec_excludes + return None + + def _should_exclude_job_submission_jpd_cpu_arch(job_submission: JobSubmission) -> bool: try: return job_submission.job_provisioning_data.instance_type.resources.cpu_arch is None diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index df53247338..6cf80b9268 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -218,6 +218,14 @@ class JobSpec(CoreModel): volumes: Optional[List[MountPoint]] = None ssh_key: Optional[JobSSHKey] = None working_dir: Optional[str] + # `repo_data` is optional for client compatibility with pre-0.19.17 servers and for compatibility + # with jobs submitted before 0.19.17. All new jobs are expected to have non-None `repo_data`. + # For --no-repo runs, `repo_data` is `VirtualRunRepoData()`. + repo_data: Annotated[Optional[AnyRunRepoData], Field(discriminator="repo_type")] = None + # `repo_code_hash` can be None because it is not used for the repo or because the job was + # submitted before 0.19.17. See `_get_repo_code_hash` on how to get the correct `repo_code_hash` + # TODO: drop this comment when supporting jobs submitted before 0.19.17 is no longer relevant. + repo_code_hash: Optional[str] = None class JobProvisioningData(CoreModel): diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index ef07becdfd..e740e53338 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -241,7 +241,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): session=session, project=project, repo=repo_model, - code_hash=run.run_spec.repo_code_hash, + code_hash=_get_repo_code_hash(run, job), ) success = await common_utils.run_async( @@ -293,7 +293,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): session=session, project=project, repo=repo_model, - code_hash=run.run_spec.repo_code_hash, + code_hash=_get_repo_code_hash(run, job), ) success = await common_utils.run_async( _process_pulling_with_shim, @@ -849,6 +849,19 @@ def _get_cluster_info( return cluster_info +def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: + # TODO: drop this function when supporting jobs submitted before 0.19.17 is no longer relevant. + if ( + job.job_spec.repo_code_hash is None + and run.run_spec.repo_code_hash is not None + and job.job_submissions[-1].deployment_num == run.deployment_num + ): + # The job spec does not have `repo_code_hash`, because it was submitted before 0.19.17. + # Use `repo_code_hash` from the run. + return run.run_spec.repo_code_hash + return job.job_spec.repo_code_hash + + async def _get_job_code( session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str] ) -> bytes: diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 3cd5a92809..a2efa2fcd3 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -78,6 +78,7 @@ class SubmitBody(CoreModel): "max_duration", "ssh_key", "working_dir", + "repo_data", } ), ] diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 465a24fb0d..b94aecaf7f 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -134,6 +134,8 @@ async def _get_job_spec( working_dir=self._working_dir(), volumes=self._volumes(job_num), ssh_key=self._ssh_key(jobs_per_replica), + repo_data=self.run_spec.repo_data, + repo_code_hash=self.run_spec.repo_code_hash, ) return job_spec diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 385660f944..63129bc588 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -898,7 +898,14 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec): set_resources_defaults(run_spec.configuration.resources) -_UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"] +_UPDATABLE_SPEC_FIELDS = ["configuration"] +_TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS = { + "service": [ + # rolling deployment + "repo_data", + "repo_code_hash", + ], +} _CONF_UPDATABLE_FIELDS = ["priority"] _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = { "dev-environment": ["inactivity_duration"], @@ -935,11 +942,14 @@ def _can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bo def _check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec): spec_diff = diff_models(current_run_spec, new_run_spec) changed_spec_fields = list(spec_diff.keys()) + updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get( + new_run_spec.configuration.type, [] + ) for key in changed_spec_fields: - if key not in _UPDATABLE_SPEC_FIELDS: + if key not in updatable_spec_fields: raise ServerClientError( f"Failed to update fields {changed_spec_fields}." - f" Can only update {_UPDATABLE_SPEC_FIELDS}." + f" Can only update {updatable_spec_fields}." ) _check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 7aadb48979..f22de90f95 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -1,5 +1,6 @@ import json import uuid +from collections.abc import Callable from contextlib import contextmanager from datetime import datetime, timezone from typing import Dict, List, Literal, Optional, Union @@ -252,18 +253,19 @@ async def create_file_archive( def get_run_spec( run_name: str, repo_id: str, - profile: Optional[Profile] = None, + configuration_path: str = "dstack.yaml", + profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"), configuration: Optional[AnyRunConfiguration] = None, ) -> RunSpec: - if profile is None: - profile = Profile(name="default") + if callable(profile): + profile = profile() return RunSpec( run_name=run_name, repo_id=repo_id, repo_data=LocalRunRepoData(repo_dir="/"), repo_code_hash=None, working_dir=".", - configuration_path="dstack.yaml", + configuration_path=configuration_path, configuration=configuration or DevEnvironmentConfiguration(ide="vscode"), profile=profile, ssh_key_pub="user_ssh_key", diff --git a/src/dstack/_internal/utils/nested_list.py b/src/dstack/_internal/utils/nested_list.py new file mode 100644 index 0000000000..599298ed49 --- /dev/null +++ b/src/dstack/_internal/utils/nested_list.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class NestedListItem: + label: str + children: list["NestedListItem"] = field(default_factory=list) + + def render(self, indent: int = 0, visited: Optional[set[int]] = None) -> str: + if visited is None: + visited = set() + + item_id = id(self) + if item_id in visited: + raise ValueError(f"Cycle detected at item: {self.label}") + + visited.add(item_id) + prefix = " " * indent + "- " + output = f"{prefix}{self.label}\n" + for child in self.children: + # `visited.copy()` so that we only detect cycles within each path, + # rather than duplicate items in unrelated paths + output += child.render(indent + 1, visited.copy()) + return output + + +@dataclass +class NestedList: + """ + A nested list that can be rendered in Markdown-like format: + + - Item 1 + - Item 2 + - Item 2.1 + - Item 2.2 + - Item 2.2.1 + - Item 3 + """ + + children: list[NestedListItem] = field(default_factory=list) + + def render(self) -> str: + output = "" + for child in self.children: + output += child.render() + return output diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 36c5a5d11e..1a4e0e1e2e 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -4,10 +4,12 @@ import threading import time from abc import ABC +from collections.abc import Iterator +from contextlib import contextmanager from copy import copy from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import BinaryIO, Dict, Iterable, List, Optional, Union from urllib.parse import urlparse from websocket import WebSocketApp @@ -438,12 +440,16 @@ def get_run_plan( """ if repo is None: repo = VirtualRepo() + repo_code_hash = None + else: + with _prepare_code_file(repo) as (_, repo_code_hash): + pass run_spec = RunSpec( run_name=configuration.name, repo_id=repo.repo_id, repo_data=repo.run_repo_data, - repo_code_hash=None, # `apply_plan` will fill it + repo_code_hash=repo_code_hash, working_dir=configuration.working_dir, configuration_path=configuration_path, configuration=configuration, @@ -500,13 +506,11 @@ def apply_plan( else: # Do not upload the diff without a repo (a default virtual repo) # since upload_code() requires a repo to be initialized. - with tempfile.TemporaryFile("w+b") as fp: - run_spec.repo_code_hash = repo.write_code_file(fp) - fp.seek(0) + with _prepare_code_file(repo) as (fp, repo_code_hash): self._api_client.repos.upload_code( project_name=self._project, repo_id=repo.repo_id, - code_hash=run_spec.repo_code_hash, + code_hash=repo_code_hash, fp=fp, ) @@ -647,6 +651,10 @@ def get_plan( logger.warning("The get_plan() method is deprecated in favor of get_run_plan().") if repo is None: repo = VirtualRepo() + repo_code_hash = None + else: + with _prepare_code_file(repo) as (_, repo_code_hash): + pass if working_dir is None: working_dir = "." @@ -683,7 +691,7 @@ def get_plan( run_name=run_name, repo_id=repo.repo_id, repo_data=repo.run_repo_data, - repo_code_hash=None, # `exec_plan` will fill it + repo_code_hash=repo_code_hash, working_dir=working_dir, configuration_path=configuration_path, configuration=configuration, @@ -825,3 +833,11 @@ def _reserve_ports( ports[port_override.container_port] = port_override.local_port or 0 logger.debug("Reserving ports: %s", ports) return PortsLock(ports).acquire() + + +@contextmanager +def _prepare_code_file(repo: Repo) -> Iterator[tuple[BinaryIO, str]]: + with tempfile.TemporaryFile("w+b") as fp: + repo_code_hash = repo.write_code_file(fp) + fp.seek(0) + yield fp, repo_code_hash diff --git a/src/tests/_internal/cli/services/configurators/test_run.py b/src/tests/_internal/cli/services/configurators/test_run.py index bb01f1e1f8..14959f5845 100644 --- a/src/tests/_internal/cli/services/configurators/test_run.py +++ b/src/tests/_internal/cli/services/configurators/test_run.py @@ -1,4 +1,5 @@ import argparse +from textwrap import dedent from typing import List, Optional, Tuple from unittest.mock import Mock @@ -6,15 +7,22 @@ from gpuhunt import AcceleratorVendor from dstack._internal.cli.services.configurators import get_run_configurator_class -from dstack._internal.cli.services.configurators.run import BaseRunConfigurator +from dstack._internal.cli.services.configurators.run import ( + BaseRunConfigurator, + render_run_spec_diff, +) from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import RegistryAuth from dstack._internal.core.models.configurations import ( BaseRunConfiguration, + DevEnvironmentConfiguration, PortMapping, TaskConfiguration, ) from dstack._internal.core.models.envs import Env +from dstack._internal.core.models.profiles import Profile +from dstack._internal.server.testing.common import get_run_spec class TestApplyArgs: @@ -288,3 +296,87 @@ def test_arm_no_image(self, cpu_spec: str): def test_x86(self, cpu_spec: str, image: Optional[str]): conf = self.prepare_conf(cpu_spec=cpu_spec, gpu_spec="H100", image=image) self.validate(conf) + + +class TestRenderRunSpecDiff: + def test_diff(self): + old = get_run_spec( + run_name="test", + repo_id="test-1", + configuration_path="1.dstack.yml", + profile=Profile( + backends=[BackendType.AWS], + regions=["us-west-1"], + name="test", + default=True, + ), + configuration=DevEnvironmentConfiguration( + name="test", + ide="vscode", + inactivity_duration=60, + ), + ) + new = get_run_spec( + run_name="test", + repo_id="test-2", + configuration_path="2.dstack.yml", + profile=Profile( + backends=[BackendType.AWS], + regions=["us-west-2"], + name="test", + default=True, + ), + configuration=DevEnvironmentConfiguration( + name="test", + ide="cursor", + inactivity_duration=None, + ), + ) + assert ( + render_run_spec_diff(old, new) + == dedent( + """ + - Repo ID + - Configuration path + - Configuration properties: + - ide + - inactivity_duration + - Profile properties: + - regions + """ + ).lstrip() + ) + + def test_field_type_change(self): + old = get_run_spec( + run_name="test", + repo_id="test", + profile=Profile(name="test"), + configuration=DevEnvironmentConfiguration( + name="test", + ide="vscode", + ), + ) + new = get_run_spec( + run_name="test", + repo_id="test", + profile=None, + configuration=TaskConfiguration( + name="test", + commands=["sleep infinity"], + ), + ) + assert ( + render_run_spec_diff(old, new) + == dedent( + """ + - Configuration type + - Profile + """ + ).lstrip() + ) + + def test_no_diff(self): + old = get_run_spec(run_name="test", repo_id="test") + new = get_run_spec(run_name="test", repo_id="test") + assert render_run_spec_diff(old, new) is None diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 880c3be5df..63a71989e7 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -244,6 +244,8 @@ def get_dev_env_run_plan_dict( "volumes": volumes, "ssh_key": None, "working_dir": ".", + "repo_code_hash": None, + "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, }, "offers": [json.loads(o.json()) for o in offers], "total_offers": total_offers, @@ -436,6 +438,8 @@ def get_dev_env_run_dict( "volumes": [], "ssh_key": None, "working_dir": ".", + "repo_code_hash": None, + "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, }, "job_submissions": [ { diff --git a/src/tests/_internal/utils/test_nested_list.py b/src/tests/_internal/utils/test_nested_list.py new file mode 100644 index 0000000000..7a86962592 --- /dev/null +++ b/src/tests/_internal/utils/test_nested_list.py @@ -0,0 +1,56 @@ +from textwrap import dedent + +import pytest + +from dstack._internal.utils.nested_list import NestedList, NestedListItem + + +def test_render_flat_list(): + nested = NestedList( + children=[NestedListItem("Item 1"), NestedListItem("Item 2"), NestedListItem("Item 3")] + ) + expected = "- Item 1\n- Item 2\n- Item 3\n" + assert nested.render() == expected + + +def test_render_nested_list(): + nested = NestedList( + children=[ + NestedListItem("Item 1"), + NestedListItem( + "Item 2", + [ + NestedListItem("Item 2.1"), + NestedListItem("Item 2.2", [NestedListItem("Item 2.2.1")]), + ], + ), + NestedListItem("Item 3"), + ] + ) + expected = dedent( + """ + - Item 1 + - Item 2 + - Item 2.1 + - Item 2.2 + - Item 2.2.1 + - Item 3 + """ + ).lstrip() + assert nested.render() == expected + + +def test_render_empty_list(): + nested = NestedList() + assert nested.render() == "" + + +def test_cycle_detection(): + a = NestedListItem("A") + b = NestedListItem("B", [a]) + a.children.append(b) # Introduce a cycle: A → B → A + + nested = NestedList(children=[a]) + + with pytest.raises(ValueError, match="Cycle detected at item: A"): + nested.render()