diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 9bfb399aca..b501f0b9de 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -35,6 +35,7 @@ ) from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey from dstack._internal.core.models.repos.base import Repo +from dstack._internal.core.services.diff import diff_models from dstack._internal.utils.common import local_time from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str @@ -82,7 +83,18 @@ def apply_configuration( confirm_message += "Create the fleet?" else: action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]." - if plan.current_resource.spec.configuration == plan.spec.configuration: + diff = diff_models( + old=plan.current_resource.spec.configuration, + new=plan.spec.configuration, + ignore={ + "ssh_config": { + "ssh_key": True, + "proxy_jump": {"ssh_key"}, + "hosts": {"__all__": {"ssh_key": True, "proxy_jump": {"ssh_key"}}}, + } + }, + ) + if not diff: if command_args.yes and not command_args.force: # --force is required only with --yes, # otherwise we may ask for force apply interactively. diff --git a/src/dstack/_internal/core/compatibility/fleets.py b/src/dstack/_internal/core/compatibility/fleets.py index 4ba2a92c42..b8a8738924 100644 --- a/src/dstack/_internal/core/compatibility/fleets.py +++ b/src/dstack/_internal/core/compatibility/fleets.py @@ -1,19 +1,20 @@ -from typing import Any, Dict, Optional +from typing import Optional +from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec from dstack._internal.core.models.instances import Instance -def get_get_plan_excludes(fleet_spec: FleetSpec) -> Dict: - get_plan_excludes = {} +def get_get_plan_excludes(fleet_spec: FleetSpec) -> IncludeExcludeDictType: + get_plan_excludes: IncludeExcludeDictType = {} spec_excludes = get_fleet_spec_excludes(fleet_spec) if spec_excludes: get_plan_excludes["spec"] = spec_excludes return get_plan_excludes -def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict: - apply_plan_excludes = {} +def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> IncludeExcludeDictType: + apply_plan_excludes: IncludeExcludeDictType = {} spec_excludes = get_fleet_spec_excludes(plan_input.spec) if spec_excludes: apply_plan_excludes["spec"] = spec_excludes @@ -28,23 +29,23 @@ def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict: return {"plan": apply_plan_excludes} -def get_create_fleet_excludes(fleet_spec: FleetSpec) -> Dict: - create_fleet_excludes = {} +def get_create_fleet_excludes(fleet_spec: FleetSpec) -> IncludeExcludeDictType: + create_fleet_excludes: IncludeExcludeDictType = {} spec_excludes = get_fleet_spec_excludes(fleet_spec) if spec_excludes: create_fleet_excludes["spec"] = spec_excludes return create_fleet_excludes -def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]: +def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[IncludeExcludeDictType]: """ Returns `fleet_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes: Dict[str, Any] = {} - configuration_excludes: Dict[str, Any] = {} - profile_excludes: set[str] = set() + spec_excludes: IncludeExcludeDictType = {} + configuration_excludes: IncludeExcludeDictType = {} + profile_excludes: IncludeExcludeSetType = set() profile = fleet_spec.profile if profile.fleets is None: profile_excludes.add("fleets") diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index 228666ed30..b9fe163838 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -1,34 +1,35 @@ -from typing import Dict - +from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec -def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> Dict: +def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType: """ Returns `gateway_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes = {} + spec_excludes: IncludeExcludeDictType = {} spec_excludes["configuration"] = _get_gateway_configuration_excludes( gateway_spec.configuration ) return spec_excludes -def get_create_gateway_excludes(configuration: GatewayConfiguration) -> Dict: +def get_create_gateway_excludes(configuration: GatewayConfiguration) -> IncludeExcludeDictType: """ Returns an exclude mapping to exclude certain fields from the create gateway request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - create_gateway_excludes = {} + create_gateway_excludes: IncludeExcludeDictType = {} create_gateway_excludes["configuration"] = _get_gateway_configuration_excludes(configuration) return create_gateway_excludes -def _get_gateway_configuration_excludes(configuration: GatewayConfiguration) -> Dict: - configuration_excludes = {} +def _get_gateway_configuration_excludes( + configuration: GatewayConfiguration, +) -> IncludeExcludeDictType: + configuration_excludes: IncludeExcludeDictType = {} if configuration.tags is None: configuration_excludes["tags"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/compatibility/logs.py b/src/dstack/_internal/core/compatibility/logs.py index d6c2d3b3bb..7d499047ae 100644 --- a/src/dstack/_internal/core/compatibility/logs.py +++ b/src/dstack/_internal/core/compatibility/logs.py @@ -1,15 +1,16 @@ -from typing import Dict, Optional +from typing import Optional +from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.server.schemas.logs import PollLogsRequest -def get_poll_logs_excludes(request: PollLogsRequest) -> Optional[Dict]: +def get_poll_logs_excludes(request: PollLogsRequest) -> Optional[IncludeExcludeDictType]: """ Returns exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - excludes = {} + excludes: IncludeExcludeDictType = {} if request.next_token is None: excludes["next_token"] = True return excludes if excludes else None diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index f4e6c6acfb..1deea7ee02 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -1,29 +1,30 @@ -from typing import Any, Dict, Optional +from typing import Optional +from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec from dstack._internal.server.schemas.runs import GetRunPlanRequest -def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: +def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeDictType]: """ Returns `plan` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - apply_plan_excludes = {} + apply_plan_excludes: IncludeExcludeDictType = {} run_spec_excludes = get_run_spec_excludes(plan.run_spec) if run_spec_excludes is not None: apply_plan_excludes["run_spec"] = run_spec_excludes current_resource = plan.current_resource if current_resource is not None: - current_resource_excludes = {} + current_resource_excludes: IncludeExcludeDictType = {} current_resource_excludes["status_message"] = True if current_resource.deployment_num == 0: current_resource_excludes["deployment_num"] = True apply_plan_excludes["current_resource"] = current_resource_excludes current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec) - job_submissions_excludes = {} + job_submissions_excludes: IncludeExcludeDictType = {} current_resource_excludes["jobs"] = { "__all__": { "job_spec": get_job_spec_excludes([job.job_spec for job in current_resource.jobs]), @@ -45,7 +46,7 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: job_submissions_excludes["deployment_num"] = True latest_job_submission = current_resource.latest_job_submission if latest_job_submission is not None: - latest_job_submission_excludes = {} + latest_job_submission_excludes: IncludeExcludeDictType = {} current_resource_excludes["latest_job_submission"] = latest_job_submission_excludes if _should_exclude_job_submission_jpd_cpu_arch(latest_job_submission): latest_job_submission_excludes["job_provisioning_data"] = { @@ -62,12 +63,12 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]: return {"plan": apply_plan_excludes} -def get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[Dict]: +def get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[IncludeExcludeDictType]: """ Excludes new fields when they are not set to keep clients backward-compatibility with older servers. """ - get_plan_excludes = {} + get_plan_excludes: IncludeExcludeDictType = {} run_spec_excludes = get_run_spec_excludes(request.run_spec) if run_spec_excludes is not None: get_plan_excludes["run_spec"] = run_spec_excludes @@ -76,15 +77,15 @@ def get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[Dict]: return get_plan_excludes -def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: +def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: """ Returns `run_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes: dict[str, Any] = {} - configuration_excludes: dict[str, Any] = {} - profile_excludes: set[str] = set() + spec_excludes: IncludeExcludeDictType = {} + configuration_excludes: IncludeExcludeDictType = {} + profile_excludes: IncludeExcludeSetType = set() configuration = run_spec.configuration profile = run_spec.profile @@ -121,18 +122,16 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: spec_excludes["configuration"] = configuration_excludes if profile_excludes: spec_excludes["profile"] = profile_excludes - if spec_excludes: - return spec_excludes - return None + return spec_excludes -def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]: +def get_job_spec_excludes(job_specs: list[JobSpec]) -> IncludeExcludeDictType: """ Returns `job_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes: dict[str, Any] = {} + spec_excludes: IncludeExcludeDictType = {} if all(s.repo_code_hash is None for s in job_specs): spec_excludes["repo_code_hash"] = True @@ -141,9 +140,7 @@ def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]: if all(not s.file_archives for s in job_specs): spec_excludes["file_archives"] = True - if spec_excludes: - return spec_excludes - return None + return spec_excludes def _should_exclude_job_submission_jpd_cpu_arch(job_submission: JobSubmission) -> bool: diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py index 6c765a273b..7395674f93 100644 --- a/src/dstack/_internal/core/compatibility/volumes.py +++ b/src/dstack/_internal/core/compatibility/volumes.py @@ -1,32 +1,33 @@ -from typing import Dict - +from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeSpec -def get_volume_spec_excludes(volume_spec: VolumeSpec) -> Dict: +def get_volume_spec_excludes(volume_spec: VolumeSpec) -> IncludeExcludeDictType: """ Returns `volume_spec` exclude mapping to exclude certain fields from the request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - spec_excludes = {} + spec_excludes: IncludeExcludeDictType = {} spec_excludes["configuration"] = _get_volume_configuration_excludes(volume_spec.configuration) return spec_excludes -def get_create_volume_excludes(configuration: VolumeConfiguration) -> Dict: +def get_create_volume_excludes(configuration: VolumeConfiguration) -> IncludeExcludeDictType: """ Returns an exclude mapping to exclude certain fields from the create volume request. Use this method to exclude new fields when they are not set to keep clients backward-compatibility with older servers. """ - create_volume_excludes = {} + create_volume_excludes: IncludeExcludeDictType = {} create_volume_excludes["configuration"] = _get_volume_configuration_excludes(configuration) return create_volume_excludes -def _get_volume_configuration_excludes(configuration: VolumeConfiguration) -> Dict: - configuration_excludes = {} +def _get_volume_configuration_excludes( + configuration: VolumeConfiguration, +) -> IncludeExcludeDictType: + configuration_excludes: IncludeExcludeDictType = {} if configuration.tags is None: configuration_excludes["tags"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index 22b0b26b7f..c347cf0d32 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -6,6 +6,13 @@ from pydantic_duality import DualBaseModel from typing_extensions import Annotated +IncludeExcludeFieldType = Union[int, str] +IncludeExcludeSetType = set[IncludeExcludeFieldType] +IncludeExcludeDictType = dict[ + IncludeExcludeFieldType, Union[bool, IncludeExcludeSetType, "IncludeExcludeDictType"] +] +IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType] + # DualBaseModel creates two classes for the model: # one with extra = "forbid" (CoreModel/CoreModel.__request__), diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py index 06211a76e2..d50ab90e5b 100644 --- a/src/dstack/_internal/core/services/diff.py +++ b/src/dstack/_internal/core/services/diff.py @@ -1,14 +1,47 @@ -from typing import Any, Dict +from typing import Any, Optional, TypedDict from pydantic import BaseModel +from dstack._internal.core.models.common import IncludeExcludeType + + +class ModelFieldDiff(TypedDict): + old: Any + new: Any + + +ModelDiff = dict[str, ModelFieldDiff] + # TODO: calculate nested diffs -def diff_models(old: BaseModel, new: BaseModel) -> Dict[str, Any]: +def diff_models( + old: BaseModel, new: BaseModel, ignore: Optional[IncludeExcludeType] = None +) -> ModelDiff: + """ + Returns a diff of model instances fields. + + NOTE: `ignore` is implemented as `BaseModel.parse_obj(BaseModel.dict(exclude=ignore))`, + that is, the "ignored" fields are actually not ignored but reset to the default values + before comparison, meaning that 1) any field in `ignore` must have a default value, + 2) the default value must be equal to itself (e.g. `math.nan` != `math.nan`). + + Args: + old: The "old" model instance. + new: The "new" model instance. + ignore: Optional fields to ignore. + + Returns: + A dict of changed fields in the form of + `{: {"old": old_value, "new": new_value}}` + """ if type(old) is not type(new): raise TypeError("Both instances must be of the same Pydantic model class.") - changes = {} + if ignore is not None: + old = type(old).parse_obj(old.dict(exclude=ignore)) + new = type(new).parse_obj(new.dict(exclude=ignore)) + + changes: ModelDiff = {} for field in old.__fields__: old_value = getattr(old, field) new_value = getattr(new, field)