Skip to content

Commit d7ea72c

Browse files
authored
Fix in-place update when files are used (#3289)
Fixes: #3265
1 parent a172672 commit d7ea72c

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/dstack/_internal/server/services/runs/spec.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,13 @@ def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
127127
f"Failed to update fields {changed_spec_fields}."
128128
f" Can only update {updatable_spec_fields}."
129129
)
130-
_check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration)
130+
# We don't allow update if the order of archives has been changed, as even if the archives
131+
# are the same (the same id => hash => content and the same container path), the order of
132+
# unpacking matters when one path is a subpath of another.
133+
ignore_files = current_run_spec.file_archives == new_run_spec.file_archives
134+
_check_can_update_configuration(
135+
current_run_spec.configuration, new_run_spec.configuration, ignore_files
136+
)
131137

132138

133139
def can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bool:
@@ -159,7 +165,7 @@ def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool:
159165

160166

161167
def _check_can_update_configuration(
162-
current: AnyRunConfiguration, new: AnyRunConfiguration
168+
current: AnyRunConfiguration, new: AnyRunConfiguration, ignore_files: bool
163169
) -> None:
164170
if current.type != new.type:
165171
raise ServerClientError(
@@ -168,6 +174,13 @@ def _check_can_update_configuration(
168174
updatable_fields = _CONF_UPDATABLE_FIELDS + _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS.get(
169175
new.type, []
170176
)
177+
if ignore_files:
178+
# We ignore files diff if the file archives are the same. It allows the user to move
179+
# local files/dirs as long as their name(*), content, and the container path stay the same.
180+
# (*) We could also ignore local name changes if the names didn't change in the tarballs.
181+
# Currently, the client preserves the original file/dir name it the tarball, but it could
182+
# use some generic names like "file"/"directory" instead.
183+
updatable_fields.append("files")
171184
diff = diff_models(current, new)
172185
changed_fields = list(diff.keys())
173186
for key in changed_fields:

src/dstack/api/_public/runs.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,18 @@ def get_run_plan(
490490
if repo_dir is None and configuration.repos:
491491
repo_dir = configuration.repos[0].path
492492

493+
self._validate_configuration_files(configuration, configuration_path)
494+
file_archives: list[FileArchiveMapping] = []
495+
for file_mapping in configuration.files:
496+
with tempfile.TemporaryFile("w+b") as fp:
497+
try:
498+
archive_hash = create_file_archive(file_mapping.local_path, fp)
499+
except OSError as e:
500+
raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e
501+
fp.seek(0)
502+
archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp)
503+
file_archives.append(FileArchiveMapping(id=archive.id, path=file_mapping.path))
504+
493505
if ssh_identity_file:
494506
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
495507
else:
@@ -513,6 +525,7 @@ def get_run_plan(
513525
repo_data=repo.run_repo_data,
514526
repo_code_hash=repo_code_hash,
515527
repo_dir=repo_dir,
528+
file_archives=file_archives,
516529
# Server doesn't use this field since 0.19.27, but we still send it for compatibility
517530
# with older servers
518531
working_dir=configuration.working_dir,
@@ -549,22 +562,6 @@ def apply_plan(
549562
# TODO handle multiple jobs
550563
ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec)
551564

552-
run_spec = run_plan.run_spec
553-
configuration = run_spec.configuration
554-
555-
self._validate_configuration_files(configuration, run_spec.configuration_path)
556-
for file_mapping in configuration.files:
557-
with tempfile.TemporaryFile("w+b") as fp:
558-
try:
559-
archive_hash = create_file_archive(file_mapping.local_path, fp)
560-
except OSError as e:
561-
raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e
562-
fp.seek(0)
563-
archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp)
564-
run_spec.file_archives.append(
565-
FileArchiveMapping(id=archive.id, path=file_mapping.path)
566-
)
567-
568565
if repo is None:
569566
repo = VirtualRepo()
570567
else:

0 commit comments

Comments
 (0)