diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index b501f0b9de..2a7eeb4d59 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -25,6 +25,7 @@ ServerClientError, URLNotFoundError, ) +from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.configurations import ApplyConfigurationType from dstack._internal.core.models.fleets import ( Fleet, @@ -72,7 +73,104 @@ def apply_configuration( spec=spec, ) _print_plan_header(plan) + if plan.action is not None: + self._apply_plan(plan, command_args) + else: + # Old servers don't support spec update + self._apply_plan_on_old_server(plan, command_args) + + def _apply_plan(self, plan: FleetPlan, command_args: argparse.Namespace): + delete_fleet_name: Optional[str] = None + action_message = "" + confirm_message = "" + if plan.current_resource is None: + if plan.spec.configuration.name is not None: + action_message += ( + f"Fleet [code]{plan.spec.configuration.name}[/] does not exist yet." + ) + confirm_message += "Create the fleet?" + else: + action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]." + if plan.action == ApplyAction.CREATE: + delete_fleet_name = plan.current_resource.name + action_message += ( + " Configuration changes detected. Cannot update the fleet in-place" + ) + confirm_message += "Re-create the fleet?" + elif plan.current_resource.spec == plan.effective_spec: + if command_args.yes and not command_args.force: + # --force is required only with --yes, + # otherwise we may ask for force apply interactively. + console.print( + "No configuration changes detected. Use --force to apply anyway." + ) + return + delete_fleet_name = plan.current_resource.name + action_message += " No configuration changes detected." + confirm_message += "Re-create the fleet?" + else: + action_message += " Configuration changes detected." + confirm_message += "Update the fleet in-place?" + + console.print(action_message) + if not command_args.yes and not confirm_ask(confirm_message): + console.print("\nExiting...") + return + + if delete_fleet_name is not None: + with console.status("Deleting existing fleet..."): + self.api.client.fleets.delete( + project_name=self.api.project, names=[delete_fleet_name] + ) + # Fleet deletion is async. Wait for fleet to be deleted. + while True: + try: + self.api.client.fleets.get( + project_name=self.api.project, name=delete_fleet_name + ) + except ResourceNotExistsError: + break + else: + time.sleep(1) + + try: + with console.status("Applying plan..."): + fleet = self.api.client.fleets.apply_plan(project_name=self.api.project, plan=plan) + except ServerClientError as e: + raise CLIError(e.msg) + if command_args.detach: + console.print("Fleet configuration submitted. Exiting...") + return + try: + with MultiItemStatus( + f"Provisioning [code]{fleet.name}[/]...", console=console + ) as live: + while not _finished_provisioning(fleet): + table = get_fleets_table([fleet]) + live.update(table) + time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS) + fleet = self.api.client.fleets.get(self.api.project, fleet.name) + except KeyboardInterrupt: + if confirm_ask("Delete the fleet before exiting?"): + with console.status("Deleting fleet..."): + self.api.client.fleets.delete( + project_name=self.api.project, names=[fleet.name] + ) + else: + console.print("Exiting... Fleet provisioning will continue in the background.") + return + console.print( + get_fleets_table( + [fleet], + verbose=_failed_provisioning(fleet), + format_date=local_time, + ) + ) + if _failed_provisioning(fleet): + console.print("\n[error]Some instances failed. Check the table above for errors.[/]") + exit(1) + def _apply_plan_on_old_server(self, plan: FleetPlan, command_args: argparse.Namespace): action_message = "" confirm_message = "" if plan.current_resource is None: @@ -86,7 +184,7 @@ def apply_configuration( diff = diff_models( old=plan.current_resource.spec.configuration, new=plan.spec.configuration, - ignore={ + reset={ "ssh_config": { "ssh_key": True, "proxy_jump": {"ssh_key"}, diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index 6cf970a955..fd616b7547 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import ApplyAction, CoreModel from dstack._internal.core.models.envs import Env from dstack._internal.core.models.instances import Instance, InstanceOfferWithAvailability, SSHKey from dstack._internal.core.models.profiles import ( @@ -324,6 +324,7 @@ class FleetPlan(CoreModel): offers: List[InstanceOfferWithAvailability] total_offers: int max_offer_price: Optional[float] = None + action: Optional[ApplyAction] = None # default value for backward compatibility def get_effective_spec(self) -> FleetSpec: if self.effective_spec is not None: diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py index d50ab90e5b..0d63cebc43 100644 --- a/src/dstack/_internal/core/services/diff.py +++ b/src/dstack/_internal/core/services/diff.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TypedDict +from typing import Any, Optional, TypedDict, TypeVar from pydantic import BaseModel @@ -15,20 +15,19 @@ class ModelFieldDiff(TypedDict): # TODO: calculate nested diffs def diff_models( - old: BaseModel, new: BaseModel, ignore: Optional[IncludeExcludeType] = None + old: BaseModel, new: BaseModel, reset: Optional[IncludeExcludeType] = None ) -> ModelDiff: """ Returns a diff of model instances fields. - NOTE: `ignore` is implemented as `BaseModel.parse_obj(BaseModel.dict(exclude=ignore))`, - that is, the "ignored" fields are actually not ignored but reset to the default values - before comparison, meaning that 1) any field in `ignore` must have a default value, - 2) the default value must be equal to itself (e.g. `math.nan` != `math.nan`). + The fields specified in the `reset` option are reset to their default values, effectively + excluding them from comparison (assuming that the default value is equal to itself, e.g, + `None == None`, `"task" == "task"`, but `math.nan != math.nan`). Args: old: The "old" model instance. new: The "new" model instance. - ignore: Optional fields to ignore. + reset: Fields to reset to their default values before comparison. Returns: A dict of changed fields in the form of @@ -37,9 +36,9 @@ def diff_models( if type(old) is not type(new): raise TypeError("Both instances must be of the same Pydantic model class.") - if ignore is not None: - old = type(old).parse_obj(old.dict(exclude=ignore)) - new = type(new).parse_obj(new.dict(exclude=ignore)) + if reset is not None: + old = copy_model(old, reset=reset) + new = copy_model(new, reset=reset) changes: ModelDiff = {} for field in old.__fields__: @@ -49,3 +48,24 @@ def diff_models( changes[field] = {"old": old_value, "new": new_value} return changes + + +M = TypeVar("M", bound=BaseModel) + + +def copy_model(model: M, reset: Optional[IncludeExcludeType] = None) -> M: + """ + Returns a deep copy of the model instance. + + Implemented as `BaseModel.parse_obj(BaseModel.dict())`, thus, + unlike `BaseModel.copy(deep=True)`, runs all validations. + + The fields specified in the `reset` option are reset to their default values. + + Args: + reset: Fields to reset to their default values. + + Returns: + A deep copy of the model instance. + """ + return type(model).parse_obj(model.dict(exclude=reset)) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 9925483e45..93ed23b868 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1,6 +1,8 @@ import uuid +from collections.abc import Callable from datetime import datetime, timezone -from typing import List, Literal, Optional, Tuple, Union, cast +from functools import wraps +from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession @@ -13,10 +15,12 @@ ResourceExistsError, ServerClientError, ) +from dstack._internal.core.models.common import ApplyAction, CoreModel from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import ( ApplyFleetPlanInput, Fleet, + FleetConfiguration, FleetPlan, FleetSpec, FleetStatus, @@ -40,6 +44,7 @@ from dstack._internal.core.models.runs import Requirements, get_policy_map from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name +from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models from dstack._internal.server.db import get_db from dstack._internal.server.models import ( FleetModel, @@ -49,7 +54,10 @@ ) from dstack._internal.server.services import instances as instances_services from dstack._internal.server.services import offers as offers_services -from dstack._internal.server.services.instances import list_active_remote_instances +from dstack._internal.server.services.instances import ( + get_instance_remote_connection_info, + list_active_remote_instances, +) from dstack._internal.server.services.locking import ( get_locker, string_to_lock_id, @@ -178,8 +186,9 @@ async def list_project_fleet_models( async def get_fleet( session: AsyncSession, project: ProjectModel, - name: Optional[str], - fleet_id: Optional[uuid.UUID], + name: Optional[str] = None, + fleet_id: Optional[uuid.UUID] = None, + include_sensitive: bool = False, ) -> Optional[Fleet]: if fleet_id is not None: fleet_model = await get_project_fleet_model_by_id( @@ -193,7 +202,7 @@ async def get_fleet( raise ServerClientError("name or id must be specified") if fleet_model is None: return None - return fleet_model_to_fleet(fleet_model) + return fleet_model_to_fleet(fleet_model, include_sensitive=include_sensitive) async def get_project_fleet_model_by_id( @@ -236,23 +245,32 @@ async def get_plan( spec: FleetSpec, ) -> FleetPlan: # Spec must be copied by parsing to calculate merged_profile - effective_spec = FleetSpec.parse_obj(spec.dict()) + effective_spec = copy_model(spec) effective_spec = await apply_plugin_policies( user=user.name, project=project.name, spec=effective_spec, ) - effective_spec = FleetSpec.parse_obj(effective_spec.dict()) - _validate_fleet_spec_and_set_defaults(spec) + # Spec must be copied by parsing to calculate merged_profile + effective_spec = copy_model(effective_spec) + _validate_fleet_spec_and_set_defaults(effective_spec) + + action = ApplyAction.CREATE current_fleet: Optional[Fleet] = None current_fleet_id: Optional[uuid.UUID] = None + if effective_spec.configuration.name is not None: - current_fleet_model = await get_project_fleet_model_by_name( - session=session, project=project, name=effective_spec.configuration.name + current_fleet = await get_fleet( + session=session, + project=project, + name=effective_spec.configuration.name, + include_sensitive=True, ) - if current_fleet_model is not None: - current_fleet = fleet_model_to_fleet(current_fleet_model) - current_fleet_id = current_fleet_model.id + if current_fleet is not None: + _set_fleet_spec_defaults(current_fleet.spec) + if _can_update_fleet_spec(current_fleet.spec, effective_spec): + action = ApplyAction.UPDATE + current_fleet_id = current_fleet.id await _check_ssh_hosts_not_yet_added(session, effective_spec, current_fleet_id) offers = [] @@ -265,7 +283,10 @@ async def get_plan( blocks=effective_spec.configuration.blocks, ) offers = [offer for _, offer in offers_with_backends] + _remove_fleet_spec_sensitive_info(effective_spec) + if current_fleet is not None: + _remove_fleet_spec_sensitive_info(current_fleet.spec) plan = FleetPlan( project_name=project.name, user=user.name, @@ -275,6 +296,7 @@ async def get_plan( offers=offers[:50], total_offers=len(offers), max_offer_price=max((offer.price for offer in offers), default=None), + action=action, ) return plan @@ -327,11 +349,77 @@ async def apply_plan( plan: ApplyFleetPlanInput, force: bool, ) -> Fleet: - return await create_fleet( + spec = await apply_plugin_policies( + user=user.name, + project=project.name, + spec=plan.spec, + ) + # Spec must be copied by parsing to calculate merged_profile + spec = copy_model(spec) + _validate_fleet_spec_and_set_defaults(spec) + + if spec.configuration.ssh_config is not None: + _check_can_manage_ssh_fleets(user=user, project=project) + + configuration = spec.configuration + if configuration.name is None: + return await _create_fleet( + session=session, + project=project, + user=user, + spec=spec, + ) + + fleet_model = await get_project_fleet_model_by_name( + session=session, + project=project, + name=configuration.name, + ) + if fleet_model is None: + return await _create_fleet( + session=session, + project=project, + user=user, + spec=spec, + ) + + instances_ids = sorted(i.id for i in fleet_model.instances if not i.deleted) + await session.commit() + async with ( + get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, [fleet_model.id]), + get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), + ): + # Refetch after lock + # TODO: Lock instances with FOR UPDATE? + res = await session.execute( + select(FleetModel) + .where( + FleetModel.project_id == project.id, + FleetModel.id == fleet_model.id, + FleetModel.deleted == False, + ) + .options(selectinload(FleetModel.instances)) + .options(selectinload(FleetModel.runs)) + .execution_options(populate_existing=True) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True) + ) + fleet_model = res.scalars().unique().one_or_none() + if fleet_model is not None: + return await _update_fleet( + session=session, + project=project, + spec=spec, + current_resource=plan.current_resource, + force=force, + fleet_model=fleet_model, + ) + + return await _create_fleet( session=session, project=project, user=user, - spec=plan.spec, + spec=spec, ) @@ -341,73 +429,19 @@ async def create_fleet( user: UserModel, spec: FleetSpec, ) -> Fleet: - # Spec must be copied by parsing to calculate merged_profile spec = await apply_plugin_policies( user=user.name, project=project.name, spec=spec, ) - spec = FleetSpec.parse_obj(spec.dict()) + # Spec must be copied by parsing to calculate merged_profile + spec = copy_model(spec) _validate_fleet_spec_and_set_defaults(spec) if spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) - lock_namespace = f"fleet_names_{project.name}" - if get_db().dialect_name == "sqlite": - # Start new transaction to see committed changes after lock - await session.commit() - elif get_db().dialect_name == "postgresql": - await session.execute( - select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) - ) - - lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) - async with lock: - if spec.configuration.name is not None: - fleet_model = await get_project_fleet_model_by_name( - session=session, - project=project, - name=spec.configuration.name, - ) - if fleet_model is not None: - raise ResourceExistsError() - else: - spec.configuration.name = await generate_fleet_name(session=session, project=project) - - fleet_model = FleetModel( - id=uuid.uuid4(), - name=spec.configuration.name, - project=project, - status=FleetStatus.ACTIVE, - spec=spec.json(), - instances=[], - ) - session.add(fleet_model) - if spec.configuration.ssh_config is not None: - for i, host in enumerate(spec.configuration.ssh_config.hosts): - instances_model = await create_fleet_ssh_instance_model( - project=project, - spec=spec, - ssh_params=spec.configuration.ssh_config, - env=spec.configuration.env, - instance_num=i, - host=host, - ) - fleet_model.instances.append(instances_model) - else: - for i in range(_get_fleet_nodes_to_provision(spec)): - instance_model = await create_fleet_instance_model( - session=session, - project=project, - user=user, - spec=spec, - reservation=spec.configuration.reservation, - instance_num=i, - ) - fleet_model.instances.append(instance_model) - await session.commit() - return fleet_model_to_fleet(fleet_model) + return await _create_fleet(session=session, project=project, user=user, spec=spec) async def create_fleet_instance_model( @@ -600,6 +634,235 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool: return len(active_instances) == 0 +async def _create_fleet( + session: AsyncSession, + project: ProjectModel, + user: UserModel, + spec: FleetSpec, +) -> Fleet: + lock_namespace = f"fleet_names_{project.name}" + if get_db().dialect_name == "sqlite": + # Start new transaction to see committed changes after lock + await session.commit() + elif get_db().dialect_name == "postgresql": + await session.execute( + select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) + ) + + lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) + async with lock: + if spec.configuration.name is not None: + fleet_model = await get_project_fleet_model_by_name( + session=session, + project=project, + name=spec.configuration.name, + ) + if fleet_model is not None: + raise ResourceExistsError() + else: + spec.configuration.name = await generate_fleet_name(session=session, project=project) + + fleet_model = FleetModel( + id=uuid.uuid4(), + name=spec.configuration.name, + project=project, + status=FleetStatus.ACTIVE, + spec=spec.json(), + instances=[], + ) + session.add(fleet_model) + if spec.configuration.ssh_config is not None: + for i, host in enumerate(spec.configuration.ssh_config.hosts): + instances_model = await create_fleet_ssh_instance_model( + project=project, + spec=spec, + ssh_params=spec.configuration.ssh_config, + env=spec.configuration.env, + instance_num=i, + host=host, + ) + fleet_model.instances.append(instances_model) + else: + for i in range(_get_fleet_nodes_to_provision(spec)): + instance_model = await create_fleet_instance_model( + session=session, + project=project, + user=user, + spec=spec, + reservation=spec.configuration.reservation, + instance_num=i, + ) + fleet_model.instances.append(instance_model) + await session.commit() + return fleet_model_to_fleet(fleet_model) + + +async def _update_fleet( + session: AsyncSession, + project: ProjectModel, + spec: FleetSpec, + current_resource: Optional[Fleet], + force: bool, + fleet_model: FleetModel, +) -> Fleet: + fleet = fleet_model_to_fleet(fleet_model) + _set_fleet_spec_defaults(fleet.spec) + fleet_sensitive = fleet_model_to_fleet(fleet_model, include_sensitive=True) + _set_fleet_spec_defaults(fleet_sensitive.spec) + + if not force: + if current_resource is not None: + _set_fleet_spec_defaults(current_resource.spec) + if ( + current_resource is None + or current_resource.id != fleet.id + or current_resource.spec != fleet.spec + ): + raise ServerClientError( + "Failed to apply plan. Resource has been changed. Try again or use force apply." + ) + + _check_can_update_fleet_spec(fleet_sensitive.spec, spec) + + spec_json = spec.json() + fleet_model.spec = spec_json + + if ( + fleet_sensitive.spec.configuration.ssh_config is not None + and spec.configuration.ssh_config is not None + ): + added_hosts, removed_hosts, changed_hosts = _calculate_ssh_hosts_changes( + current=fleet_sensitive.spec.configuration.ssh_config.hosts, + new=spec.configuration.ssh_config.hosts, + ) + # `_check_can_update_fleet_spec` ensures hosts are not changed + assert not changed_hosts, changed_hosts + active_instance_nums: set[int] = set() + removed_instance_nums: list[int] = [] + if removed_hosts or added_hosts: + for instance_model in fleet_model.instances: + if instance_model.deleted: + continue + active_instance_nums.add(instance_model.instance_num) + rci = get_instance_remote_connection_info(instance_model) + if rci is None: + logger.error( + "Cloud instance %s in SSH fleet %s", + instance_model.id, + fleet_model.id, + ) + continue + if rci.host in removed_hosts: + removed_instance_nums.append(instance_model.instance_num) + if added_hosts: + await _check_ssh_hosts_not_yet_added(session, spec, fleet.id) + for host in added_hosts.values(): + instance_num = _get_next_instance_num(active_instance_nums) + instance_model = await create_fleet_ssh_instance_model( + project=project, + spec=spec, + ssh_params=spec.configuration.ssh_config, + env=spec.configuration.env, + instance_num=instance_num, + host=host, + ) + fleet_model.instances.append(instance_model) + active_instance_nums.add(instance_num) + if removed_instance_nums: + _terminate_fleet_instances(fleet_model, removed_instance_nums) + + await session.commit() + return fleet_model_to_fleet(fleet_model) + + +def _can_update_fleet_spec(current_fleet_spec: FleetSpec, new_fleet_spec: FleetSpec) -> bool: + try: + _check_can_update_fleet_spec(current_fleet_spec, new_fleet_spec) + except ServerClientError as e: + logger.debug("Run cannot be updated: %s", repr(e)) + return False + return True + + +M = TypeVar("M", bound=CoreModel) + + +def _check_can_update(*updatable_fields: str): + def decorator(fn: Callable[[M, M, ModelDiff], None]) -> Callable[[M, M], None]: + @wraps(fn) + def inner(current: M, new: M): + diff = _check_can_update_inner(current, new, updatable_fields) + fn(current, new, diff) + + return inner + + return decorator + + +def _check_can_update_inner(current: M, new: M, updatable_fields: tuple[str, ...]) -> ModelDiff: + diff = diff_models(current, new) + changed_fields = diff.keys() + if not (changed_fields <= set(updatable_fields)): + raise ServerClientError( + f"Failed to update fields {list(changed_fields)}." + f" Can only update {list(updatable_fields)}." + ) + return diff + + +@_check_can_update("configuration", "configuration_path") +def _check_can_update_fleet_spec(current: FleetSpec, new: FleetSpec, diff: ModelDiff): + if "configuration" in diff: + _check_can_update_fleet_configuration(current.configuration, new.configuration) + + +@_check_can_update("ssh_config") +def _check_can_update_fleet_configuration( + current: FleetConfiguration, new: FleetConfiguration, diff: ModelDiff +): + if "ssh_config" in diff: + current_ssh_config = current.ssh_config + new_ssh_config = new.ssh_config + if current_ssh_config is None: + if new_ssh_config is not None: + raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update") + elif new_ssh_config is None: + raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update") + else: + _check_can_update_ssh_config(current_ssh_config, new_ssh_config) + + +@_check_can_update("hosts") +def _check_can_update_ssh_config(current: SSHParams, new: SSHParams, diff: ModelDiff): + if "hosts" in diff: + _, _, changed_hosts = _calculate_ssh_hosts_changes(current.hosts, new.hosts) + if changed_hosts: + raise ServerClientError( + f"Hosts configuration changed, cannot update: {list(changed_hosts)}" + ) + + +def _calculate_ssh_hosts_changes( + current: list[Union[SSHHostParams, str]], new: list[Union[SSHHostParams, str]] +) -> tuple[dict[str, Union[SSHHostParams, str]], set[str], set[str]]: + current_hosts = {h if isinstance(h, str) else h.hostname: h for h in current} + new_hosts = {h if isinstance(h, str) else h.hostname: h for h in new} + added_hosts = {h: new_hosts[h] for h in new_hosts.keys() - current_hosts} + removed_hosts = current_hosts.keys() - new_hosts + changed_hosts: set[str] = set() + for host in current_hosts.keys() & new_hosts: + current_host = current_hosts[host] + new_host = new_hosts[host] + if isinstance(current_host, str) or isinstance(new_host, str): + if current_host != new_host: + changed_hosts.add(host) + elif diff_models( + current_host, new_host, reset={"identity_file": True, "proxy_jump": {"identity_file"}} + ): + changed_hosts.add(host) + return added_hosts, removed_hosts, changed_hosts + + def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel): if user.global_role == GlobalRole.ADMIN: return @@ -654,6 +917,8 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec): validate_dstack_resource_name(spec.configuration.name) if spec.configuration.ssh_config is None and spec.configuration.nodes is None: raise ServerClientError("No ssh_config or nodes specified") + if spec.configuration.ssh_config is not None and spec.configuration.nodes is not None: + raise ServerClientError("ssh_config and nodes are mutually exclusive") if spec.configuration.ssh_config is not None: _validate_all_ssh_params_specified(spec.configuration.ssh_config) if spec.configuration.ssh_config.ssh_key is not None: @@ -662,6 +927,10 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec): if isinstance(host, SSHHostParams) and host.ssh_key is not None: _validate_ssh_key(host.ssh_key) _validate_internal_ips(spec.configuration.ssh_config) + _set_fleet_spec_defaults(spec) + + +def _set_fleet_spec_defaults(spec: FleetSpec): if spec.configuration.resources is not None: set_resources_defaults(spec.configuration.resources) @@ -734,3 +1003,16 @@ def _get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements: reservation=fleet_spec.configuration.reservation, ) return requirements + + +def _get_next_instance_num(instance_nums: set[int]) -> int: + if not instance_nums: + return 0 + min_instance_num = min(instance_nums) + if min_instance_num > 0: + return 0 + instance_num = min_instance_num + 1 + while True: + if instance_num not in instance_nums: + return instance_num + instance_num += 1 diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 53ca55396a..c2dd5c2a1f 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -106,6 +106,14 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements: return Requirements.__response__.parse_raw(instance_model.requirements) +def get_instance_remote_connection_info( + instance_model: InstanceModel, +) -> Optional[RemoteConnectionInfo]: + if instance_model.remote_connection_info is None: + return None + return RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info) + + def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]: """ Returns a pair of SSH private keys: host key and optional proxy jump key. diff --git a/src/dstack/_internal/server/services/locking.py b/src/dstack/_internal/server/services/locking.py index 37807b37a8..4c3b7f938a 100644 --- a/src/dstack/_internal/server/services/locking.py +++ b/src/dstack/_internal/server/services/locking.py @@ -172,7 +172,7 @@ async def _wait_to_lock_many( The keys must be sorted to prevent deadlock. """ left_to_lock = keys.copy() - while len(left_to_lock) > 0: + while True: async with lock: locked_now_num = 0 for key in left_to_lock: @@ -182,4 +182,6 @@ async def _wait_to_lock_many( locked.add(key) locked_now_num += 1 left_to_lock = left_to_lock[locked_now_num:] + if not left_to_lock: + return await asyncio.sleep(delay) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 047adb5c14..7512fe1be7 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -31,6 +31,8 @@ FleetSpec, FleetStatus, InstanceGroupPlacement, + SSHHostParams, + SSHParams, ) from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus from dstack._internal.core.models.instances import ( @@ -378,6 +380,7 @@ def get_job_provisioning_data( hostname: str = "127.0.0.4", internal_ip: Optional[str] = "127.0.0.4", price: float = 10.5, + instance_type: Optional[InstanceType] = None, ) -> JobProvisioningData: gpus = [ Gpu( @@ -386,14 +389,16 @@ def get_job_provisioning_data( vendor=gpuhunt.AcceleratorVendor.NVIDIA, ) ] * gpu_count - return JobProvisioningData( - backend=backend, - instance_type=InstanceType( + if instance_type is None: + instance_type = InstanceType( name="instance", resources=Resources( cpus=cpu_count, memory_mib=int(memory_gib * 1024), spot=spot, gpus=gpus ), - ), + ) + return JobProvisioningData( + backend=backend, + instance_type=instance_type, instance_id="instance_id", hostname=hostname, internal_ip=internal_ip, @@ -549,6 +554,31 @@ def get_fleet_configuration( ) +def get_ssh_fleet_configuration( + name: str = "test-fleet", + user: str = "ubuntu", + ssh_key: Optional[SSHKey] = None, + hosts: Optional[list[Union[SSHHostParams, str]]] = None, + network: Optional[str] = None, + placement: Optional[InstanceGroupPlacement] = None, +) -> FleetConfiguration: + if ssh_key is None: + ssh_key = SSHKey(public="", private=get_private_key_string()) + if hosts is None: + hosts = ["10.0.0.100"] + ssh_config = SSHParams( + user=user, + ssh_key=ssh_key, + hosts=hosts, + network=network, + ) + return FleetConfiguration( + name=name, + ssh_config=ssh_config, + placement=placement, + ) + + async def create_instance( session: AsyncSession, project: ProjectModel, @@ -590,7 +620,9 @@ async def create_instance( internal_ip=None, ) if offer == "auto": - offer = get_instance_offer_with_availability(backend=backend, region=region, spot=spot) + offer = get_instance_offer_with_availability( + backend=backend, region=region, spot=spot, price=price + ) if profile is None: profile = Profile(name="test_name") diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 87a970c73c..5c4758501d 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -10,7 +10,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.fleets import FleetConfiguration, FleetStatus, SSHParams +from dstack._internal.core.models.fleets import ( + FleetConfiguration, + FleetStatus, + InstanceGroupPlacement, + SSHParams, +) from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, @@ -21,6 +26,7 @@ ) from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.services.fleets import fleet_model_to_fleet from dstack._internal.server.services.permissions import DefaultPermissions from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( @@ -35,7 +41,11 @@ get_auth_headers, get_fleet_configuration, get_fleet_spec, + get_instance_offer_with_availability, + get_job_provisioning_data, get_private_key_string, + get_remote_connection_info, + get_ssh_fleet_configuration, ) pytestmark = pytest.mark.usefixtures("image_config_mock") @@ -415,17 +425,14 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.USER ) - spec = get_fleet_spec( - conf=FleetConfiguration( - name="test-ssh-fleet", - ssh_config=SSHParams( - user="ubuntu", - ssh_key=SSHKey(public="", private=get_private_key_string()), - hosts=["1.1.1.1"], - network=None, - ), - ) + conf = get_ssh_fleet_configuration( + name="test-ssh-fleet", + user="ubuntu", + ssh_key=SSHKey(public="", private=get_private_key_string()), + hosts=["1.1.1.1"], + network=None, ) + spec = get_fleet_spec(conf=conf) with patch("uuid.uuid4") as m: m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") response = await client.post( @@ -541,6 +548,212 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A instance = res.unique().scalar_one() assert instance.remote_connection_info is not None + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), real_asyncio=True) + async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_conf = get_ssh_fleet_configuration( + name="test-ssh-fleet", + user="ubuntu", + ssh_key=SSHKey(public="", private=get_private_key_string()), + hosts=["10.0.0.100"], + network=None, + ) + current_spec = get_fleet_spec(conf=current_conf) + spec = current_spec.copy(deep=True) + # 10.0.0.100 removed, 10.0.0.101 added + spec.configuration.ssh_config.hosts = ["10.0.0.101"] + + fleet = await create_fleet(session=session, project=project, spec=current_spec) + instance_type = InstanceType( + name="ssh", + resources=Resources(cpus=2, memory_mib=8, gpus=[], spot=False), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + backend=BackendType.REMOTE, + name="test-ssh-fleet-0", + region="remote", + price=0.0, + status=InstanceStatus.IDLE, + offer=get_instance_offer_with_availability( + backend=BackendType.REMOTE, + region="remote", + price=0.0, + ), + job_provisioning_data=get_job_provisioning_data( + instance_type=instance_type, + hostname="10.0.0.100", + ), + remote_connection_info=get_remote_connection_info(host="10.0.0.100"), + ) + + with patch("uuid.uuid4") as m: + m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={ + "plan": { + "spec": spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + }, + "force": False, + }, + ) + + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(fleet.id), + "name": spec.configuration.name, + "project_name": project.name, + "spec": { + "configuration_path": spec.configuration_path, + "configuration": { + "env": {}, + "ssh_config": { + "user": "ubuntu", + "port": None, + "identity_file": None, + "ssh_key": None, # should not return ssh_key + "proxy_jump": None, + "hosts": ["10.0.0.101"], + "network": None, + }, + "nodes": None, + "placement": None, + "resources": { + "cpu": {"min": 2, "max": None}, + "memory": {"min": 8.0, "max": None}, + "shm_size": None, + "gpu": None, + "disk": {"size": {"min": 100.0, "max": None}}, + }, + "backends": None, + "regions": None, + "availability_zones": None, + "instance_types": None, + "spot_policy": None, + "retry": None, + "max_price": None, + "idle_duration": None, + "type": "fleet", + "name": spec.configuration.name, + "reservation": None, + "blocks": 1, + "tags": None, + }, + "profile": { + "backends": None, + "regions": None, + "availability_zones": None, + "instance_types": None, + "spot_policy": None, + "retry": None, + "max_duration": None, + "stop_duration": None, + "max_price": None, + "creation_policy": None, + "idle_duration": None, + "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, + "name": "", + "default": False, + "reservation": None, + "fleets": None, + "tags": None, + }, + "autocreated": False, + }, + "created_at": "2023-01-02T03:04:00+00:00", + "status": "active", + "status_message": None, + "instances": [ + { + "id": str(instance.id), + "project_name": project.name, + "backend": "remote", + "instance_type": { + "name": "ssh", + "resources": { + "cpu_arch": None, + "cpus": 2, + "memory_mib": 8, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "cpu=2 mem=0GB disk=100GB", + }, + }, + "name": "test-ssh-fleet-0", + "fleet_id": str(fleet.id), + "fleet_name": "test-ssh-fleet", + "instance_num": 0, + "job_name": None, + "hostname": "10.0.0.100", + "status": "terminating", + "unreachable": False, + "termination_reason": None, + "created": "2023-01-02T03:04:00+00:00", + "region": "remote", + "availability_zone": None, + "price": 0.0, + "total_blocks": 1, + "busy_blocks": 0, + }, + { + "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", + "project_name": project.name, + "backend": "remote", + "instance_type": { + "name": "ssh", + "resources": { + "cpu_arch": None, + "cpus": 2, + "memory_mib": 8, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "cpu=2 mem=0GB disk=100GB", + }, + }, + "name": "test-ssh-fleet-1", + "fleet_id": str(fleet.id), + "fleet_name": "test-ssh-fleet", + "instance_num": 1, + "job_name": None, + "hostname": "10.0.0.101", + "status": "pending", + "unreachable": False, + "termination_reason": None, + "created": "2023-01-02T03:04:00+00:00", + "region": "remote", + "availability_zone": None, + "price": 0.0, + "total_blocks": 1, + "busy_blocks": 0, + }, + ], + } + res = await session.execute(select(FleetModel)) + assert res.scalar_one() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + res = await session.execute( + select(InstanceModel).where(InstanceModel.id == "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") + ) + instance = res.unique().scalar_one() + assert instance.status == InstanceStatus.PENDING + assert instance.remote_connection_info is not None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)) @@ -820,7 +1033,9 @@ async def test_returns_40x_if_not_authenticated( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncClient): + async def test_returns_create_plan_for_new_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user) await add_project_member( @@ -861,4 +1076,85 @@ async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncC "offers": [json.loads(o.json()) for o in offers], "total_offers": len(offers), "max_offer_price": 1.0, + "action": "create", + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_update_plan_for_existing_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + conf = get_ssh_fleet_configuration(hosts=["10.0.0.100"]) + spec = get_fleet_spec(conf=conf) + effective_spec = spec.copy(deep=True) + effective_spec.configuration.ssh_config.ssh_key = None + current_spec = spec.copy(deep=True) + # `hosts` can be updated in-place + current_spec.configuration.ssh_config.hosts = ["10.0.0.100", "10.0.0.101"] + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + assert response.status_code == 200 + assert response.json() == { + "project_name": project.name, + "user": user.name, + "spec": spec.dict(), + "effective_spec": effective_spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + "offers": [], + "total_offers": 0, + "max_offer_price": None, + "action": "update", + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_create_plan_for_existing_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + conf = get_ssh_fleet_configuration(placement=InstanceGroupPlacement.ANY) + spec = get_fleet_spec(conf=conf) + effective_spec = spec.copy(deep=True) + effective_spec.configuration.ssh_config.ssh_key = None + current_spec = spec.copy(deep=True) + # `placement` cannot be updated in-place + current_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + assert response.status_code == 200 + assert response.json() == { + "project_name": project.name, + "user": user.name, + "spec": spec.dict(), + "effective_spec": effective_spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + "offers": [], + "total_offers": 0, + "max_offer_price": None, + "action": "create", } + + +def _fleet_model_to_json_dict(fleet: FleetModel) -> dict: + return json.loads(fleet_model_to_fleet(fleet).json())