From 093b86b18f1d4d592ac4bfcab1164a30743f2455 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 2 Jul 2025 15:11:26 +0000 Subject: [PATCH] Ignore SSH keys when calculating fleet conf diff Since `/api/project/{project_name}/fleets/get_plan` returns the plan with sensitive info (SSH keys) removed, the configurations are never equal. To work around this issue, we exclude (strictly speaking, not exclude but reset to the default values, `None` in this case) fields containing SSH key content when calculating the fleet configurations diff. Fixes: https://github.com/dstackai/dstack/issues/2222 --- .../cli/services/configurators/fleet.py | 14 ++++++- .../_internal/core/compatibility/fleets.py | 23 +++++------ .../_internal/core/compatibility/gateways.py | 17 ++++---- .../_internal/core/compatibility/logs.py | 7 ++-- .../_internal/core/compatibility/runs.py | 37 ++++++++---------- .../_internal/core/compatibility/volumes.py | 17 ++++---- src/dstack/_internal/core/models/common.py | 7 ++++ src/dstack/_internal/core/services/diff.py | 39 +++++++++++++++++-- 8 files changed, 107 insertions(+), 54 deletions(-) diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 9bfb399aca..b501f0b9de 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -35,6 +35,7 @@ ) from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey from dstack._internal.core.models.repos.base import Repo +from dstack._internal.core.services.diff import diff_models from dstack._internal.utils.common import local_time from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str @@ -82,7 +83,18 @@ def apply_configuration( confirm_message += "Create the fleet?" else: action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]." - if plan.current_resource.spec.configuration == plan.spec.configuration: + diff = diff_models( + old=plan.current_resource.spec.configuration, + new=plan.spec.configuration, + ignore={ + "ssh_config": { + "ssh_key": True, + "proxy_jump": {"ssh_key"}, + "hosts": {"__all__": {"ssh_key": True, "proxy_jump": {"ssh_key"}}}, + } + }, + ) + if not diff: if command_args.yes and not command_args.force: # --force is required only with --yes, # otherwise we may ask for force apply interactively. diff --git a/src/dstack/_internal/core/compatibility/fleets.py b/src/dstack/_internal/core/compatibility/fleets.py index 4ba2a92c42..b8a8738924 100644 --- a/src/dstack/_internal/core/compatibility/fleets.py +++ b/src/dstack/_internal/core/compatibility/fleets.py @@ -1,19 +1,20 @@ -from typing import Any, Dict, Optional +from typing import Optional +from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec from dstack._internal.core.models.instances import Instance -def get_get_plan_excludes(fleet_spec: FleetSpec) -> Dict: - get_plan_excludes = {} +def get_get_plan_excludes(fleet_spec: FleetSpec) -> IncludeExcludeDictType: + get_plan_excludes: IncludeExcludeDictType = {} spec_excludes = get_fleet_spec_excludes(fleet_spec) if spec_excludes: get_plan_excludes["spec"] = spec_excludes return get_plan_excludes -def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict: - apply_plan_excludes = {} +def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> IncludeExcludeDictType: + apply_plan_excludes: IncludeExcludeDictType = {} spec_excludes = get_fleet_spec_excludes(plan_input.spec) if spec_excludes: apply_plan_excludes["spec"] = spec_excludes @@ -28,23 +29,23 @@ def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict: return {"plan": apply_plan_excludes} -def get_create_fleet_excludes(fleet_spec: FleetSpec) -> Dict: - create_fleet_excludes = {} +def get_create_fleet_excludes(fleet_spec: FleetSpec) -> IncludeExcludeDictType: + create_fleet_excludes: IncludeExcludeDictType = {} spec_excludes = get_fleet_spec_excludes(fleet_spec) if spec_excludes: create_fleet_excludes["spec"] = spec_excludes return create_fleet_excludes -def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]: +def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[IncludeExcludeDictType]: """ Returns `fleet_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes: Dict[str, Any] = {} - configuration_excludes: Dict[str, Any] = {} - profile_excludes: set[str] = set() + spec_excludes: IncludeExcludeDictType = {} + configuration_excludes: IncludeExcludeDictType = {} + profile_excludes: IncludeExcludeSetType = set() profile = fleet_spec.profile if profile.fleets is None: profile_excludes.add("fleets") diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index 228666ed30..b9fe163838 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -1,34 +1,35 @@ -from typing import Dict - +from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec -def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> Dict: +def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType: """ Returns `gateway_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes = {} + spec_excludes: IncludeExcludeDictType = {} spec_excludes["configuration"] = _get_gateway_configuration_excludes( gateway_spec.configuration ) return spec_excludes -def get_create_gateway_excludes(configuration: GatewayConfiguration) -> Dict: +def get_create_gateway_excludes(configuration: GatewayConfiguration) -> IncludeExcludeDictType: """ Returns an exclude mapping to exclude certain fields from the create gateway request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - create_gateway_excludes = {} + create_gateway_excludes: IncludeExcludeDictType = {} create_gateway_excludes["configuration"] = _get_gateway_configuration_excludes(configuration) return create_gateway_excludes -def _get_gateway_configuration_excludes(configuration: GatewayConfiguration) -> Dict: - configuration_excludes = {} +def _get_gateway_configuration_excludes( + configuration: GatewayConfiguration, +) -> IncludeExcludeDictType: + configuration_excludes: IncludeExcludeDictType = {} if configuration.tags is None: configuration_excludes["tags"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/compatibility/logs.py b/src/dstack/_internal/core/compatibility/logs.py index d6c2d3b3bb..7d499047ae 100644 --- a/src/dstack/_internal/core/compatibility/logs.py +++ b/src/dstack/_internal/core/compatibility/logs.py @@ -1,15 +1,16 @@ -from typing import Dict, Optional +from typing import Optional +from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.server.schemas.logs import PollLogsRequest -def get_poll_logs_excludes(request: PollLogsRequest) -> Optional[Dict]: +def get_poll_logs_excludes(request: PollLogsRequest) -> Optional[IncludeExcludeDictType]: """ Returns exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - excludes = {} + excludes: IncludeExcludeDictType = {} if request.next_token is None: excludes["next_token"] = True return excludes if excludes else None diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index f4e6c6acfb..1deea7ee02 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -1,29 +1,30 @@ -from typing import Any, Dict, Optional +from typing import Optional +from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec from dstack._internal.server.schemas.runs import GetRunPlanRequest -def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: +def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeDictType]: """ Returns `plan` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - apply_plan_excludes = {} + apply_plan_excludes: IncludeExcludeDictType = {} run_spec_excludes = get_run_spec_excludes(plan.run_spec) if run_spec_excludes is not None: apply_plan_excludes["run_spec"] = run_spec_excludes current_resource = plan.current_resource if current_resource is not None: - current_resource_excludes = {} + current_resource_excludes: IncludeExcludeDictType = {} current_resource_excludes["status_message"] = True if current_resource.deployment_num == 0: current_resource_excludes["deployment_num"] = True apply_plan_excludes["current_resource"] = current_resource_excludes current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec) - job_submissions_excludes = {} + job_submissions_excludes: IncludeExcludeDictType = {} current_resource_excludes["jobs"] = { "__all__": { "job_spec": get_job_spec_excludes([job.job_spec for job in current_resource.jobs]), @@ -45,7 +46,7 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: job_submissions_excludes["deployment_num"] = True latest_job_submission = current_resource.latest_job_submission if latest_job_submission is not None: - latest_job_submission_excludes = {} + latest_job_submission_excludes: IncludeExcludeDictType = {} current_resource_excludes["latest_job_submission"] = latest_job_submission_excludes if _should_exclude_job_submission_jpd_cpu_arch(latest_job_submission): latest_job_submission_excludes["job_provisioning_data"] = { @@ -62,12 +63,12 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: return {"plan": apply_plan_excludes} -def get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[Dict]: +def get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[IncludeExcludeDictType]: """ Excludes new fields when they are not set to keep clients backward-compatibility with older servers. """ - get_plan_excludes = {} + get_plan_excludes: IncludeExcludeDictType = {} run_spec_excludes = get_run_spec_excludes(request.run_spec) if run_spec_excludes is not None: get_plan_excludes["run_spec"] = run_spec_excludes @@ -76,15 +77,15 @@ def get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[Dict]: return get_plan_excludes -def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: +def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: """ Returns `run_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes: dict[str, Any] = {} - configuration_excludes: dict[str, Any] = {} - profile_excludes: set[str] = set() + spec_excludes: IncludeExcludeDictType = {} + configuration_excludes: IncludeExcludeDictType = {} + profile_excludes: IncludeExcludeSetType = set() configuration = run_spec.configuration profile = run_spec.profile @@ -121,18 +122,16 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: spec_excludes["configuration"] = configuration_excludes if profile_excludes: spec_excludes["profile"] = profile_excludes - if spec_excludes: - return spec_excludes - return None + return spec_excludes -def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]: +def get_job_spec_excludes(job_specs: list[JobSpec]) -> IncludeExcludeDictType: """ Returns `job_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes: dict[str, Any] = {} + spec_excludes: IncludeExcludeDictType = {} if all(s.repo_code_hash is None for s in job_specs): spec_excludes["repo_code_hash"] = True @@ -141,9 +140,7 @@ def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]: if all(not s.file_archives for s in job_specs): spec_excludes["file_archives"] = True - if spec_excludes: - return spec_excludes - return None + return spec_excludes def _should_exclude_job_submission_jpd_cpu_arch(job_submission: JobSubmission) -> bool: diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py index 6c765a273b..7395674f93 100644 --- a/src/dstack/_internal/core/compatibility/volumes.py +++ b/src/dstack/_internal/core/compatibility/volumes.py @@ -1,32 +1,33 @@ -from typing import Dict - +from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeSpec -def get_volume_spec_excludes(volume_spec: VolumeSpec) -> Dict: +def get_volume_spec_excludes(volume_spec: VolumeSpec) -> IncludeExcludeDictType: """ Returns `volume_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes = {} + spec_excludes: IncludeExcludeDictType = {} spec_excludes["configuration"] = _get_volume_configuration_excludes(volume_spec.configuration) return spec_excludes -def get_create_volume_excludes(configuration: VolumeConfiguration) -> Dict: +def get_create_volume_excludes(configuration: VolumeConfiguration) -> IncludeExcludeDictType: """ Returns an exclude mapping to exclude certain fields from the create volume request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - create_volume_excludes = {} + create_volume_excludes: IncludeExcludeDictType = {} create_volume_excludes["configuration"] = _get_volume_configuration_excludes(configuration) return create_volume_excludes -def _get_volume_configuration_excludes(configuration: VolumeConfiguration) -> Dict: - configuration_excludes = {} +def _get_volume_configuration_excludes( + configuration: VolumeConfiguration, +) -> IncludeExcludeDictType: + configuration_excludes: IncludeExcludeDictType = {} if configuration.tags is None: configuration_excludes["tags"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index 22b0b26b7f..c347cf0d32 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -6,6 +6,13 @@ from pydantic_duality import DualBaseModel from typing_extensions import Annotated +IncludeExcludeFieldType = Union[int, str] +IncludeExcludeSetType = set[IncludeExcludeFieldType] +IncludeExcludeDictType = dict[ + IncludeExcludeFieldType, Union[bool, IncludeExcludeSetType, "IncludeExcludeDictType"] +] +IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType] + # DualBaseModel creates two classes for the model: # one with extra = "forbid" (CoreModel/CoreModel.__request__), diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py index 06211a76e2..d50ab90e5b 100644 --- a/src/dstack/_internal/core/services/diff.py +++ b/src/dstack/_internal/core/services/diff.py @@ -1,14 +1,47 @@ -from typing import Any, Dict +from typing import Any, Optional, TypedDict from pydantic import BaseModel +from dstack._internal.core.models.common import IncludeExcludeType + + +class ModelFieldDiff(TypedDict): + old: Any + new: Any + + +ModelDiff = dict[str, ModelFieldDiff] + # TODO: calculate nested diffs -def diff_models(old: BaseModel, new: BaseModel) -> Dict[str, Any]: +def diff_models( + old: BaseModel, new: BaseModel, ignore: Optional[IncludeExcludeType] = None +) -> ModelDiff: + """ + Returns a diff of model instances fields. + + NOTE: `ignore` is implemented as `BaseModel.parse_obj(BaseModel.dict(exclude=ignore))`, + that is, the "ignored" fields are actually not ignored but reset to the default values + before comparison, meaning that 1) any field in `ignore` must have a default value, + 2) the default value must be equal to itself (e.g. `math.nan` != `math.nan`). + + Args: + old: The "old" model instance. + new: The "new" model instance. + ignore: Optional fields to ignore. + + Returns: + A dict of changed fields in the form of + `{: {"old": old_value, "new": new_value}}` + """ if type(old) is not type(new): raise TypeError("Both instances must be of the same Pydantic model class.") - changes = {} + if ignore is not None: + old = type(old).parse_obj(old.dict(exclude=ignore)) + new = type(new).parse_obj(new.dict(exclude=ignore)) + + changes: ModelDiff = {} for field in old.__fields__: old_value = getattr(old, field) new_value = getattr(new, field)