Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/concepts/gateways.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ If you disable [public IP](#public-ip) (e.g. to make the gateway private) or if

* `lets-encrypt` (default) — Automatic certificates via [Let's Encrypt](https://letsencrypt.org/). Requires a [public IP](#public-ip).
* `acm` — Certificates managed by [AWS Certificate Manager](https://aws.amazon.com/certificate-manager/). AWS-only. TLS is terminated at the load balancer, not at the gateway.
Requires a VPC with at least two subnets in different availability zones to provision a load balancer. If `public_ip: False`, subnets must be private and have a route to NAT gateway.
* `null` — No certificate. Services will use HTTP.

### Public IP
Expand Down
67 changes: 46 additions & 21 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ComputeWithVolumeSupport,
generate_unique_gateway_instance_name,
generate_unique_instance_name,
generate_unique_short_backend_name,
generate_unique_volume_name,
get_gateway_user_data,
get_user_data,
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
if zones_cache is None:
zones_cache = ComputeCache(cache=Cache(maxsize=10))
self._regions_to_zones_cache = zones_cache
self._vpc_id_subnet_id_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
self._vpc_id_subnets_ids_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
self._maximum_efa_interfaces_cache = ComputeCache(cache=Cache(maxsize=100))
self._subnets_availability_zones_cache = ComputeCache(cache=Cache(maxsize=100))
self._security_group_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
Expand Down Expand Up @@ -265,7 +266,7 @@ def create_instance(
enable_efa = max_efa_interfaces > 0
is_capacity_block = False
try:
vpc_id, subnet_ids = self._get_vpc_id_subnet_id_or_error(
vpc_id, subnets_ids = self._get_vpc_id_subnets_ids_or_error(
ec2_client=ec2_client,
config=self.config,
region=instance_offer.region,
Expand All @@ -275,7 +276,7 @@ def create_instance(
subnet_id_to_az_map = self._get_subnets_availability_zones(
ec2_client=ec2_client,
region=instance_offer.region,
subnet_ids=subnet_ids,
subnets_ids=subnets_ids,
)
if instance_config.reservation:
reservation = aws_resources.get_reservation(
Expand Down Expand Up @@ -497,7 +498,7 @@ def create_gateway(
tags = aws_resources.filter_invalid_tags(tags)
tags = aws_resources.make_tags(tags)

vpc_id, subnets_ids = self._get_vpc_id_subnet_id_or_error(
vpc_id, subnets_ids = self._get_vpc_id_subnets_ids_or_error(
ec2_client=ec2_client,
config=self.config,
region=configuration.region,
Expand Down Expand Up @@ -548,15 +549,21 @@ def create_gateway(

elb_client = self.session.client("elbv2", region_name=configuration.region)

if len(subnets_ids) < 2:
lb_subnets_ids = self._get_gateway_lb_subnets_ids(
ec2_client=ec2_client, region=configuration.region, subnets_ids=subnets_ids
)
if len(lb_subnets_ids) < 2:
raise ComputeError(
"Deploying gateway with ACM certificate requires at least two subnets in different AZs"
)

# Using short names as LB and target groups have length limit of 32.
resources_name_prefix = generate_unique_short_backend_name()

logger.debug("Creating ALB for gateway %s...", configuration.instance_name)
response = elb_client.create_load_balancer(
Name=f"{instance_name}-lb",
Subnets=subnets_ids,
Name=f"{resources_name_prefix}-lb",
Subnets=lb_subnets_ids,
SecurityGroups=[security_group_id],
Scheme="internet-facing" if configuration.public_ip else "internal",
Tags=tags,
Expand All @@ -570,7 +577,7 @@ def create_gateway(

logger.debug("Creating Target Group for gateway %s...", configuration.instance_name)
response = elb_client.create_target_group(
Name=f"{instance_name}-tg",
Name=f"{resources_name_prefix}-tg",
Protocol="HTTP",
Port=80,
VpcId=vpc_id,
Expand Down Expand Up @@ -877,7 +884,7 @@ def _get_regions_to_zones(
) -> Dict[str, List[str]]:
return _get_regions_to_zones(session=session, regions=regions)

def _get_vpc_id_subnet_id_or_error_cache_key(
def _get_vpc_id_subnets_ids_or_error_cache_key(
self,
ec2_client: botocore.client.BaseClient,
config: AWSConfig,
Expand All @@ -890,19 +897,19 @@ def _get_vpc_id_subnet_id_or_error_cache_key(
)

@cachedmethod(
cache=lambda self: self._vpc_id_subnet_id_cache.cache,
key=_get_vpc_id_subnet_id_or_error_cache_key,
lock=lambda self: self._vpc_id_subnet_id_cache.lock,
cache=lambda self: self._vpc_id_subnets_ids_cache.cache,
key=_get_vpc_id_subnets_ids_or_error_cache_key,
lock=lambda self: self._vpc_id_subnets_ids_cache.lock,
)
def _get_vpc_id_subnet_id_or_error(
def _get_vpc_id_subnets_ids_or_error(
self,
ec2_client: botocore.client.BaseClient,
config: AWSConfig,
region: str,
allocate_public_ip: bool,
availability_zones: Optional[List[str]] = None,
) -> Tuple[str, List[str]]:
return get_vpc_id_subnet_id_or_error(
return get_vpc_id_subnets_ids_or_error(
ec2_client=ec2_client,
config=config,
region=region,
Expand Down Expand Up @@ -930,9 +937,9 @@ def _get_subnets_availability_zones_key(
self,
ec2_client: botocore.client.BaseClient,
region: str,
subnet_ids: List[str],
subnets_ids: List[str],
) -> tuple:
return hashkey(region, tuple(subnet_ids))
return hashkey(region, tuple(subnets_ids))

@cachedmethod(
cache=lambda self: self._subnets_availability_zones_cache.cache,
Expand All @@ -943,11 +950,11 @@ def _get_subnets_availability_zones(
self,
ec2_client: botocore.client.BaseClient,
region: str,
subnet_ids: List[str],
subnets_ids: List[str],
) -> Dict[str, str]:
return aws_resources.get_subnets_availability_zones(
ec2_client=ec2_client,
subnet_ids=subnet_ids,
subnets_ids=subnets_ids,
)

@cachedmethod(
Expand Down Expand Up @@ -1000,8 +1007,26 @@ def _get_image_id_and_username(
image_config=image_config,
)

def _get_gateway_lb_subnets_ids(
self,
ec2_client: botocore.client.BaseClient,
region: str,
subnets_ids: List[str],
) -> List[str]:
"""
Returns subnet IDs to be used for gateway Load Balancer among `subnets_ids`.
Filters out subnets from the same AZ since Load Balancer requires all subnets to be in different AZ.
"""
subnet_id_to_az_map = self._get_subnets_availability_zones(
ec2_client=ec2_client,
region=region,
subnets_ids=subnets_ids,
)
az_to_subnet_id_map = {az: subnet_id for subnet_id, az in subnet_id_to_az_map.items()}
return list(az_to_subnet_id_map.values())


def get_vpc_id_subnet_id_or_error(
def get_vpc_id_subnets_ids_or_error(
ec2_client: botocore.client.BaseClient,
config: AWSConfig,
region: str,
Expand Down Expand Up @@ -1032,7 +1057,7 @@ def get_vpc_id_subnet_id_or_error(
if not config.use_default_vpcs:
raise ComputeError(f"No VPC ID configured for region {region}")

return _get_vpc_id_subnet_id_by_vpc_name_or_error(
return _get_vpc_id_subnets_ids_by_vpc_name_or_error(
ec2_client=ec2_client,
vpc_name=config.vpc_name,
region=region,
Expand All @@ -1041,7 +1066,7 @@ def get_vpc_id_subnet_id_or_error(
)


def _get_vpc_id_subnet_id_by_vpc_name_or_error(
def _get_vpc_id_subnets_ids_by_vpc_name_or_error(
ec2_client: botocore.client.BaseClient,
vpc_name: Optional[str],
region: str,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/backends/aws/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _check_config_vpc(self, session: Session, config: AWSBackendConfigWithCreds)
for region in regions:
ec2_client = session.client("ec2", region_name=region)
future = executor.submit(
compute.get_vpc_id_subnet_id_or_error,
compute.get_vpc_id_subnets_ids_or_error,
ec2_client=ec2_client,
config=AWSConfig.parse_obj(config),
region=region,
Expand Down
5 changes: 3 additions & 2 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def get_subnets_ids_for_vpc(
"""
If `allocate_public_ip` is True, returns public subnets found in the VPC.
If `allocate_public_ip` is False, returns subnets with NAT found in the VPC.
Returns
"""
subnets = _get_subnets_by_vpc_id(
ec2_client=ec2_client,
Expand Down Expand Up @@ -423,9 +424,9 @@ def get_availability_zone_by_subnet_id(


def get_subnets_availability_zones(
ec2_client: botocore.client.BaseClient, subnet_ids: List[str]
ec2_client: botocore.client.BaseClient, subnets_ids: List[str]
) -> Dict[str, str]:
response = ec2_client.describe_subnets(SubnetIds=subnet_ids)
response = ec2_client.describe_subnets(SubnetIds=subnets_ids)
subnet_id_to_az_map = {
subnet["SubnetId"]: subnet["AvailabilityZone"] for subnet in response["Subnets"]
}
Expand Down
15 changes: 12 additions & 3 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,12 +670,21 @@ def generate_unique_backend_name(
)


def generate_unique_short_backend_name() -> str:
"""
Generates a unique 15-char resource name of the form "dstack-12345678".
Can be used for resources that have a very small length limit like AWS LBs.
"""
return _generate_unique_backend_name_with_prefix("dstack")


def _generate_unique_backend_name_with_prefix(
prefix: str,
max_length: int,
max_length: Optional[int] = None,
) -> str:
prefix_len = max_length - _CLOUD_RESOURCE_SUFFIX_LEN - 1
prefix = prefix[:prefix_len]
if max_length is not None:
prefix_len = max_length - _CLOUD_RESOURCE_SUFFIX_LEN - 1
prefix = prefix[:prefix_len]
suffix = "".join(
random.choice(string.ascii_lowercase + string.digits)
for _ in range(_CLOUD_RESOURCE_SUFFIX_LEN)
Expand Down
2 changes: 1 addition & 1 deletion src/tests/_internal/core/backends/aws/test_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_validate_config_valid(self):
)
with (
patch("dstack._internal.core.backends.aws.auth.authenticate"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"),
):
AWSConfigurator().validate_config(config, default_creds_enabled=True)

Expand Down
10 changes: 5 additions & 5 deletions src/tests/_internal/server/routers/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def test_creates_aws_backend(self, test_db, session: AsyncSession, client:
}
with (
patch("dstack._internal.core.backends.aws.auth.authenticate"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"),
):
response = await client.post(
f"/api/project/{project.name}/backends/create",
Expand Down Expand Up @@ -542,7 +542,7 @@ async def test_returns_400_if_backend_exists(
}
with (
patch("dstack._internal.core.backends.aws.auth.authenticate"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"),
):
response = await client.post(
f"/api/project/{project.name}/backends/create",
Expand Down Expand Up @@ -605,7 +605,7 @@ async def test_updates_backend(self, test_db, session: AsyncSession, client: Asy
}
with (
patch("dstack._internal.core.backends.aws.auth.authenticate"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"),
):
response = await client.post(
f"/api/project/{project.name}/backends/update",
Expand Down Expand Up @@ -857,7 +857,7 @@ async def test_creates_aws_backend(self, test_db, session: AsyncSession, client:
body = {"config_yaml": yaml.dump(config_dict)}
with (
patch("dstack._internal.core.backends.aws.auth.authenticate"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"),
):
response = await client.post(
f"/api/project/{project.name}/backends/create_yaml",
Expand Down Expand Up @@ -945,7 +945,7 @@ async def test_updates_aws_backend(self, test_db, session: AsyncSession, client:
body = {"config_yaml": yaml.dump(config_dict)}
with (
patch("dstack._internal.core.backends.aws.auth.authenticate"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"),
):
response = await client.post(
f"/api/project/{project.name}/backends/update_yaml",
Expand Down
4 changes: 3 additions & 1 deletion src/tests/_internal/server/services/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ async def test_creates_backend(self, test_db, session: AsyncSession, tmp_path: P
with (
patch("boto3.session.Session"),
patch.object(settings, "SERVER_CONFIG_FILE_PATH", config_filepath),
patch("dstack._internal.core.backends.aws.compute.get_vpc_id_subnet_id_or_error"),
patch(
"dstack._internal.core.backends.aws.compute.get_vpc_id_subnets_ids_or_error"
),
):
manager = ServerConfigManager()
manager.load_config()
Expand Down
Loading