Skip to content

Commit 09c59ed

Browse files
committed
Fix in-place update when files are used
Fixes: #3265
1 parent fa6875b commit 09c59ed

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/dstack/_internal/server/services/runs.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,11 +1117,17 @@ def _check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec)
11171117
f"Failed to update fields {changed_spec_fields}."
11181118
f" Can only update {updatable_spec_fields}."
11191119
)
1120-
_check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration)
1120+
# We don't allow update if the order of archives has been changed, as even if the archives
1121+
# are the same (the same id => hash => content and the same container path), the order of
1122+
# unpacking matters when one path is a subpath of another.
1123+
ignore_files = current_run_spec.file_archives == new_run_spec.file_archives
1124+
_check_can_update_configuration(
1125+
current_run_spec.configuration, new_run_spec.configuration, ignore_files
1126+
)
11211127

11221128

11231129
def _check_can_update_configuration(
1124-
current: AnyRunConfiguration, new: AnyRunConfiguration
1130+
current: AnyRunConfiguration, new: AnyRunConfiguration, ignore_files: bool
11251131
) -> None:
11261132
if current.type != new.type:
11271133
raise ServerClientError(
@@ -1130,6 +1136,13 @@ def _check_can_update_configuration(
11301136
updatable_fields = _CONF_UPDATABLE_FIELDS + _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS.get(
11311137
new.type, []
11321138
)
1139+
if ignore_files:
1140+
# We ignore files diff if the file archives are the same. It allows the user to move
1141+
# local files/dirs as long as their name(*), content, and the container path stay the same.
1142+
# (*) We could also ignore local name changes if the names didn't change in the tarballs.
1143+
# Currently, the client preserves the original file/dir name it the tarball, but it could
1144+
# use some generic names like "file"/"directory" instead.
1145+
updatable_fields.append("files")
11331146
diff = diff_models(current, new)
11341147
changed_fields = list(diff.keys())
11351148
for key in changed_fields:

src/dstack/api/_public/runs.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,18 @@ def get_run_plan(
490490
if repo_dir is None and configuration.repos:
491491
repo_dir = configuration.repos[0].path
492492

493+
self._validate_configuration_files(configuration, configuration_path)
494+
file_archives: list[FileArchiveMapping] = []
495+
for file_mapping in configuration.files:
496+
with tempfile.TemporaryFile("w+b") as fp:
497+
try:
498+
archive_hash = create_file_archive(file_mapping.local_path, fp)
499+
except OSError as e:
500+
raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e
501+
fp.seek(0)
502+
archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp)
503+
file_archives.append(FileArchiveMapping(id=archive.id, path=file_mapping.path))
504+
493505
if ssh_identity_file:
494506
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
495507
else:
@@ -513,6 +525,7 @@ def get_run_plan(
513525
repo_data=repo.run_repo_data,
514526
repo_code_hash=repo_code_hash,
515527
repo_dir=repo_dir,
528+
file_archives=file_archives,
516529
# Server doesn't use this field since 0.19.27, but we still send it for compatibility
517530
# with older servers
518531
working_dir=configuration.working_dir,
@@ -549,22 +562,6 @@ def apply_plan(
549562
# TODO handle multiple jobs
550563
ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec)
551564

552-
run_spec = run_plan.run_spec
553-
configuration = run_spec.configuration
554-
555-
self._validate_configuration_files(configuration, run_spec.configuration_path)
556-
for file_mapping in configuration.files:
557-
with tempfile.TemporaryFile("w+b") as fp:
558-
try:
559-
archive_hash = create_file_archive(file_mapping.local_path, fp)
560-
except OSError as e:
561-
raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e
562-
fp.seek(0)
563-
archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp)
564-
run_spec.file_archives.append(
565-
FileArchiveMapping(id=archive.id, path=file_mapping.path)
566-
)
567-
568565
if repo is None:
569566
repo = VirtualRepo()
570567
else:

0 commit comments

Comments
 (0)