Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/concepts/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ There are two ways to configure AWS: using an access key or using the default cr
* `user` with passwordless sudo access
* Docker is installed
* (For NVIDIA instances) NVIDIA/CUDA drivers and NVIDIA Container Toolkit are installed
* The firewall (`iptables`, `ufw`, etc.) must allow external traffic to port 22 and all traffic within the private subnet, and should forbid any other incoming external traffic.

## Azure

Expand Down
7 changes: 6 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,12 @@ def create_instance(
image_id=image_id,
instance_type=instance_offer.instance.name,
iam_instance_profile=self.config.iam_instance_profile,
user_data=get_user_data(authorized_keys=instance_config.get_public_keys()),
user_data=get_user_data(
authorized_keys=instance_config.get_public_keys(),
# Custom OS images may lack ufw, so don't attempt to set up the firewall.
# Rely on security groups and the image's built-in firewall rules instead.
skip_firewall_setup=self.config.os_images is not None,
),
tags=aws_resources.make_tags(tags),
security_group_id=security_group_id,
spot=instance_offer.instance.resources.spot,
Expand Down
33 changes: 30 additions & 3 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import string
import threading
from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Literal, Optional
Expand Down Expand Up @@ -45,6 +46,7 @@

DSTACK_SHIM_BINARY_NAME = "dstack-shim"
DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")

GoArchType = Literal["amd64", "arm64"]

Expand Down Expand Up @@ -507,12 +509,16 @@ def get_user_data(
base_path: Optional[PathLike] = None,
bin_path: Optional[PathLike] = None,
backend_shim_env: Optional[Dict[str, str]] = None,
skip_firewall_setup: bool = False,
firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS,
) -> str:
shim_commands = get_shim_commands(
authorized_keys=authorized_keys,
base_path=base_path,
bin_path=bin_path,
backend_shim_env=backend_shim_env,
skip_firewall_setup=skip_firewall_setup,
firewall_allow_from_subnets=firewall_allow_from_subnets,
)
commands = (backend_specific_commands or []) + shim_commands
return get_cloud_config(
Expand Down Expand Up @@ -554,8 +560,13 @@ def get_shim_commands(
bin_path: Optional[PathLike] = None,
backend_shim_env: Optional[Dict[str, str]] = None,
arch: Optional[str] = None,
skip_firewall_setup: bool = False,
firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS,
) -> List[str]:
commands = get_setup_cloud_instance_commands()
commands = get_setup_cloud_instance_commands(
skip_firewall_setup=skip_firewall_setup,
firewall_allow_from_subnets=firewall_allow_from_subnets,
)
commands += get_shim_pre_start_commands(
base_path=base_path,
bin_path=bin_path,
Expand Down Expand Up @@ -638,8 +649,11 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
return url_template.format(version=version, arch=arch)


def get_setup_cloud_instance_commands() -> list[str]:
return [
def get_setup_cloud_instance_commands(
skip_firewall_setup: bool,
firewall_allow_from_subnets: Iterable[str],
) -> list[str]:
commands = [
# Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
# Attempts to patch /etc/docker/daemon.json while keeping any custom settings it may have.
(
Expand All @@ -653,6 +667,19 @@ def get_setup_cloud_instance_commands() -> list[str]:
"'"
),
]
if not skip_firewall_setup:
commands += [
"ufw --force reset", # Some OS images have default rules like `allow 80`. Delete them
"ufw default deny incoming",
"ufw default allow outgoing",
"ufw allow ssh",
]
for subnet in firewall_allow_from_subnets:
commands.append(f"ufw allow from {subnet}")
commands += [
"ufw --force enable",
]
return commands


def get_shim_pre_start_commands(
Expand Down
11 changes: 2 additions & 9 deletions src/dstack/_internal/core/backends/digitalocean_base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
logger = get_logger(__name__)

MAX_INSTANCE_NAME_LEN = 60

# Setup commands for DigitalOcean instances
SETUP_COMMANDS = [
"sudo ufw delete limit ssh",
"sudo ufw allow ssh",
]

DOCKER_INSTALL_COMMANDS = [
"export DEBIAN_FRONTEND=noninteractive",
"mkdir -p /etc/apt/keyrings",
Expand Down Expand Up @@ -92,9 +85,9 @@ def create_instance(
size_slug = instance_offer.instance.name

if not instance_offer.instance.resources.gpus:
backend_specific_commands = SETUP_COMMANDS + DOCKER_INSTALL_COMMANDS
backend_specific_commands = DOCKER_INSTALL_COMMANDS
else:
backend_specific_commands = SETUP_COMMANDS
backend_specific_commands = None

project_id = None
if self.config.project_name:
Expand Down
40 changes: 32 additions & 8 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import threading
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Dict, List, Literal, Optional, Tuple

import google.api_core.exceptions
Expand Down Expand Up @@ -285,16 +286,18 @@ def create_instance(
)
raise NoCapacityError()

image = _get_image(
instance_type_name=instance_offer.instance.name,
cuda=len(instance_offer.instance.resources.gpus) > 0,
)

for zone in zones:
request = compute_v1.InsertInstanceRequest()
request.zone = zone
request.project = self.config.project_id
request.instance_resource = gcp_resources.create_instance_struct(
disk_size=disk_size,
image_id=_get_image_id(
instance_type_name=instance_offer.instance.name,
cuda=len(instance_offer.instance.resources.gpus) > 0,
),
image_id=image.id,
machine_type=instance_offer.instance.name,
accelerators=gcp_resources.get_accelerators(
project_id=self.config.project_id,
Expand All @@ -305,6 +308,7 @@ def create_instance(
user_data=_get_user_data(
authorized_keys=authorized_keys,
instance_type_name=instance_offer.instance.name,
is_ufw_installed=image.is_ufw_installed,
),
authorized_keys=authorized_keys,
labels=labels,
Expand Down Expand Up @@ -889,24 +893,41 @@ def _get_vpc_subnet(
)


def _get_image_id(instance_type_name: str, cuda: bool) -> str:
@dataclass
class GCPImage:
id: str
is_ufw_installed: bool


def _get_image(instance_type_name: str, cuda: bool) -> GCPImage:
if instance_type_name == "a3-megagpu-8g":
image_name = "dstack-a3mega-5"
is_ufw_installed = False
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
return "projects/cos-cloud/global/images/cos-105-17412-535-78"
return GCPImage(
id="projects/cos-cloud/global/images/cos-105-17412-535-78",
is_ufw_installed=False,
)
elif cuda:
image_name = f"dstack-cuda-{version.base_image}"
is_ufw_installed = True
else:
image_name = f"dstack-{version.base_image}"
is_ufw_installed = True
image_name = image_name.replace(".", "-")
return f"projects/dstack/global/images/{image_name}"
return GCPImage(
id=f"projects/dstack/global/images/{image_name}",
is_ufw_installed=is_ufw_installed,
)


def _get_gateway_image_id() -> str:
return "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714"


def _get_user_data(authorized_keys: List[str], instance_type_name: str) -> str:
def _get_user_data(
authorized_keys: List[str], instance_type_name: str, is_ufw_installed: bool
) -> str:
base_path = None
bin_path = None
backend_shim_env = None
Expand All @@ -929,6 +950,9 @@ def _get_user_data(authorized_keys: List[str], instance_type_name: str) -> str:
base_path=base_path,
bin_path=bin_path,
backend_shim_env=backend_shim_env,
# Instance-level firewall is optional on GCP. The main protection comes from GCP firewalls.
# So only set up instance-level firewall as an additional measure if ufw is available.
skip_firewall_setup=not is_ufw_installed,
)


Expand Down
7 changes: 0 additions & 7 deletions src/dstack/_internal/core/backends/nebius/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,6 @@
"exec-opts": ["native.cgroupdriver=cgroupfs"],
}
SETUP_COMMANDS = [
"ufw allow ssh",
"ufw allow from 10.0.0.0/8",
"ufw allow from 172.16.0.0/12",
"ufw allow from 192.168.0.0/16",
"ufw default deny incoming",
"ufw default allow outgoing",
"ufw enable",
'sed -i "s/.*AllowTcpForwarding.*/AllowTcpForwarding yes/g" /etc/ssh/sshd_config',
"service ssh restart",
f"echo {shlex.quote(json.dumps(DOCKER_DAEMON_CONFIG))} > /etc/docker/daemon.json",
Expand Down
9 changes: 4 additions & 5 deletions src/dstack/_internal/core/backends/oci/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,10 @@ def create_instance(
security_group.id, region.virtual_network_client
)

setup_commands = [
f"sudo iptables -I INPUT -s {resources.VCN_CIDR} -j ACCEPT",
"sudo netfilter-persistent save",
]
cloud_init_user_data = get_user_data(instance_config.get_public_keys(), setup_commands)
cloud_init_user_data = get_user_data(
authorized_keys=instance_config.get_public_keys(),
firewall_allow_from_subnets=[resources.VCN_CIDR],
)

display_name = generate_unique_instance_name(instance_config)
try:
Expand Down
6 changes: 1 addition & 5 deletions src/dstack/_internal/core/backends/vultr/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,13 @@ def create_instance(
subnet = vpc["v4_subnet"]
subnet_mask = vpc["v4_subnet_mask"]

setup_commands = [
f"sudo ufw allow from {subnet}/{subnet_mask}",
"sudo ufw reload",
]
instance_id = self.api_client.launch_instance(
region=instance_offer.region,
label=instance_name,
plan=instance_offer.instance.name,
user_data=get_user_data(
authorized_keys=instance_config.get_public_keys(),
backend_specific_commands=setup_commands,
firewall_allow_from_subnets=[f"{subnet}/{subnet_mask}"],
),
vpc_id=vpc["id"],
)
Expand Down
Loading