Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,10 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
else:
has_amd_gpu = vendor == gpuhunt.AcceleratorVendor.AMD
has_tt_gpu = vendor == gpuhunt.AcceleratorVendor.TENSTORRENT
if has_amd_gpu and conf.image is None:
# When docker=True, the system uses Docker-in-Docker image, so no custom image is required
if has_amd_gpu and conf.image is None and conf.docker is not True:
raise ConfigurationError("`image` is required if `resources.gpu.vendor` is `amd`")
if has_tt_gpu and conf.image is None:
if has_tt_gpu and conf.image is None and conf.docker is not True:
raise ConfigurationError(
"`image` is required if `resources.gpu.vendor` is `tenstorrent`"
)
Expand Down
20 changes: 19 additions & 1 deletion src/tests/_internal/cli/services/configurators/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def test_interpolates_env(self):

class TestValidateGPUVendorAndImage:
def prepare_conf(
self, *, image: Optional[str] = None, gpu_spec: Optional[str] = None
self,
*,
image: Optional[str] = None,
gpu_spec: Optional[str] = None,
docker: Optional[bool] = None,
) -> BaseRunConfiguration:
conf_dict = {
"type": "none",
Expand All @@ -110,6 +114,8 @@ def prepare_conf(
conf_dict["resources"] = {
"gpu": gpu_spec,
}
if docker is not None:
conf_dict["docker"] = docker
return BaseRunConfiguration.parse_obj(conf_dict)

def validate(self, conf: BaseRunConfiguration) -> None:
Expand Down Expand Up @@ -199,6 +205,12 @@ def test_amd_vendor_declared_no_image(self):
):
self.validate(conf)

@pytest.mark.parametrize("gpu_spec", ["AMD", "MI300X"])
def test_amd_vendor_docker_true_no_image(self, gpu_spec):
conf = self.prepare_conf(gpu_spec=gpu_spec, docker=True)
self.validate(conf)
assert conf.resources.gpu.vendor == AcceleratorVendor.AMD

@pytest.mark.parametrize("gpu_spec", ["MI300X", "MI300x", "mi300x"])
def test_amd_vendor_inferred_no_image(self, gpu_spec):
conf = self.prepare_conf(gpu_spec=gpu_spec)
Expand All @@ -222,6 +234,12 @@ def test_two_vendors_including_amd_inferred_no_image(self, gpu_spec):
):
self.validate(conf)

@pytest.mark.parametrize("gpu_spec", ["n150", "n300"])
def test_tenstorrent_docker_true_no_image(self, gpu_spec):
conf = self.prepare_conf(gpu_spec=gpu_spec, docker=True)
self.validate(conf)
assert conf.resources.gpu.vendor == AcceleratorVendor.TENSTORRENT


class TestValidateCPUArchAndImage:
def prepare_conf(
Expand Down
Loading