From 74697bbf7b2a8504551f1fe1dda14bf6c8a9c6f7 Mon Sep 17 00:00:00 2001 From: peterschmidt85 Date: Fri, 10 Oct 2025 00:59:57 +0200 Subject: [PATCH 1/6] [Feature]: Store user SSH key on the server #2053 --- src/dstack/_internal/cli/commands/offer.py | 1 - .../cli/services/configurators/run.py | 6 +- .../core/backends/kubernetes/compute.py | 1 + .../_internal/core/backends/runpod/compute.py | 1 + .../_internal/core/backends/vastai/compute.py | 1 + src/dstack/_internal/core/models/runs.py | 5 +- src/dstack/_internal/core/models/users.py | 2 + .../core/services/configs/__init__.py | 1 + .../background/tasks/process_running_jobs.py | 1 + .../versions/ff1d94f65b08_user_ssh_key.py | 34 +++++++++++ src/dstack/_internal/server/models.py | 3 + src/dstack/_internal/server/routers/runs.py | 6 +- src/dstack/_internal/server/routers/users.py | 16 ++++- src/dstack/_internal/server/services/runs.py | 13 +++-- src/dstack/_internal/server/services/users.py | 29 ++++++++++ src/dstack/api/_public/__init__.py | 12 +--- src/dstack/api/_public/repos.py | 21 ------- src/dstack/api/_public/runs.py | 58 ++++++++++++++++--- src/dstack/api/server/_users.py | 4 +- 19 files changed, 157 insertions(+), 58 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/ff1d94f65b08_user_ssh_key.py diff --git a/src/dstack/_internal/cli/commands/offer.py b/src/dstack/_internal/cli/commands/offer.py index e6d02a0154..17600b0607 100644 --- a/src/dstack/_internal/cli/commands/offer.py +++ b/src/dstack/_internal/cli/commands/offer.py @@ -104,7 +104,6 @@ def _command(self, args: argparse.Namespace): run_spec = RunSpec( configuration=conf, - ssh_key_pub="(dummy)", profile=profile, ) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 0403a57a64..b51f1470d7 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -62,7 +62,6 @@ from dstack._internal.utils.logging import get_logger from dstack._internal.utils.nested_list import NestedList, NestedListItem from dstack._internal.utils.path import is_absolute_posix_path -from dstack.api._public.repos import get_ssh_keypair from dstack.api._public.runs import Run from dstack.api.server import APIClient from dstack.api.utils import load_profile @@ -135,10 +134,6 @@ def apply_configuration( config_manager = ConfigManager() repo = self.get_repo(conf, configuration_path, configurator_args, config_manager) - self.api.ssh_identity_file = get_ssh_keypair( - configurator_args.ssh_identity_file, - config_manager.dstack_key_path, - ) profile = load_profile(Path.cwd(), configurator_args.profile) with console.status("Getting apply plan..."): run_plan = self.api.runs.get_run_plan( @@ -146,6 +141,7 @@ def apply_configuration( repo=repo, configuration_path=configuration_path, profile=profile, + ssh_identity_file=configurator_args.ssh_identity_file, ) print_run_plan(run_plan, max_offers=configurator_args.max_offers) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 9668a17f31..601472127a 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -161,6 +161,7 @@ def run_job( volumes: List[Volume], ) -> JobProvisioningData: instance_name = generate_unique_instance_name_for_job(run, job) + assert run.run_spec.ssh_key_pub is not None commands = get_docker_commands( [run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) diff --git a/src/dstack/_internal/core/backends/runpod/compute.py b/src/dstack/_internal/core/backends/runpod/compute.py index 9b7fa6e652..3b9e022fd0 100644 --- a/src/dstack/_internal/core/backends/runpod/compute.py +++ b/src/dstack/_internal/core/backends/runpod/compute.py @@ -86,6 +86,7 @@ def run_job( project_ssh_private_key: str, volumes: List[Volume], ) -> JobProvisioningData: + assert run.run_spec.ssh_key_pub is not None instance_config = InstanceConfiguration( project_name=run.project_name, instance_name=get_job_instance_name(run, job), diff --git a/src/dstack/_internal/core/backends/vastai/compute.py b/src/dstack/_internal/core/backends/vastai/compute.py index ec853b69ee..b019fbb9be 100644 --- a/src/dstack/_internal/core/backends/vastai/compute.py +++ b/src/dstack/_internal/core/backends/vastai/compute.py @@ -85,6 +85,7 @@ def run_job( instance_name = generate_unique_instance_name_for_job( run, job, max_length=MAX_INSTANCE_NAME_LEN ) + assert run.run_spec.ssh_key_pub is not None commands = get_docker_commands( [run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 969b336b9d..765888e349 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -462,11 +462,12 @@ class RunSpec(generate_dual_core_model(RunSpecConfig)): configuration: Annotated[AnyRunConfiguration, Field(discriminator="type")] profile: Annotated[Optional[Profile], Field(description="The profile parameters")] = None ssh_key_pub: Annotated[ - str, + Optional[str], Field( description="The contents of the SSH public key that will be used to connect to the run." + " Can be empty only before the run is submitted." ), - ] + ] = None # merged_profile stores profile parameters merged from profile and configuration. # Read profile parameters from merged_profile instead of profile directly. # TODO: make merged_profile a computed field after migrating to pydanticV2 diff --git a/src/dstack/_internal/core/models/users.py b/src/dstack/_internal/core/models/users.py index 042f286627..449c57f8b2 100644 --- a/src/dstack/_internal/core/models/users.py +++ b/src/dstack/_internal/core/models/users.py @@ -38,3 +38,5 @@ class UserTokenCreds(CoreModel): class UserWithCreds(User): creds: UserTokenCreds + ssh_public_key: Optional[str] = None + ssh_private_key: Optional[str] = None diff --git a/src/dstack/_internal/core/services/configs/__init__.py b/src/dstack/_internal/core/services/configs/__init__.py index 6ebf7e7654..13ed49c9b5 100644 --- a/src/dstack/_internal/core/services/configs/__init__.py +++ b/src/dstack/_internal/core/services/configs/__init__.py @@ -117,6 +117,7 @@ def dstack_ssh_dir(self) -> Path: @property def dstack_key_path(self) -> Path: + # TODO: Remove since 0.19.35 return self.dstack_ssh_dir / "id_rsa" @property diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index aa69f2f335..fb2deed3eb 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -243,6 +243,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job_submission.age, ) ssh_user = job_provisioning_data.username + assert run.run_spec.ssh_key_pub is not None user_ssh_key = run.run_spec.ssh_key_pub.strip() public_keys = [project.ssh_public_key.strip(), user_ssh_key] if job_provisioning_data.backend == BackendType.LOCAL: diff --git a/src/dstack/_internal/server/migrations/versions/ff1d94f65b08_user_ssh_key.py b/src/dstack/_internal/server/migrations/versions/ff1d94f65b08_user_ssh_key.py new file mode 100644 index 0000000000..fc79b58b08 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/ff1d94f65b08_user_ssh_key.py @@ -0,0 +1,34 @@ +"""user.ssh_key + +Revision ID: ff1d94f65b08 +Revises: 2498ab323443 +Create Date: 2025-10-09 20:31:31.166786 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "ff1d94f65b08" +down_revision = "2498ab323443" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.add_column(sa.Column("ssh_private_key", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("ssh_public_key", sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.drop_column("ssh_public_key") + batch_op.drop_column("ssh_private_key") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index b21ba81a4d..31f44d3692 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -190,6 +190,9 @@ class UserModel(BaseModel): # deactivated users cannot access API active: Mapped[bool] = mapped_column(Boolean, default=True) + ssh_private_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + ssh_public_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + email: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) projects_quota: Mapped[int] = mapped_column( diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 8f3909503c..11cbbfa0b0 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -17,7 +17,7 @@ SubmitRunRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember -from dstack._internal.server.services import runs +from dstack._internal.server.services import runs, users from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -111,6 +111,8 @@ async def get_plan( This is an optional step before calling `/apply`. """ user, project = user_project + if not user.ssh_public_key and not body.run_spec.ssh_key_pub: + await users.refresh_ssh_key(session=session, user=user, username=user.name) run_plan = await runs.get_plan( session=session, project=project, @@ -137,6 +139,8 @@ async def apply_plan( If the existing run is active and cannot be updated, it must be stopped first. """ user, project = user_project + if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub: + await users.refresh_ssh_key(session=session, user=user, username=user.name) return CustomORJSONResponse( await runs.apply_plan( session=session, diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index abb6729141..706d3db8b5 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -36,11 +36,11 @@ async def list_users( return CustomORJSONResponse(await users.list_users_for_user(session=session, user=user)) -@router.post("/get_my_user", response_model=User) +@router.post("/get_my_user", response_model=UserWithCreds) async def get_my_user( user: UserModel = Depends(Authenticated()), ): - return CustomORJSONResponse(users.user_model_to_user(user)) + return CustomORJSONResponse(users.user_model_to_user_with_creds(user)) @router.post("/get_user", response_model=UserWithCreds) @@ -91,6 +91,18 @@ async def update_user( return CustomORJSONResponse(users.user_model_to_user(res)) +@router.post("/refresh_ssh_key", response_model=UserWithCreds) +async def refresh_ssh_key( + body: RefreshTokenRequest, + session: AsyncSession = Depends(get_session), + user: UserModel = Depends(Authenticated()), +): + res = await users.refresh_ssh_key(session=session, user=user, username=body.username) + if res is None: + raise ResourceNotExistsError() + return CustomORJSONResponse(users.user_model_to_user_with_creds(res)) + + @router.post("/refresh_token", response_model=UserWithCreds) async def refresh_token( body: RefreshTokenRequest, diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index e45d76ef33..c43f802fc8 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -317,7 +317,7 @@ async def get_plan( spec=effective_run_spec, ) effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict()) - _validate_run_spec_and_set_defaults(effective_run_spec) + _validate_run_spec_and_set_defaults(user, effective_run_spec) profile = effective_run_spec.merged_profile creation_policy = profile.creation_policy @@ -422,7 +422,7 @@ async def apply_plan( ) # Spec must be copied by parsing to calculate merged_profile run_spec = RunSpec.parse_obj(run_spec.dict()) - _validate_run_spec_and_set_defaults(run_spec) + _validate_run_spec_and_set_defaults(user, run_spec) if run_spec.run_name is None: return await submit_run( session=session, @@ -489,7 +489,7 @@ async def submit_run( project: ProjectModel, run_spec: RunSpec, ) -> Run: - _validate_run_spec_and_set_defaults(run_spec) + _validate_run_spec_and_set_defaults(user, run_spec) repo = await _get_run_repo_or_error( session=session, project=project, @@ -981,7 +981,7 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float: return job_submission.job_provisioning_data.price * duration_hours -def _validate_run_spec_and_set_defaults(run_spec: RunSpec): +def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): # This function may set defaults for null run_spec values, # although most defaults are resolved when building job_spec # so that we can keep both the original user-supplied value (null in run_spec) @@ -1031,6 +1031,11 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec): if run_spec.configuration.priority is None: run_spec.configuration.priority = RUN_PRIORITY_DEFAULT set_resources_defaults(run_spec.configuration.resources) + if run_spec.ssh_key_pub is None: + if user.ssh_public_key: + run_spec.ssh_key_pub = user.ssh_public_key + else: + raise ServerClientError("ssh_key_pub must be set if the user has no ssh_public_key") _UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"] diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 28cd5f85d2..d612bc9022 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -19,6 +19,8 @@ from dstack._internal.server.models import DecryptedString, UserModel from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.utils.routers import error_forbidden +from dstack._internal.utils.common import run_async +from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -84,6 +86,7 @@ async def create_user( raise ResourceExistsError() if token is None: token = str(uuid.uuid4()) + private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username) user = UserModel( id=uuid.uuid4(), name=username, @@ -92,6 +95,8 @@ async def create_user( token_hash=get_token_hash(token), email=email, active=active, + ssh_private_key=private_bytes.decode(), + ssh_public_key=public_bytes.decode(), ) session.add(user) await session.commit() @@ -107,6 +112,7 @@ async def update_user( email: Optional[str] = None, active: bool = True, ) -> UserModel: + # TODO: Allow to update ssh_private_key and ssh_public_key await session.execute( update(UserModel) .where(UserModel.name == username) @@ -120,6 +126,27 @@ async def update_user( return await get_user_model_by_name_or_error(session=session, username=username) +async def refresh_ssh_key( + session: AsyncSession, + user: UserModel, + username: str, +) -> Optional[UserModel]: + logger.debug("Refreshing SSH key for user [code]%s[/code]", username) + if user.global_role != GlobalRole.ADMIN and user.name != username: + raise error_forbidden() + private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username) + await session.execute( + update(UserModel) + .where(UserModel.name == username) + .values( + ssh_private_key=private_bytes.decode(), + ssh_public_key=public_bytes.decode(), + ) + ) + await session.commit() + return await get_user_model_by_name(session=session, username=username) + + async def refresh_user_token( session: AsyncSession, user: UserModel, @@ -212,6 +239,8 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds: active=user_model.active, permissions=get_user_permissions(user_model), creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()), + ssh_public_key=user_model.ssh_public_key, + ssh_private_key=user_model.ssh_private_key, ) diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index 44ad85597e..c5f7982c90 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -2,11 +2,10 @@ import dstack._internal.core.services.api_client as api_client_service from dstack._internal.core.errors import ConfigurationError -from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack.api._public.backends import BackendCollection -from dstack.api._public.repos import RepoCollection, get_ssh_keypair +from dstack.api._public.repos import RepoCollection from dstack.api._public.runs import RunCollection from dstack.api.server import APIClient @@ -29,30 +28,23 @@ def __init__( self, api_client: APIClient, project_name: str, - ssh_identity_file: Optional[PathLike] = None, ): # """ # Args: # api_client: low-level server API client # project_name: project name used for runs - # ssh_identity_file: SSH keypair to access instances # """ self._client = api_client self._project = project_name self._repos = RepoCollection(api_client, project_name) self._backends = BackendCollection(api_client, project_name) self._runs = RunCollection(api_client, project_name, self) - if ssh_identity_file: - self.ssh_identity_file = str(ssh_identity_file) - else: - self.ssh_identity_file = get_ssh_keypair(None, ConfigManager().dstack_key_path) @staticmethod def from_config( project_name: Optional[str] = None, server_url: Optional[str] = None, user_token: Optional[str] = None, - ssh_identity_file: Optional[PathLike] = None, ) -> "Client": """ Creates a Client using the default configuration from `~/.dstack/config.yml` if it exists. @@ -61,7 +53,6 @@ def from_config( project_name: The name of the project. required if `server_url` and `user_token` are specified. server_url: The dstack server URL (e.g. `http://localhost:3000/` or `https://sky.dstack.ai`). user_token: The dstack user token. - ssh_identity_file: The private SSH key path for SSH tunneling. Returns: A client instance. @@ -75,7 +66,6 @@ def from_config( return Client( api_client=api_client, project_name=project_name, - ssh_identity_file=ssh_identity_file, ) @property diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index d7519bbcf2..9015bc69c0 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Literal, Optional, Union, overload from git import InvalidGitRepositoryError @@ -18,7 +17,6 @@ get_repo_creds_and_default_branch, load_repo, ) -from dstack._internal.utils.crypto import generate_rsa_key_pair from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack.api.server import APIClient @@ -209,22 +207,3 @@ def get( return method(self._project, repo_id) except ResourceNotExistsError: return None - - -def get_ssh_keypair(key_path: Optional[PathLike], dstack_key_path: Path) -> str: - """Returns a path to the private key""" - if key_path is not None: - key_path = Path(key_path).expanduser().resolve() - pub_key = ( - key_path - if key_path.suffix == ".pub" - else key_path.with_suffix(key_path.suffix + ".pub") - ) - private_key = pub_key.with_suffix("") - if pub_key.exists() and private_key.exists(): - return str(private_key) - raise ConfigurationError(f"Make sure valid keypair exists: {private_key}(.pub)") - - if not dstack_key_path.exists(): - generate_rsa_key_pair(private_key_path=dstack_key_path) - return str(dstack_key_path) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index c6b7863373..6a63496780 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -1,4 +1,6 @@ import base64 +import hashlib +import os import queue import tempfile import threading @@ -15,6 +17,7 @@ from websocket import WebSocketApp import dstack.api as api +from dstack._internal.cli.utils.common import warn from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType @@ -45,6 +48,7 @@ get_service_port, ) from dstack._internal.core.models.runs import Run as RunModel +from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.logs import URLReplacer from dstack._internal.core.services.ssh.attach import SSHAttach from dstack._internal.core.services.ssh.ports import PortsLock @@ -53,6 +57,7 @@ from dstack._internal.utils.files import create_file_archive from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike, path_in_dir +from dstack.api._public.repos import generate_rsa_key_pair from dstack.api.server import APIClient logger = get_logger(__name__) @@ -72,13 +77,11 @@ def __init__( self, api_client: APIClient, project: str, - ssh_identity_file: Optional[PathLike], run: RunModel, ports_lock: Optional[PortsLock] = None, ): self._api_client = api_client self._project = project - self._ssh_identity_file = ssh_identity_file self._run = run self._ports_lock: Optional[PortsLock] = ports_lock self._ssh_attach: Optional[SSHAttach] = None @@ -270,9 +273,33 @@ def attach( Raises: dstack.api.PortUsedError: If ports are in use or the run is attached by another process. """ - ssh_identity_file = ssh_identity_file or self._ssh_identity_file - if ssh_identity_file is None: - raise ConfigurationError("SSH identity file is required to attach to the run") + if not ssh_identity_file: + user = self._api_client.users.get_my_user() + run_ssh_key_pub = self._run.run_spec.ssh_key_pub + config_manager = ConfigManager() + if user.ssh_public_key == run_ssh_key_pub: + token_hash = hashlib.sha1(user.creds.token.encode()).hexdigest()[:8] + config_manager.dstack_ssh_dir.mkdir(parents=True, exist_ok=True) + ssh_identity_file = config_manager.dstack_ssh_dir / token_hash + + def key_opener(path, flags): + return os.open(path, flags, 0o600) + + with open(ssh_identity_file, "wb", opener=key_opener) as f: + assert user.ssh_private_key + f.write(user.ssh_private_key.encode()) + else: + if config_manager.dstack_key_path.exists(): + # TODO: Remove since 0.19.35 + warn( + f"Using legacy [code]{config_manager.dstack_key_path}[/code]." + " Future versions will use the user SSH key from the server.", + ) + ssh_identity_file = config_manager.dstack_key_path + else: + raise ConfigurationError( + f"User SSH key doen't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist" + ) ssh_identity_file = str(ssh_identity_file) job = self._find_job(replica_num=replica_num, job_num=job_num) @@ -434,6 +461,7 @@ def get_run_plan( profile: Optional[Profile] = None, configuration_path: Optional[str] = None, repo_dir: Optional[str] = None, + ssh_identity_file: Optional[PathLike] = None, ) -> RunPlan: """ Get a run plan. @@ -465,6 +493,19 @@ def get_run_plan( if repo_dir is None and configuration.repos: repo_dir = configuration.repos[0].path + if ssh_identity_file: + ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text() + else: + config_manager = ConfigManager() + if not config_manager.dstack_key_path.exists(): + generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) + warn( + f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." + " Future versions will use the user SSH key from the server.", + ) + ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() + # TODO: Uncomment after 0.19.35 + # ssh_key_pub = None run_spec = RunSpec( run_name=configuration.name, repo_id=repo.repo_id, @@ -477,7 +518,7 @@ def get_run_plan( configuration_path=configuration_path, configuration=configuration, profile=profile, - ssh_key_pub=Path(self._client.ssh_identity_file + ".pub").read_text().strip(), + ssh_key_pub=ssh_key_pub, ) logger.debug("Getting run plan") run_plan = self._api_client.runs.get_plan(self._project, run_spec) @@ -546,6 +587,7 @@ def apply_configuration( profile: Optional[Profile] = None, configuration_path: Optional[str] = None, reserve_ports: bool = True, + ssh_identity_file: Optional[PathLike] = None, ) -> Run: """ Apply the run configuration. @@ -567,6 +609,7 @@ def apply_configuration( repo=repo, profile=profile, configuration_path=configuration_path, + ssh_identity_file=ssh_identity_file, ) run = self.apply_plan( run_plan=run_plan, @@ -718,7 +761,6 @@ def get_plan( configuration_path=configuration_path, configuration=configuration, profile=profile, - ssh_key_pub=Path(self._client.ssh_identity_file + ".pub").read_text().strip(), ) logger.debug("Getting run plan") run_plan = self._api_client.runs.get_plan(self._project, run_spec) @@ -800,7 +842,6 @@ def _model_to_run(self, run: RunModel) -> Run: return Run( self._api_client, self._project, - self._client.ssh_identity_file, run, ) @@ -808,7 +849,6 @@ def _model_to_submitted_run(self, run: RunModel, ports_lock: Optional[PortsLock] return Run( self._api_client, self._project, - self._client.ssh_identity_file, run, ports_lock, ) diff --git a/src/dstack/api/server/_users.py b/src/dstack/api/server/_users.py index 5c2d58b0b8..6082636c4b 100644 --- a/src/dstack/api/server/_users.py +++ b/src/dstack/api/server/_users.py @@ -17,9 +17,9 @@ def list(self) -> List[User]: resp = self._request("/api/users/list") return parse_obj_as(List[User.__response__], resp.json()) - def get_my_user(self) -> User: + def get_my_user(self) -> UserWithCreds: resp = self._request("/api/users/get_my_user") - return parse_obj_as(User.__response__, resp.json()) + return parse_obj_as(UserWithCreds.__response__, resp.json()) def get_user(self, username: str) -> User: body = GetUserRequest(username=username) From d2f42905dbe1d7ce1e9f2915dceb130ba36f4157 Mon Sep 17 00:00:00 2001 From: peterschmidt85 Date: Fri, 10 Oct 2025 01:21:20 +0200 Subject: [PATCH 2/6] [Feature]: Store user SSH key on the server #2053 Fixing tests --- src/dstack/api/_public/__init__.py | 2 +- src/tests/_internal/server/routers/test_users.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index c5f7982c90..b46e91aa21 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -3,7 +3,7 @@ import dstack._internal.core.services.api_client as api_client_service from dstack._internal.core.errors import ConfigurationError from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.path import PathLike +from dstack._internal.utils.path import PathLike as PathLike from dstack.api._public.backends import BackendCollection from dstack.api._public.repos import RepoCollection from dstack.api._public.runs import RunCollection diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index df4a8f3c06..f08ee780fb 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -139,10 +139,13 @@ async def test_returns_logged_in_user( "created_at": "2023-01-02T03:04:00+00:00", "global_role": user.global_role, "email": None, + "creds": {"token": user.token.get_plaintext_or_error()}, "active": True, "permissions": { "can_create_projects": True, }, + "ssh_private_key": None, + "ssh_public_key": None, } @@ -196,6 +199,8 @@ async def test_returns_logged_in_user( "permissions": { "can_create_projects": True, }, + "ssh_private_key": None, + "ssh_public_key": None, } From f65983bf8901cc565e977217b01c151407575165 Mon Sep 17 00:00:00 2001 From: peterschmidt85 Date: Fri, 10 Oct 2025 01:27:14 +0200 Subject: [PATCH 3/6] [Feature]: Store user SSH key on the server #2053 Bugfix --- src/dstack/api/_public/runs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 6a63496780..3312aa7e2b 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -54,10 +54,10 @@ from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url +from dstack._internal.utils.crypto import generate_rsa_key_pair from dstack._internal.utils.files import create_file_archive from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike, path_in_dir -from dstack.api._public.repos import generate_rsa_key_pair from dstack.api.server import APIClient logger = get_logger(__name__) From 1b6a7f4d2143b8174c9974b846da90259c046da1 Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov <54148038+peterschmidt85@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:22:42 -0700 Subject: [PATCH 4/6] Update src/dstack/_internal/server/services/users.py Co-authored-by: jvstme <36324149+jvstme@users.noreply.github.com> --- src/dstack/_internal/server/services/users.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index d612bc9022..d5a9ef8904 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -112,7 +112,6 @@ async def update_user( email: Optional[str] = None, active: bool = True, ) -> UserModel: - # TODO: Allow to update ssh_private_key and ssh_public_key await session.execute( update(UserModel) .where(UserModel.name == username) From 035cfe81ff57ff7c8b739227deb22e9f31d19f94 Mon Sep 17 00:00:00 2001 From: peterschmidt85 Date: Tue, 14 Oct 2025 15:34:26 -0700 Subject: [PATCH 5/6] [Feature]: Store user SSH key on the server #2053 Review: move `ssh_public_key` to `User` model (from `UserWithCreds`) --- src/dstack/_internal/core/models/users.py | 2 +- src/dstack/_internal/server/services/users.py | 3 ++- src/tests/_internal/server/routers/test_projects.py | 9 +++++++++ src/tests/_internal/server/routers/test_users.py | 13 +++++++++++-- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/core/models/users.py b/src/dstack/_internal/core/models/users.py index 449c57f8b2..d09d032bf2 100644 --- a/src/dstack/_internal/core/models/users.py +++ b/src/dstack/_internal/core/models/users.py @@ -30,6 +30,7 @@ class User(CoreModel): email: Optional[str] active: bool permissions: UserPermissions + ssh_public_key: Optional[str] = None class UserTokenCreds(CoreModel): @@ -38,5 +39,4 @@ class UserTokenCreds(CoreModel): class UserWithCreds(User): creds: UserTokenCreds - ssh_public_key: Optional[str] = None ssh_private_key: Optional[str] = None diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index d5a9ef8904..f4c57425b7 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -225,6 +225,7 @@ def user_model_to_user(user_model: UserModel) -> User: email=user_model.email, active=user_model.active, permissions=get_user_permissions(user_model), + ssh_public_key=user_model.ssh_public_key, ) @@ -237,8 +238,8 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds: email=user_model.email, active=user_model.active, permissions=get_user_permissions(user_model), - creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()), ssh_public_key=user_model.ssh_public_key, + creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()), ssh_private_key=user_model.ssh_private_key, ) diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 2fe7cc1888..d3b0426960 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -77,6 +77,7 @@ async def test_returns_projects(self, test_db, session: AsyncSession, client: As "permissions": { "can_create_projects": True, }, + "ssh_public_key": None, }, "created_at": "2023-01-02T03:04:00+00:00", "backends": [], @@ -244,6 +245,7 @@ async def test_creates_project(self, test_db, session: AsyncSession, client: Asy "permissions": { "can_create_projects": True, }, + "ssh_public_key": user.ssh_public_key, }, "created_at": "2023-01-02T03:04:00+00:00", "backends": [], @@ -259,6 +261,7 @@ async def test_creates_project(self, test_db, session: AsyncSession, client: Asy "permissions": { "can_create_projects": True, }, + "ssh_public_key": user.ssh_public_key, }, "project_role": ProjectRole.ADMIN, "permissions": { @@ -693,6 +696,7 @@ async def test_returns_project(self, test_db, session: AsyncSession, client: Asy "permissions": { "can_create_projects": True, }, + "ssh_public_key": None, }, "created_at": "2023-01-02T03:04:00+00:00", "backends": [], @@ -708,6 +712,7 @@ async def test_returns_project(self, test_db, session: AsyncSession, client: Asy "permissions": { "can_create_projects": True, }, + "ssh_public_key": None, }, "project_role": ProjectRole.ADMIN, "permissions": { @@ -937,6 +942,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession, client "permissions": { "can_create_projects": True, }, + "ssh_public_key": admin.ssh_public_key, }, "project_role": ProjectRole.ADMIN, "permissions": { @@ -954,6 +960,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession, client "permissions": { "can_create_projects": True, }, + "ssh_public_key": user1.ssh_public_key, }, "project_role": ProjectRole.ADMIN, "permissions": { @@ -971,6 +978,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession, client "permissions": { "can_create_projects": True, }, + "ssh_public_key": user2.ssh_public_key, }, "project_role": ProjectRole.USER, "permissions": { @@ -1027,6 +1035,7 @@ async def test_sets_project_members_by_email( "permissions": { "can_create_projects": True, }, + "ssh_public_key": user1.ssh_public_key, }, "project_role": ProjectRole.ADMIN, "permissions": { diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index f08ee780fb..6fc39ed712 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -48,6 +48,7 @@ async def test_admins_see_all_users(self, test_db, session: AsyncSession, client "permissions": { "can_create_projects": True, }, + "ssh_public_key": None, }, { "id": str(other_user.id), @@ -59,6 +60,7 @@ async def test_admins_see_all_users(self, test_db, session: AsyncSession, client "permissions": { "can_create_projects": True, }, + "ssh_public_key": None, }, ] @@ -92,6 +94,7 @@ async def test_non_admins_see_only_themselves( "permissions": { "can_create_projects": True, }, + "ssh_public_key": None, } ] @@ -229,7 +232,9 @@ async def test_creates_user(self, test_db, session: AsyncSession, client: AsyncC }, ) assert response.status_code == 200 - assert response.json() == { + user_data = response.json() + ssh_public_key = user_data["ssh_public_key"] + assert user_data == { "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", "username": "test", "created_at": "2023-01-02T03:04:00+00:00", @@ -239,6 +244,7 @@ async def test_creates_user(self, test_db, session: AsyncSession, client: AsyncC "permissions": { "can_create_projects": True, }, + "ssh_public_key": ssh_public_key, } res = await session.execute(select(UserModel).where(UserModel.name == "test")) assert len(res.scalars().all()) == 1 @@ -264,7 +270,9 @@ async def test_return_400_if_username_taken( }, ) assert response.status_code == 200 - assert response.json() == { + user_data = response.json() + ssh_public_key = user_data["ssh_public_key"] + assert user_data == { "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", "username": "Test", "created_at": "2023-01-02T03:04:00+00:00", @@ -274,6 +282,7 @@ async def test_return_400_if_username_taken( "permissions": { "can_create_projects": True, }, + "ssh_public_key": ssh_public_key, } # Username uniqueness check should be case insensitive for username in ["test", "Test", "TesT"]: From 4d42a86e0123890688128cb0aa58da7acffbd00d Mon Sep 17 00:00:00 2001 From: peterschmidt85 Date: Tue, 14 Oct 2025 16:21:45 -0700 Subject: [PATCH 6/6] [Feature]: Store user SSH key on the server #2053 Review --- src/dstack/_internal/cli/commands/offer.py | 1 + .../core/services/configs/__init__.py | 2 +- src/dstack/api/_public/__init__.py | 9 ++++++++- src/dstack/api/_public/runs.py | 18 ++++++++++++++++-- .../_internal/server/routers/test_users.py | 2 +- 5 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/cli/commands/offer.py b/src/dstack/_internal/cli/commands/offer.py index 17600b0607..d173982dab 100644 --- a/src/dstack/_internal/cli/commands/offer.py +++ b/src/dstack/_internal/cli/commands/offer.py @@ -105,6 +105,7 @@ def _command(self, args: argparse.Namespace): run_spec = RunSpec( configuration=conf, profile=profile, + ssh_key_pub="(dummy)", # TODO: Remove since 0.19.40 ) if args.group_by: diff --git a/src/dstack/_internal/core/services/configs/__init__.py b/src/dstack/_internal/core/services/configs/__init__.py index 13ed49c9b5..c6d333cf84 100644 --- a/src/dstack/_internal/core/services/configs/__init__.py +++ b/src/dstack/_internal/core/services/configs/__init__.py @@ -117,7 +117,7 @@ def dstack_ssh_dir(self) -> Path: @property def dstack_key_path(self) -> Path: - # TODO: Remove since 0.19.35 + # TODO: Remove since 0.19.40 return self.dstack_ssh_dir / "id_rsa" @property diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index b46e91aa21..510555a898 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -6,7 +6,7 @@ from dstack._internal.utils.path import PathLike as PathLike from dstack.api._public.backends import BackendCollection from dstack.api._public.repos import RepoCollection -from dstack.api._public.runs import RunCollection +from dstack.api._public.runs import RunCollection, warn from dstack.api.server import APIClient logger = get_logger(__name__) @@ -28,17 +28,24 @@ def __init__( self, api_client: APIClient, project_name: str, + ssh_identity_file: Optional[PathLike] = None, ): # """ # Args: # api_client: low-level server API client # project_name: project name used for runs + # ssh_identity_file: deprecated and will be removed in 0.19.40 # """ self._client = api_client self._project = project_name self._repos = RepoCollection(api_client, project_name) self._backends = BackendCollection(api_client, project_name) self._runs = RunCollection(api_client, project_name, self) + if ssh_identity_file is not None: + warn( + "[code]ssh_identity_file[/code] in [code]Client[/code] is deprecated and ignored; will be removed" + " since 0.19.40" + ) @staticmethod def from_config( diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 3312aa7e2b..5ebc8c2c42 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -79,12 +79,18 @@ def __init__( project: str, run: RunModel, ports_lock: Optional[PortsLock] = None, + ssh_identity_file: Optional[PathLike] = None, ): self._api_client = api_client self._project = project self._run = run self._ports_lock: Optional[PortsLock] = ports_lock self._ssh_attach: Optional[SSHAttach] = None + if ssh_identity_file is not None: + warn( + "[code]ssh_identity_file[/code] in [code]Run[/code] is deprecated and ignored; will be removed" + " since 0.19.40" + ) @property def name(self) -> str: @@ -290,7 +296,7 @@ def key_opener(path, flags): f.write(user.ssh_private_key.encode()) else: if config_manager.dstack_key_path.exists(): - # TODO: Remove since 0.19.35 + # TODO: Remove since 0.19.40 warn( f"Using legacy [code]{config_manager.dstack_key_path}[/code]." " Future versions will use the user SSH key from the server.", @@ -504,7 +510,7 @@ def get_run_plan( " Future versions will use the user SSH key from the server.", ) ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() - # TODO: Uncomment after 0.19.35 + # TODO: Uncomment after 0.19.40 # ssh_key_pub = None run_spec = RunSpec( run_name=configuration.name, @@ -752,6 +758,13 @@ def get_plan( creation_policy=creation_policy, idle_duration=idle_duration, # type: ignore[assignment] ) + config_manager = ConfigManager() + if not config_manager.dstack_key_path.exists(): + generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) + warn( + f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." + " Future versions will use the user SSH key from the server.", + ) run_spec = RunSpec( run_name=run_name, repo_id=repo.repo_id, @@ -761,6 +774,7 @@ def get_plan( configuration_path=configuration_path, configuration=configuration, profile=profile, + ssh_key_pub=config_manager.dstack_key_path.with_suffix(".pub").read_text(), ) logger.debug("Getting run plan") run_plan = self._api_client.runs.get_plan(self._project, run_spec) diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index 6fc39ed712..ec38c94f04 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -270,7 +270,7 @@ async def test_return_400_if_username_taken( }, ) assert response.status_code == 200 - user_data = response.json() + user_data = response.json() ssh_public_key = user_data["ssh_public_key"] assert user_data == { "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e",