Skip to content

Commit 38eee94

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Minor Update
1 parent 860ea23 commit 38eee94

2 files changed

Lines changed: 9 additions & 5 deletions

File tree

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SSHKey,
4040
)
4141
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
42+
from dstack._internal.core.models.routers import RouterType
4243
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
4344
from dstack._internal.core.models.volumes import (
4445
Volume,
@@ -923,7 +924,7 @@ def get_run_shim_script(
923924
]
924925

925926

926-
def get_gateway_user_data(authorized_key: str, router: Optional[str] = None) -> str:
927+
def get_gateway_user_data(authorized_key: str, router: Optional[RouterType] = None) -> str:
927928
return get_cloud_config(
928929
package_update=True,
929930
packages=[
@@ -1035,7 +1036,7 @@ def get_latest_runner_build() -> Optional[str]:
10351036
return None
10361037

10371038

1038-
def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str:
1039+
def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> str:
10391040
channel = "release" if settings.DSTACK_RELEASE else "stgn"
10401041
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
10411042
if build == "latest":
@@ -1044,11 +1045,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str:
10441045
wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
10451046
# Build package spec with extras if router is specified
10461047
if router:
1047-
return f"dstack-gateway[{router}] @ {wheel}"
1048+
return f"dstack-gateway[{router.value}] @ {wheel}"
10481049
return f"dstack-gateway @ {wheel}"
10491050

10501051

1051-
def get_dstack_gateway_commands(router: Optional[str] = None) -> List[str]:
1052+
def get_dstack_gateway_commands(router: Optional[RouterType] = None) -> List[str]:
10521053
build = get_dstack_runner_version() or "latest"
10531054
gateway_package = get_dstack_gateway_wheel(build, router)
10541055
return [

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
)
6767
from dstack._internal.core.models.placement import PlacementGroup
6868
from dstack._internal.core.models.resources import CPUSpec, GPUSpec
69+
from dstack._internal.core.models.routers import RouterType
6970
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
7071
from dstack._internal.core.models.volumes import Volume
7172
from dstack._internal.utils.common import get_or_error
@@ -862,7 +863,9 @@ def _wait_for_load_balancer_address(
862863
time.sleep(1)
863864

864865

865-
def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]:
866+
def _get_gateway_commands(
867+
authorized_keys: List[str], router: Optional[RouterType] = None
868+
) -> List[str]:
866869
authorized_keys_content = "\n".join(authorized_keys).strip()
867870
gateway_commands = " && ".join(get_dstack_gateway_commands(router=router))
868871
quoted_gateway_commands = shlex.quote(gateway_commands)

0 commit comments

Comments
 (0)