diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index a56296bba..0326ec584 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -78,17 +78,17 @@ class SSHHostParams(CoreModel): ssh_key: Optional[SSHKey] = None blocks: Annotated[ - Union[Literal["auto"], int], + Optional[Union[Literal["auto"], int]], Field( description=( "The amount of blocks to split the instance into, a number or `auto`." " `auto` means as many as possible." " The number of GPUs and CPUs must be divisible by the number of blocks." - " Defaults to `1`, i.e. do not split" + " Defaults to the top-level `blocks` value." ), ge=1, ), - ] = 1 + ] = None @validator("internal_ip") def validate_internal_ip(cls, value): diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 702086f6d..385bdec41 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -674,6 +674,7 @@ async def create_fleet_ssh_instance_model( spec: FleetSpec, ssh_params: SSHParams, env: Env, + blocks: Union[int, Literal["auto"]], instance_num: int, host: Union[SSHHostParams, str], ) -> InstanceModel: @@ -684,7 +685,6 @@ async def create_fleet_ssh_instance_model( port = ssh_params.port proxy_jump = ssh_params.proxy_jump internal_ip = None - blocks = 1 else: hostname = host.hostname ssh_user = host.user or ssh_params.user @@ -692,7 +692,8 @@ async def create_fleet_ssh_instance_model( port = host.port or ssh_params.port proxy_jump = host.proxy_jump or ssh_params.proxy_jump internal_ip = host.internal_ip - blocks = host.blocks + if host.blocks is not None: + blocks = host.blocks if ssh_user is None or ssh_key is None: # This should not be reachable but checked by fleet spec validation @@ -1042,6 +1043,7 @@ async def _create_fleet( spec=spec, ssh_params=spec.configuration.ssh_config, env=spec.configuration.env, + blocks=spec.configuration.blocks, instance_num=i, host=host, ) @@ -1152,6 +1154,7 @@ async def _update_fleet( spec=spec, ssh_params=spec.configuration.ssh_config, env=spec.configuration.env, + blocks=spec.configuration.blocks, instance_num=instance_num, host=host, ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index e38907b01..17500246e 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -3,7 +3,7 @@ from collections.abc import Callable from contextlib import contextmanager from datetime import datetime, timezone -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from uuid import UUID import gpuhunt @@ -703,6 +703,7 @@ def get_ssh_fleet_configuration( hosts: Optional[list[Union[SSHHostParams, str]]] = None, network: Optional[str] = None, placement: Optional[InstanceGroupPlacement] = None, + blocks: Optional[Union[int, Literal["auto"]]] = None, ) -> FleetConfiguration: if ssh_key is None: ssh_key = SSHKey(public="", private=get_private_key_string()) @@ -714,10 +715,14 @@ def get_ssh_fleet_configuration( hosts=hosts, network=network, ) + optional_properties: dict[str, Any] = {} + if blocks is not None: + optional_properties["blocks"] = blocks return FleetConfiguration( name=name, ssh_config=ssh_config, placement=placement, + **optional_properties, ) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 47544b921..77dacef91 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Optional, Union +from typing import Literal, Optional, Union from unittest.mock import Mock, patch from uuid import uuid4 @@ -16,6 +16,7 @@ FleetConfiguration, FleetStatus, InstanceGroupPlacement, + SSHHostParams, SSHParams, ) from dstack._internal.core.models.instances import ( @@ -1178,6 +1179,56 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A instance = res.unique().scalar_one() assert instance.remote_connection_info is not None + @pytest.mark.parametrize( + ["top_level_blocks", "host_blocks", "host_type", "expected_blocks"], + [ + pytest.param(None, None, str, 1, id="global-default-string"), + pytest.param(None, None, SSHHostParams, 1, id="global-default-object"), + pytest.param(4, None, str, 4, id="top-level-int-string"), + pytest.param(4, None, SSHHostParams, 4, id="top-level-int-object"), + pytest.param("auto", None, str, None, id="top-level-auto-string"), + pytest.param("auto", None, SSHHostParams, None, id="top-level-auto-object"), + pytest.param("auto", 4, SSHHostParams, 4, id="host-level-int"), + pytest.param(4, "auto", SSHHostParams, None, id="host-level-auto"), + ], + ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_creates_ssh_fleet_with_blocks( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + top_level_blocks: Optional[Union[int, Literal["auto"]]], + host_blocks: Optional[Union[int, Literal["auto"]]], + host_type: Union[type[str], type[SSHHostParams]], + expected_blocks: Optional[int], + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + if host_type is str: + host = "1.1.1.1" + elif host_blocks is None: + host = SSHHostParams(hostname="1.1.1.1") + else: + host = SSHHostParams(hostname="1.1.1.1", blocks=host_blocks) + conf = get_ssh_fleet_configuration(blocks=top_level_blocks, hosts=[host]) + spec = get_fleet_spec(conf=conf) + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={"plan": {"spec": spec.dict()}, "force": False}, + ) + assert response.status_code == 200, response.json() + res = await session.execute(select(FleetModel)) + assert len(res.scalars().all()) == 1 + res = await session.execute(select(InstanceModel)) + instance = res.scalar_one() + assert instance.total_blocks == expected_blocks + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), real_asyncio=True)