Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/commands/offer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _command(self, args: argparse.Namespace):

run_spec = RunSpec(
configuration=conf,
ssh_key_pub="(dummy)",
profile=profile,
ssh_key_pub="(dummy)", # TODO: Remove since 0.19.40
)

if args.group_by:
Expand Down
6 changes: 1 addition & 5 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.nested_list import NestedList, NestedListItem
from dstack._internal.utils.path import is_absolute_posix_path
from dstack.api._public.repos import get_ssh_keypair
from dstack.api._public.runs import Run
from dstack.api.server import APIClient
from dstack.api.utils import load_profile
Expand Down Expand Up @@ -135,17 +134,14 @@ def apply_configuration(

config_manager = ConfigManager()
repo = self.get_repo(conf, configuration_path, configurator_args, config_manager)
self.api.ssh_identity_file = get_ssh_keypair(
configurator_args.ssh_identity_file,
config_manager.dstack_key_path,
)
profile = load_profile(Path.cwd(), configurator_args.profile)
with console.status("Getting apply plan..."):
run_plan = self.api.runs.get_run_plan(
configuration=conf,
repo=repo,
configuration_path=configuration_path,
profile=profile,
ssh_identity_file=configurator_args.ssh_identity_file,
)

print_run_plan(run_plan, max_offers=configurator_args.max_offers)
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def run_job(
volumes: List[Volume],
) -> JobProvisioningData:
instance_name = generate_unique_instance_name_for_job(run, job)
assert run.run_spec.ssh_key_pub is not None
commands = get_docker_commands(
[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()]
)
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/runpod/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def run_job(
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
assert run.run_spec.ssh_key_pub is not None
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_job_instance_name(run, job),
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/vastai/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def run_job(
instance_name = generate_unique_instance_name_for_job(
run, job, max_length=MAX_INSTANCE_NAME_LEN
)
assert run.run_spec.ssh_key_pub is not None
commands = get_docker_commands(
[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()]
)
Expand Down
5 changes: 3 additions & 2 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,12 @@ class RunSpec(generate_dual_core_model(RunSpecConfig)):
configuration: Annotated[AnyRunConfiguration, Field(discriminator="type")]
profile: Annotated[Optional[Profile], Field(description="The profile parameters")] = None
ssh_key_pub: Annotated[
str,
Optional[str],
Field(
description="The contents of the SSH public key that will be used to connect to the run."
" Can be empty only before the run is submitted."
),
]
] = None
# merged_profile stores profile parameters merged from profile and configuration.
# Read profile parameters from merged_profile instead of profile directly.
# TODO: make merged_profile a computed field after migrating to pydanticV2
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class User(CoreModel):
email: Optional[str]
active: bool
permissions: UserPermissions
ssh_public_key: Optional[str] = None


class UserTokenCreds(CoreModel):
Expand All @@ -38,3 +39,4 @@ class UserTokenCreds(CoreModel):

class UserWithCreds(User):
creds: UserTokenCreds
ssh_private_key: Optional[str] = None
1 change: 1 addition & 0 deletions src/dstack/_internal/core/services/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def dstack_ssh_dir(self) -> Path:

@property
def dstack_key_path(self) -> Path:
# TODO: Remove since 0.19.40
return self.dstack_ssh_dir / "id_rsa"

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
job_submission.age,
)
ssh_user = job_provisioning_data.username
assert run.run_spec.ssh_key_pub is not None
user_ssh_key = run.run_spec.ssh_key_pub.strip()
public_keys = [project.ssh_public_key.strip(), user_ssh_key]
if job_provisioning_data.backend == BackendType.LOCAL:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""user.ssh_key

Revision ID: ff1d94f65b08
Revises: 2498ab323443
Create Date: 2025-10-09 20:31:31.166786

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "ff1d94f65b08"
down_revision = "2498ab323443"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(sa.Column("ssh_private_key", sa.Text(), nullable=True))
batch_op.add_column(sa.Column("ssh_public_key", sa.Text(), nullable=True))

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column("ssh_public_key")
batch_op.drop_column("ssh_private_key")

# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ class UserModel(BaseModel):
# deactivated users cannot access API
active: Mapped[bool] = mapped_column(Boolean, default=True)

ssh_private_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
ssh_public_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)

email: Mapped[Optional[str]] = mapped_column(String(200), nullable=True)

projects_quota: Mapped[int] = mapped_column(
Expand Down
6 changes: 5 additions & 1 deletion src/dstack/_internal/server/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SubmitRunRequest,
)
from dstack._internal.server.security.permissions import Authenticated, ProjectMember
from dstack._internal.server.services import runs
from dstack._internal.server.services import runs, users
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
Expand Down Expand Up @@ -111,6 +111,8 @@ async def get_plan(
This is an optional step before calling `/apply`.
"""
user, project = user_project
if not user.ssh_public_key and not body.run_spec.ssh_key_pub:
await users.refresh_ssh_key(session=session, user=user, username=user.name)
run_plan = await runs.get_plan(
session=session,
project=project,
Expand All @@ -137,6 +139,8 @@ async def apply_plan(
If the existing run is active and cannot be updated, it must be stopped first.
"""
user, project = user_project
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
await users.refresh_ssh_key(session=session, user=user, username=user.name)
return CustomORJSONResponse(
await runs.apply_plan(
session=session,
Expand Down
16 changes: 14 additions & 2 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ async def list_users(
return CustomORJSONResponse(await users.list_users_for_user(session=session, user=user))


@router.post("/get_my_user", response_model=User)
@router.post("/get_my_user", response_model=UserWithCreds)
async def get_my_user(
user: UserModel = Depends(Authenticated()),
):
return CustomORJSONResponse(users.user_model_to_user(user))
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))


@router.post("/get_user", response_model=UserWithCreds)
Expand Down Expand Up @@ -91,6 +91,18 @@ async def update_user(
return CustomORJSONResponse(users.user_model_to_user(res))


@router.post("/refresh_ssh_key", response_model=UserWithCreds)
async def refresh_ssh_key(
body: RefreshTokenRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
):
res = await users.refresh_ssh_key(session=session, user=user, username=body.username)
if res is None:
raise ResourceNotExistsError()
return CustomORJSONResponse(users.user_model_to_user_with_creds(res))


@router.post("/refresh_token", response_model=UserWithCreds)
async def refresh_token(
body: RefreshTokenRequest,
Expand Down
13 changes: 9 additions & 4 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ async def get_plan(
spec=effective_run_spec,
)
effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict())
_validate_run_spec_and_set_defaults(effective_run_spec)
_validate_run_spec_and_set_defaults(user, effective_run_spec)

profile = effective_run_spec.merged_profile
creation_policy = profile.creation_policy
Expand Down Expand Up @@ -422,7 +422,7 @@ async def apply_plan(
)
# Spec must be copied by parsing to calculate merged_profile
run_spec = RunSpec.parse_obj(run_spec.dict())
_validate_run_spec_and_set_defaults(run_spec)
_validate_run_spec_and_set_defaults(user, run_spec)
if run_spec.run_name is None:
return await submit_run(
session=session,
Expand Down Expand Up @@ -489,7 +489,7 @@ async def submit_run(
project: ProjectModel,
run_spec: RunSpec,
) -> Run:
_validate_run_spec_and_set_defaults(run_spec)
_validate_run_spec_and_set_defaults(user, run_spec)
repo = await _get_run_repo_or_error(
session=session,
project=project,
Expand Down Expand Up @@ -981,7 +981,7 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float:
return job_submission.job_provisioning_data.price * duration_hours


def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
# This function may set defaults for null run_spec values,
# although most defaults are resolved when building job_spec
# so that we can keep both the original user-supplied value (null in run_spec)
Expand Down Expand Up @@ -1031,6 +1031,11 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
if run_spec.configuration.priority is None:
run_spec.configuration.priority = RUN_PRIORITY_DEFAULT
set_resources_defaults(run_spec.configuration.resources)
if run_spec.ssh_key_pub is None:
if user.ssh_public_key:
run_spec.ssh_key_pub = user.ssh_public_key
else:
raise ServerClientError("ssh_key_pub must be set if the user has no ssh_public_key")


_UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"]
Expand Down
29 changes: 29 additions & 0 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from dstack._internal.server.models import DecryptedString, UserModel
from dstack._internal.server.services.permissions import get_default_permissions
from dstack._internal.server.utils.routers import error_forbidden
from dstack._internal.utils.common import run_async
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -84,6 +86,7 @@ async def create_user(
raise ResourceExistsError()
if token is None:
token = str(uuid.uuid4())
private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username)
user = UserModel(
id=uuid.uuid4(),
name=username,
Expand All @@ -92,6 +95,8 @@ async def create_user(
token_hash=get_token_hash(token),
email=email,
active=active,
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
)
session.add(user)
await session.commit()
Expand Down Expand Up @@ -120,6 +125,27 @@ async def update_user(
return await get_user_model_by_name_or_error(session=session, username=username)


async def refresh_ssh_key(
session: AsyncSession,
user: UserModel,
username: str,
) -> Optional[UserModel]:
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
if user.global_role != GlobalRole.ADMIN and user.name != username:
raise error_forbidden()
private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username)
await session.execute(
update(UserModel)
.where(UserModel.name == username)
.values(
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
)
)
await session.commit()
return await get_user_model_by_name(session=session, username=username)


async def refresh_user_token(
session: AsyncSession,
user: UserModel,
Expand Down Expand Up @@ -199,6 +225,7 @@ def user_model_to_user(user_model: UserModel) -> User:
email=user_model.email,
active=user_model.active,
permissions=get_user_permissions(user_model),
ssh_public_key=user_model.ssh_public_key,
)


Expand All @@ -211,7 +238,9 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds:
email=user_model.email,
active=user_model.active,
permissions=get_user_permissions(user_model),
ssh_public_key=user_model.ssh_public_key,
creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()),
ssh_private_key=user_model.ssh_private_key,
)


Expand Down
21 changes: 9 additions & 12 deletions src/dstack/api/_public/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import dstack._internal.core.services.api_client as api_client_service
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import PathLike
from dstack._internal.utils.path import PathLike as PathLike
from dstack.api._public.backends import BackendCollection
from dstack.api._public.repos import RepoCollection, get_ssh_keypair
from dstack.api._public.runs import RunCollection
from dstack.api._public.repos import RepoCollection
from dstack.api._public.runs import RunCollection, warn
from dstack.api.server import APIClient

logger = get_logger(__name__)
Expand Down Expand Up @@ -35,24 +34,24 @@ def __init__(
# Args:
# api_client: low-level server API client
# project_name: project name used for runs
# ssh_identity_file: SSH keypair to access instances
# ssh_identity_file: deprecated and will be removed in 0.19.40
# """
self._client = api_client
self._project = project_name
self._repos = RepoCollection(api_client, project_name)
self._backends = BackendCollection(api_client, project_name)
self._runs = RunCollection(api_client, project_name, self)
if ssh_identity_file:
self.ssh_identity_file = str(ssh_identity_file)
else:
self.ssh_identity_file = get_ssh_keypair(None, ConfigManager().dstack_key_path)
if ssh_identity_file is not None:
warn(
"[code]ssh_identity_file[/code] in [code]Client[/code] is deprecated and ignored; will be removed"
" since 0.19.40"
)

@staticmethod
def from_config(
project_name: Optional[str] = None,
server_url: Optional[str] = None,
user_token: Optional[str] = None,
ssh_identity_file: Optional[PathLike] = None,
) -> "Client":
"""
Creates a Client using the default configuration from `~/.dstack/config.yml` if it exists.
Expand All @@ -61,7 +60,6 @@ def from_config(
project_name: The name of the project. required if `server_url` and `user_token` are specified.
server_url: The dstack server URL (e.g. `http://localhost:3000/` or `https://sky.dstack.ai`).
user_token: The dstack user token.
ssh_identity_file: The private SSH key path for SSH tunneling.

Returns:
A client instance.
Expand All @@ -75,7 +73,6 @@ def from_config(
return Client(
api_client=api_client,
project_name=project_name,
ssh_identity_file=ssh_identity_file,
)

@property
Expand Down
Loading