diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 14f46d035c..c720eb02e9 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -276,6 +276,14 @@ func (ex *RunExecutor) SetRunnerState(state string) { ex.state = state } +func (ex *RunExecutor) getRepoData() schemas.RepoData { + if ex.jobSpec.RepoData == nil { + // jobs submitted before 0.19.17 do not have jobSpec.RepoData + return ex.run.RunSpec.RepoData + } + return *ex.jobSpec.RepoData +} + func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error { node_rank := ex.jobSpec.JobNum nodes_num := ex.jobSpec.JobsPerReplica diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index 2165521bff..8d275b1375 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -132,7 +132,7 @@ func TestExecutor_RemoteRepo(t *testing.T) { var b bytes.Buffer ex := makeTestExecutor(t) - ex.run.RunSpec.RepoData = schemas.RepoData{ + ex.jobSpec.RepoData = &schemas.RepoData{ RepoType: "remote", RepoBranch: "main", RepoHash: "2b83592e506ed6fe8e49f4eaa97c3866bc9402b1", @@ -148,7 +148,7 @@ func TestExecutor_RemoteRepo(t *testing.T) { err = ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.run.RunSpec.RepoData.RepoHash, ex.run.RunSpec.RepoData.RepoConfigName, ex.run.RunSpec.RepoData.RepoConfigEmail) + expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) assert.Equal(t, expected, b.String()) } @@ -178,6 +178,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor { Env: make(map[string]string), MaxDuration: 0, // no timeout WorkingDir: &workingDir, + RepoData: &schemas.RepoData{RepoType: "local"}, }, Secrets: make(map[string]string), RepoCredentials: &schemas.RepoCredentials{ diff --git a/runner/internal/executor/repo.go b/runner/internal/executor/repo.go index 5afbb2515a..4e6271f29e 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/executor/repo.go @@ -40,7 +40,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { err = gerrors.Wrap(err_) } }() - switch ex.run.RunSpec.RepoData.RepoType { + switch ex.getRepoData().RepoType { case "remote": log.Trace(ctx, "Fetching git repository") if err := ex.prepareGit(ctx); err != nil { @@ -52,7 +52,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { return gerrors.Wrap(err) } default: - return gerrors.Newf("unknown RepoType: %s", ex.run.RunSpec.RepoData.RepoType) + return gerrors.Newf("unknown RepoType: %s", ex.getRepoData().RepoType) } return err } @@ -61,8 +61,8 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { repoManager := repo.NewManager( ctx, ex.repoCredentials.CloneURL, - ex.run.RunSpec.RepoData.RepoBranch, - ex.run.RunSpec.RepoData.RepoHash, + ex.getRepoData().RepoBranch, + ex.getRepoData().RepoHash, ex.jobSpec.SingleBranch, ).WithLocalPath(ex.workingDir) if ex.repoCredentials != nil { @@ -92,7 +92,7 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { if err := repoManager.Checkout(); err != nil { return gerrors.Wrap(err) } - if err := repoManager.SetConfig(ex.run.RunSpec.RepoData.RepoConfigName, ex.run.RunSpec.RepoData.RepoConfigEmail); err != nil { + if err := repoManager.SetConfig(ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail); err != nil { return gerrors.Wrap(err) } diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 8d4cef2f1c..e3b950639c 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -68,6 +68,10 @@ type JobSpec struct { MaxDuration int `json:"max_duration"` SSHKey *SSHKey `json:"ssh_key"` WorkingDir *string `json:"working_dir"` + // `RepoData` is optional for compatibility with jobs submitted before 0.19.17. + // Use `RunExecutor.getRepoData()` to get non-nil `RepoData`. + // TODO: make required when supporting jobs submitted before 0.19.17 is no longer relevant. + RepoData *RepoData `json:"repo_data"` } type ClusterInfo struct { diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index b23d227a5a..7adc3b90be 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -41,12 +41,13 @@ ) from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.resources import CPUSpec -from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunStatus +from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.diff import diff_models from dstack._internal.utils.common import local_time from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.nested_list import NestedList, NestedListItem from dstack.api._public.repos import get_ssh_keypair from dstack.api._public.runs import Run from dstack.api.utils import load_profile @@ -102,25 +103,20 @@ def apply_configuration( confirm_message = f"Submit the run [code]{conf.name}[/]?" stop_run_name = None if run_plan.current_resource is not None: - changed_fields = [] - if run_plan.action == ApplyAction.UPDATE: - diff = diff_models( - run_plan.get_effective_run_spec().configuration, - run_plan.current_resource.run_spec.configuration, - ) - changed_fields = list(diff.keys()) - if run_plan.action == ApplyAction.UPDATE and len(changed_fields) > 0: + diff = render_run_spec_diff( + run_plan.get_effective_run_spec(), + run_plan.current_resource.run_spec, + ) + if run_plan.action == ApplyAction.UPDATE and diff is not None: console.print( f"Active run [code]{conf.name}[/] already exists." - " Detected configuration changes that can be updated in-place:" - f" {changed_fields}" + f" Detected changes that [code]can[/] be updated in-place:\n{diff}" ) confirm_message = "Update the run?" - elif run_plan.action == ApplyAction.UPDATE and len(changed_fields) == 0: + elif run_plan.action == ApplyAction.UPDATE and diff is None: stop_run_name = run_plan.current_resource.run_spec.run_name console.print( - f"Active run [code]{conf.name}[/] already exists." - " Detected no configuration changes." + f"Active run [code]{conf.name}[/] already exists. Detected no changes." ) if command_args.yes and not command_args.force: console.print("Use --force to apply anyway.") @@ -129,7 +125,8 @@ def apply_configuration( elif not run_plan.current_resource.status.is_finished(): stop_run_name = run_plan.current_resource.run_spec.run_name console.print( - f"Active run [code]{conf.name}[/] already exists and cannot be updated in-place." + f"Active run [code]{conf.name}[/] already exists." + f" Detected changes that [error]cannot[/] be updated in-place:\n{diff}" ) confirm_message = "Stop and override the run?" @@ -611,3 +608,47 @@ def _run_resubmitted(run: Run, current_job_submission: Optional[JobSubmission]) not run.status.is_finished() and run._run.latest_job_submission.submitted_at > current_job_submission.submitted_at ) + + +def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]: + changed_spec_fields = list(diff_models(old_spec, new_spec)) + if not changed_spec_fields: + return None + friendly_spec_field_names = { + "repo_id": "Repo ID", + "repo_code_hash": "Repo files", + "repo_data": "Repo state (branch, commit, or other)", + "ssh_key_pub": "Public SSH key", + } + nested_list = NestedList() + for spec_field in changed_spec_fields: + if spec_field == "merged_profile": + continue + elif spec_field == "configuration": + if type(old_spec.configuration) is not type(new_spec.configuration): + item = NestedListItem("Configuration type") + else: + item = NestedListItem( + "Configuration properties:", + children=[ + NestedListItem(field) + for field in diff_models(old_spec.configuration, new_spec.configuration) + ], + ) + elif spec_field == "profile": + if type(old_spec.profile) is not type(new_spec.profile): + item = NestedListItem("Profile") + else: + item = NestedListItem( + "Profile properties:", + children=[ + NestedListItem(field) + for field in diff_models(old_spec.profile, new_spec.profile) + ], + ) + elif spec_field in friendly_spec_field_names: + item = NestedListItem(friendly_spec_field_names[spec_field]) + else: + item = NestedListItem(spec_field.replace("_", " ").capitalize()) + nested_list.children.append(item) + return nested_list.render() diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 696cd7f025..172ba879b4 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSubmission, RunSpec +from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec from dstack._internal.server.schemas.runs import GetRunPlanRequest @@ -25,7 +25,10 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec) job_submissions_excludes = {} current_resource_excludes["jobs"] = { - "__all__": {"job_submissions": {"__all__": job_submissions_excludes}} + "__all__": { + "job_spec": get_job_spec_excludes([job.job_spec for job in current_resource.jobs]), + "job_submissions": {"__all__": job_submissions_excludes}, + } } job_submissions = [js for j in current_resource.jobs for js in j.job_submissions] if all(map(_should_exclude_job_submission_jpd_cpu_arch, job_submissions)): @@ -123,6 +126,24 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: return None +def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]: + """ + Returns `job_spec` exclude mapping to exclude certain fields from the request. + Use this method to exclude new fields when they are not set to keep + clients backward-compatibility with older servers. + """ + spec_excludes: dict[str, Any] = {} + + if all(s.repo_code_hash is None for s in job_specs): + spec_excludes["repo_code_hash"] = True + if all(s.repo_data is None for s in job_specs): + spec_excludes["repo_data"] = True + + if spec_excludes: + return spec_excludes + return None + + def _should_exclude_job_submission_jpd_cpu_arch(job_submission: JobSubmission) -> bool: try: return job_submission.job_provisioning_data.instance_type.resources.cpu_arch is None diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index df53247338..6cf80b9268 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -218,6 +218,14 @@ class JobSpec(CoreModel): volumes: Optional[List[MountPoint]] = None ssh_key: Optional[JobSSHKey] = None working_dir: Optional[str] + # `repo_data` is optional for client compatibility with pre-0.19.17 servers and for compatibility + # with jobs submitted before 0.19.17. All new jobs are expected to have non-None `repo_data`. + # For --no-repo runs, `repo_data` is `VirtualRunRepoData()`. + repo_data: Annotated[Optional[AnyRunRepoData], Field(discriminator="repo_type")] = None + # `repo_code_hash` can be None because it is not used for the repo or because the job was + # submitted before 0.19.17. See `_get_repo_code_hash` on how to get the correct `repo_code_hash` + # TODO: drop this comment when supporting jobs submitted before 0.19.17 is no longer relevant. + repo_code_hash: Optional[str] = None class JobProvisioningData(CoreModel): diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index ef07becdfd..e740e53338 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -241,7 +241,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): session=session, project=project, repo=repo_model, - code_hash=run.run_spec.repo_code_hash, + code_hash=_get_repo_code_hash(run, job), ) success = await common_utils.run_async( @@ -293,7 +293,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): session=session, project=project, repo=repo_model, - code_hash=run.run_spec.repo_code_hash, + code_hash=_get_repo_code_hash(run, job), ) success = await common_utils.run_async( _process_pulling_with_shim, @@ -849,6 +849,19 @@ def _get_cluster_info( return cluster_info +def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: + # TODO: drop this function when supporting jobs submitted before 0.19.17 is no longer relevant. + if ( + job.job_spec.repo_code_hash is None + and run.run_spec.repo_code_hash is not None + and job.job_submissions[-1].deployment_num == run.deployment_num + ): + # The job spec does not have `repo_code_hash`, because it was submitted before 0.19.17. + # Use `repo_code_hash` from the run. + return run.run_spec.repo_code_hash + return job.job_spec.repo_code_hash + + async def _get_job_code( session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str] ) -> bytes: diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 3cd5a92809..a2efa2fcd3 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -78,6 +78,7 @@ class SubmitBody(CoreModel): "max_duration", "ssh_key", "working_dir", + "repo_data", } ), ] diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 465a24fb0d..b94aecaf7f 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -134,6 +134,8 @@ async def _get_job_spec( working_dir=self._working_dir(), volumes=self._volumes(job_num), ssh_key=self._ssh_key(jobs_per_replica), + repo_data=self.run_spec.repo_data, + repo_code_hash=self.run_spec.repo_code_hash, ) return job_spec diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 385660f944..63129bc588 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -898,7 +898,14 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec): set_resources_defaults(run_spec.configuration.resources) -_UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"] +_UPDATABLE_SPEC_FIELDS = ["configuration"] +_TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS = { + "service": [ + # rolling deployment + "repo_data", + "repo_code_hash", + ], +} _CONF_UPDATABLE_FIELDS = ["priority"] _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = { "dev-environment": ["inactivity_duration"], @@ -935,11 +942,14 @@ def _can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bo def _check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec): spec_diff = diff_models(current_run_spec, new_run_spec) changed_spec_fields = list(spec_diff.keys()) + updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get( + new_run_spec.configuration.type, [] + ) for key in changed_spec_fields: - if key not in _UPDATABLE_SPEC_FIELDS: + if key not in updatable_spec_fields: raise ServerClientError( f"Failed to update fields {changed_spec_fields}." - f" Can only update {_UPDATABLE_SPEC_FIELDS}." + f" Can only update {updatable_spec_fields}." ) _check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 7aadb48979..f22de90f95 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -1,5 +1,6 @@ import json import uuid +from collections.abc import Callable from contextlib import contextmanager from datetime import datetime, timezone from typing import Dict, List, Literal, Optional, Union @@ -252,18 +253,19 @@ async def create_file_archive( def get_run_spec( run_name: str, repo_id: str, - profile: Optional[Profile] = None, + configuration_path: str = "dstack.yaml", + profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"), configuration: Optional[AnyRunConfiguration] = None, ) -> RunSpec: - if profile is None: - profile = Profile(name="default") + if callable(profile): + profile = profile() return RunSpec( run_name=run_name, repo_id=repo_id, repo_data=LocalRunRepoData(repo_dir="/"), repo_code_hash=None, working_dir=".", - configuration_path="dstack.yaml", + configuration_path=configuration_path, configuration=configuration or DevEnvironmentConfiguration(ide="vscode"), profile=profile, ssh_key_pub="user_ssh_key", diff --git a/src/dstack/_internal/utils/nested_list.py b/src/dstack/_internal/utils/nested_list.py new file mode 100644 index 0000000000..599298ed49 --- /dev/null +++ b/src/dstack/_internal/utils/nested_list.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class NestedListItem: + label: str + children: list["NestedListItem"] = field(default_factory=list) + + def render(self, indent: int = 0, visited: Optional[set[int]] = None) -> str: + if visited is None: + visited = set() + + item_id = id(self) + if item_id in visited: + raise ValueError(f"Cycle detected at item: {self.label}") + + visited.add(item_id) + prefix = " " * indent + "- " + output = f"{prefix}{self.label}\n" + for child in self.children: + # `visited.copy()` so that we only detect cycles within each path, + # rather than duplicate items in unrelated paths + output += child.render(indent + 1, visited.copy()) + return output + + +@dataclass +class NestedList: + """ + A nested list that can be rendered in Markdown-like format: + + - Item 1 + - Item 2 + - Item 2.1 + - Item 2.2 + - Item 2.2.1 + - Item 3 + """ + + children: list[NestedListItem] = field(default_factory=list) + + def render(self) -> str: + output = "" + for child in self.children: + output += child.render() + return output diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 36c5a5d11e..1a4e0e1e2e 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -4,10 +4,12 @@ import threading import time from abc import ABC +from collections.abc import Iterator +from contextlib import contextmanager from copy import copy from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import BinaryIO, Dict, Iterable, List, Optional, Union from urllib.parse import urlparse from websocket import WebSocketApp @@ -438,12 +440,16 @@ def get_run_plan( """ if repo is None: repo = VirtualRepo() + repo_code_hash = None + else: + with _prepare_code_file(repo) as (_, repo_code_hash): + pass run_spec = RunSpec( run_name=configuration.name, repo_id=repo.repo_id, repo_data=repo.run_repo_data, - repo_code_hash=None, # `apply_plan` will fill it + repo_code_hash=repo_code_hash, working_dir=configuration.working_dir, configuration_path=configuration_path, configuration=configuration, @@ -500,13 +506,11 @@ def apply_plan( else: # Do not upload the diff without a repo (a default virtual repo) # since upload_code() requires a repo to be initialized. - with tempfile.TemporaryFile("w+b") as fp: - run_spec.repo_code_hash = repo.write_code_file(fp) - fp.seek(0) + with _prepare_code_file(repo) as (fp, repo_code_hash): self._api_client.repos.upload_code( project_name=self._project, repo_id=repo.repo_id, - code_hash=run_spec.repo_code_hash, + code_hash=repo_code_hash, fp=fp, ) @@ -647,6 +651,10 @@ def get_plan( logger.warning("The get_plan() method is deprecated in favor of get_run_plan().") if repo is None: repo = VirtualRepo() + repo_code_hash = None + else: + with _prepare_code_file(repo) as (_, repo_code_hash): + pass if working_dir is None: working_dir = "." @@ -683,7 +691,7 @@ def get_plan( run_name=run_name, repo_id=repo.repo_id, repo_data=repo.run_repo_data, - repo_code_hash=None, # `exec_plan` will fill it + repo_code_hash=repo_code_hash, working_dir=working_dir, configuration_path=configuration_path, configuration=configuration, @@ -825,3 +833,11 @@ def _reserve_ports( ports[port_override.container_port] = port_override.local_port or 0 logger.debug("Reserving ports: %s", ports) return PortsLock(ports).acquire() + + +@contextmanager +def _prepare_code_file(repo: Repo) -> Iterator[tuple[BinaryIO, str]]: + with tempfile.TemporaryFile("w+b") as fp: + repo_code_hash = repo.write_code_file(fp) + fp.seek(0) + yield fp, repo_code_hash diff --git a/src/tests/_internal/cli/services/configurators/test_run.py b/src/tests/_internal/cli/services/configurators/test_run.py index bb01f1e1f8..14959f5845 100644 --- a/src/tests/_internal/cli/services/configurators/test_run.py +++ b/src/tests/_internal/cli/services/configurators/test_run.py @@ -1,4 +1,5 @@ import argparse +from textwrap import dedent from typing import List, Optional, Tuple from unittest.mock import Mock @@ -6,15 +7,22 @@ from gpuhunt import AcceleratorVendor from dstack._internal.cli.services.configurators import get_run_configurator_class -from dstack._internal.cli.services.configurators.run import BaseRunConfigurator +from dstack._internal.cli.services.configurators.run import ( + BaseRunConfigurator, + render_run_spec_diff, +) from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import RegistryAuth from dstack._internal.core.models.configurations import ( BaseRunConfiguration, + DevEnvironmentConfiguration, PortMapping, TaskConfiguration, ) from dstack._internal.core.models.envs import Env +from dstack._internal.core.models.profiles import Profile +from dstack._internal.server.testing.common import get_run_spec class TestApplyArgs: @@ -288,3 +296,87 @@ def test_arm_no_image(self, cpu_spec: str): def test_x86(self, cpu_spec: str, image: Optional[str]): conf = self.prepare_conf(cpu_spec=cpu_spec, gpu_spec="H100", image=image) self.validate(conf) + + +class TestRenderRunSpecDiff: + def test_diff(self): + old = get_run_spec( + run_name="test", + repo_id="test-1", + configuration_path="1.dstack.yml", + profile=Profile( + backends=[BackendType.AWS], + regions=["us-west-1"], + name="test", + default=True, + ), + configuration=DevEnvironmentConfiguration( + name="test", + ide="vscode", + inactivity_duration=60, + ), + ) + new = get_run_spec( + run_name="test", + repo_id="test-2", + configuration_path="2.dstack.yml", + profile=Profile( + backends=[BackendType.AWS], + regions=["us-west-2"], + name="test", + default=True, + ), + configuration=DevEnvironmentConfiguration( + name="test", + ide="cursor", + inactivity_duration=None, + ), + ) + assert ( + render_run_spec_diff(old, new) + == dedent( + """ + - Repo ID + - Configuration path + - Configuration properties: + - ide + - inactivity_duration + - Profile properties: + - regions + """ + ).lstrip() + ) + + def test_field_type_change(self): + old = get_run_spec( + run_name="test", + repo_id="test", + profile=Profile(name="test"), + configuration=DevEnvironmentConfiguration( + name="test", + ide="vscode", + ), + ) + new = get_run_spec( + run_name="test", + repo_id="test", + profile=None, + configuration=TaskConfiguration( + name="test", + commands=["sleep infinity"], + ), + ) + assert ( + render_run_spec_diff(old, new) + == dedent( + """ + - Configuration type + - Profile + """ + ).lstrip() + ) + + def test_no_diff(self): + old = get_run_spec(run_name="test", repo_id="test") + new = get_run_spec(run_name="test", repo_id="test") + assert render_run_spec_diff(old, new) is None diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 880c3be5df..63a71989e7 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -244,6 +244,8 @@ def get_dev_env_run_plan_dict( "volumes": volumes, "ssh_key": None, "working_dir": ".", + "repo_code_hash": None, + "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, }, "offers": [json.loads(o.json()) for o in offers], "total_offers": total_offers, @@ -436,6 +438,8 @@ def get_dev_env_run_dict( "volumes": [], "ssh_key": None, "working_dir": ".", + "repo_code_hash": None, + "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, }, "job_submissions": [ { diff --git a/src/tests/_internal/utils/test_nested_list.py b/src/tests/_internal/utils/test_nested_list.py new file mode 100644 index 0000000000..7a86962592 --- /dev/null +++ b/src/tests/_internal/utils/test_nested_list.py @@ -0,0 +1,56 @@ +from textwrap import dedent + +import pytest + +from dstack._internal.utils.nested_list import NestedList, NestedListItem + + +def test_render_flat_list(): + nested = NestedList( + children=[NestedListItem("Item 1"), NestedListItem("Item 2"), NestedListItem("Item 3")] + ) + expected = "- Item 1\n- Item 2\n- Item 3\n" + assert nested.render() == expected + + +def test_render_nested_list(): + nested = NestedList( + children=[ + NestedListItem("Item 1"), + NestedListItem( + "Item 2", + [ + NestedListItem("Item 2.1"), + NestedListItem("Item 2.2", [NestedListItem("Item 2.2.1")]), + ], + ), + NestedListItem("Item 3"), + ] + ) + expected = dedent( + """ + - Item 1 + - Item 2 + - Item 2.1 + - Item 2.2 + - Item 2.2.1 + - Item 3 + """ + ).lstrip() + assert nested.render() == expected + + +def test_render_empty_list(): + nested = NestedList() + assert nested.render() == "" + + +def test_cycle_detection(): + a = NestedListItem("A") + b = NestedListItem("B", [a]) + a.children.append(b) # Introduce a cycle: A → B → A + + nested = NestedList(children=[a]) + + with pytest.raises(ValueError, match="Cycle detected at item: A"): + nested.render()