From f19b0fbd3944ba0ad01eae08d1e6a01011708ccf Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 22 Oct 2025 20:56:37 +0200 Subject: [PATCH] Use server-managed user SSH keys for new runs This commit updates the CLI and the server to use server-managed user SSH keys when starting new runs. This allows users to attach to the run from different machines, since the SSH key is automatically replicated to all clients. Implementation details: - Server: - If the user key is missing, generate it when the user first calls `/get_my_user`. - Client: - Before applying or getting a run plan, call `/get_my_user` to check if the user key is available. If it is, use it. - Cache the downloaded keys in `~/.dstack/ssh` to avoid repeated `/get_my_user` calls. - Switch from `warn` to logger messages, since this code is part of the Python API, so its output should be configurable. --- .../core/services/ssh/key_manager.py | 56 ++++++++++ src/dstack/_internal/server/models.py | 3 + src/dstack/_internal/server/routers/users.py | 7 ++ src/dstack/_internal/server/services/users.py | 6 +- src/dstack/_internal/server/testing/common.py | 4 + src/dstack/api/_public/__init__.py | 4 +- src/dstack/api/_public/runs.py | 75 +++++++------ src/dstack/api/server/__init__.py | 4 + .../core/services/ssh/test_key_manager.py | 100 ++++++++++++++++++ .../_internal/server/routers/test_users.py | 29 ++++- 10 files changed, 242 insertions(+), 46 deletions(-) create mode 100644 src/dstack/_internal/core/services/ssh/key_manager.py create mode 100644 src/tests/_internal/core/services/ssh/test_key_manager.py diff --git a/src/dstack/_internal/core/services/ssh/key_manager.py b/src/dstack/_internal/core/services/ssh/key_manager.py new file mode 100644 index 0000000000..98e941638a --- /dev/null +++ b/src/dstack/_internal/core/services/ssh/key_manager.py @@ -0,0 +1,56 @@ +import os +from dataclasses import dataclass +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from dstack._internal.core.models.users import UserWithCreds + +if TYPE_CHECKING: + from dstack.api.server import APIClient + +KEY_REFRESH_RATE = timedelta(minutes=10) # redownload the key periodically in case it was rotated + + +@dataclass +class UserSSHKey: + public_key: str + private_key_path: Path + + +class UserSSHKeyManager: + def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None: + self._api_client = api_client + self._key_path = ssh_keys_dir / api_client.get_token_hash() + self._pub_key_path = self._key_path.with_suffix(".pub") + + def get_user_key(self) -> Optional[UserSSHKey]: + """ + Return the up-to-date user key, or None if the user has no key (if created before 0.19.33) + """ + if ( + not self._key_path.exists() + or not self._pub_key_path.exists() + or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime) + > KEY_REFRESH_RATE + ): + if not self._download_user_key(): + return None + return UserSSHKey( + public_key=self._pub_key_path.read_text(), private_key_path=self._key_path + ) + + def _download_user_key(self) -> bool: + user = self._api_client.users.get_my_user() + if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key): + return False + + def key_opener(path, flags): + return os.open(path, flags, 0o600) + + with open(self._key_path, "w", opener=key_opener) as f: + f.write(user.ssh_private_key) + with open(self._pub_key_path, "w") as f: + f.write(user.ssh_public_key) + + return True diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 31f44d3692..c6d97b810e 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -190,6 +190,9 @@ class UserModel(BaseModel): # deactivated users cannot access API active: Mapped[bool] = mapped_column(Boolean, default=True) + # SSH keys can be null for users created before 0.19.33. + # Keys for those users are being gradually generated on /get_my_user calls. + # TODO: make keys required in a future version. ssh_private_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True) ssh_public_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 706d3db8b5..d58472275d 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -38,8 +38,15 @@ async def list_users( @router.post("/get_my_user", response_model=UserWithCreds) async def get_my_user( + session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): + if user.ssh_private_key is None or user.ssh_public_key is None: + # Generate keys for pre-0.19.33 users + updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name) + if updated_user is None: + raise ResourceNotExistsError() + user = updated_user return CustomORJSONResponse(users.user_model_to_user_with_creds(user)) diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 12fb0e00cf..9fdbe3b4e8 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -20,8 +20,8 @@ from dstack._internal.server.models import DecryptedString, UserModel from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.utils.routers import error_forbidden +from dstack._internal.utils import crypto from dstack._internal.utils.common import run_async -from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -88,7 +88,7 @@ async def create_user( raise ResourceExistsError() if token is None: token = str(uuid.uuid4()) - private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username) + private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) user = UserModel( id=uuid.uuid4(), name=username, @@ -135,7 +135,7 @@ async def refresh_ssh_key( logger.debug("Refreshing SSH key for user [code]%s[/code]", username) if user.global_role != GlobalRole.ADMIN and user.name != username: raise error_forbidden() - private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username) + private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) await session.execute( update(UserModel) .where(UserModel.name == username) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 0a8adaa428..e6de272911 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -126,6 +126,8 @@ async def create_user( global_role: GlobalRole = GlobalRole.ADMIN, token: Optional[str] = None, email: Optional[str] = None, + ssh_public_key: Optional[str] = None, + ssh_private_key: Optional[str] = None, active: bool = True, ) -> UserModel: if token is None: @@ -137,6 +139,8 @@ async def create_user( token=DecryptedString(plaintext=token), token_hash=get_token_hash(token), email=email, + ssh_public_key=ssh_public_key, + ssh_private_key=ssh_private_key, active=active, ) session.add(user) diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index 510555a898..1d9ab353d2 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -6,7 +6,7 @@ from dstack._internal.utils.path import PathLike as PathLike from dstack.api._public.backends import BackendCollection from dstack.api._public.repos import RepoCollection -from dstack.api._public.runs import RunCollection, warn +from dstack.api._public.runs import RunCollection from dstack.api.server import APIClient logger = get_logger(__name__) @@ -42,7 +42,7 @@ def __init__( self._backends = BackendCollection(api_client, project_name) self._runs = RunCollection(api_client, project_name, self) if ssh_identity_file is not None: - warn( + logger.warning( "[code]ssh_identity_file[/code] in [code]Client[/code] is deprecated and ignored; will be removed" " since 0.19.40" ) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index c51c899d30..b583e676ec 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -1,6 +1,4 @@ import base64 -import hashlib -import os import queue import tempfile import threading @@ -17,7 +15,6 @@ from websocket import WebSocketApp import dstack.api as api -from dstack._internal.cli.utils.common import warn from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType @@ -48,10 +45,10 @@ get_service_port, ) from dstack._internal.core.models.runs import Run as RunModel -from dstack._internal.core.models.users import UserWithCreds from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.logs import URLReplacer from dstack._internal.core.services.ssh.attach import SSHAttach +from dstack._internal.core.services.ssh.key_manager import UserSSHKeyManager from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url @@ -88,7 +85,7 @@ def __init__( self._ports_lock: Optional[PortsLock] = ports_lock self._ssh_attach: Optional[SSHAttach] = None if ssh_identity_file is not None: - warn( + logger.warning( "[code]ssh_identity_file[/code] in [code]Run[/code] is deprecated and ignored; will be removed" " since 0.19.40" ) @@ -281,31 +278,20 @@ def attach( dstack.api.PortUsedError: If ports are in use or the run is attached by another process. """ if not ssh_identity_file: - user = self._api_client.users.get_my_user() - run_ssh_key_pub = self._run.run_spec.ssh_key_pub config_manager = ConfigManager() - if isinstance(user, UserWithCreds) and user.ssh_public_key == run_ssh_key_pub: - token_hash = hashlib.sha1(user.creds.token.encode()).hexdigest()[:8] - config_manager.dstack_ssh_dir.mkdir(parents=True, exist_ok=True) - ssh_identity_file = config_manager.dstack_ssh_dir / token_hash - - def key_opener(path, flags): - return os.open(path, flags, 0o600) - - with open(ssh_identity_file, "wb", opener=key_opener) as f: - assert user.ssh_private_key - f.write(user.ssh_private_key.encode()) + key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) + if ( + user_key := key_manager.get_user_key() + ) and user_key.public_key == self._run.run_spec.ssh_key_pub: + ssh_identity_file = user_key.private_key_path else: if config_manager.dstack_key_path.exists(): # TODO: Remove since 0.19.40 - warn( - f"Using legacy [code]{config_manager.dstack_key_path}[/code]." - " Future versions will use the user SSH key from the server.", - ) + logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].") ssh_identity_file = config_manager.dstack_key_path else: raise ConfigurationError( - f"User SSH key doen't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist" + f"User SSH key doesn't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist" ) ssh_identity_file = str(ssh_identity_file) @@ -504,15 +490,19 @@ def get_run_plan( ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text() else: config_manager = ConfigManager() - if not config_manager.dstack_key_path.exists(): - generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) - warn( - f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." - " Future versions will use the user SSH key from the server.", - ) - ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() - # TODO: Uncomment after 0.19.40 - # ssh_key_pub = None + key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) + if key_manager.get_user_key(): + ssh_key_pub = None # using the server-managed user key + else: + if not config_manager.dstack_key_path.exists(): + generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) + logger.warning( + f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." + " You will only be able to attach to the run from this client." + " Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys" + " automatically replicated to all clients.", + ) + ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() run_spec = RunSpec( run_name=configuration.name, repo_id=repo.repo_id, @@ -760,12 +750,19 @@ def get_plan( idle_duration=idle_duration, # type: ignore[assignment] ) config_manager = ConfigManager() - if not config_manager.dstack_key_path.exists(): - generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) - warn( - f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." - " Future versions will use the user SSH key from the server.", - ) + key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) + if key_manager.get_user_key(): + ssh_key_pub = None # using the server-managed user key + else: + if not config_manager.dstack_key_path.exists(): + generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) + logger.warning( + f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." + " You will only be able to attach to the run from this client." + " Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys" + " automatically replicated to all clients.", + ) + ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() run_spec = RunSpec( run_name=run_name, repo_id=repo.repo_id, @@ -775,7 +772,7 @@ def get_plan( configuration_path=configuration_path, configuration=configuration, profile=profile, - ssh_key_pub=config_manager.dstack_key_path.with_suffix(".pub").read_text(), + ssh_key_pub=ssh_key_pub, ) logger.debug("Getting run plan") run_plan = self._api_client.runs.get_plan(self._project, run_spec) diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index ce0328c2ef..be0e586e62 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -1,3 +1,4 @@ +import hashlib import os import pprint import time @@ -121,6 +122,9 @@ def volumes(self) -> VolumesAPIClient: def files(self) -> FilesAPIClient: return FilesAPIClient(self._request, self._logger) + def get_token_hash(self) -> str: + return hashlib.sha1(self._token.encode()).hexdigest()[:8] + def _request( self, path: str, diff --git a/src/tests/_internal/core/services/ssh/test_key_manager.py b/src/tests/_internal/core/services/ssh/test_key_manager.py new file mode 100644 index 0000000000..d727ba6258 --- /dev/null +++ b/src/tests/_internal/core/services/ssh/test_key_manager.py @@ -0,0 +1,100 @@ +import os +import time +import uuid +from datetime import datetime +from pathlib import Path +from unittest.mock import Mock + +from dstack._internal.core.models.users import ( + GlobalRole, + User, + UserPermissions, + UserTokenCreds, + UserWithCreds, +) +from dstack._internal.core.services.ssh.key_manager import ( + KEY_REFRESH_RATE, + UserSSHKeyManager, +) + +SAMPLE_USER = UserWithCreds( + id=uuid.uuid4(), + username="test", + created_at=datetime.now(), + global_role=GlobalRole.USER, + active=True, + email="test@example.com", + permissions=UserPermissions(can_create_projects=False), + creds=UserTokenCreds(token="7f92121b-a1b9-4ff2-8c0e-39070ffcd964"), + ssh_public_key="ssh-rsa AAA.public", + ssh_private_key="-----BEGIN PRIVATE KEY-----\nPRIVATE\n-----END PRIVATE KEY-----", +) +SAMPLE_USER_TOKEN_HASH = "4f010545" # sha1(SAMPLE_USER.creds.token.encode()).hexdigest[:8] + + +def make_api_client(user: User, token_hash: str): + api_client = Mock() + api_client.get_token_hash.return_value = token_hash + api_client.users = Mock() + api_client.users.get_my_user.return_value = user + return api_client + + +def set_mtime(path: Path, ts: float): + os.utime(path, (ts, ts)) + + +def test_get_user_key_returns_none_when_no_user_creds(tmp_path: Path): + api_client = make_api_client( + user=User.__response__.parse_obj(SAMPLE_USER.dict()), token_hash=SAMPLE_USER_TOKEN_HASH + ) + manager = UserSSHKeyManager(api_client, tmp_path) + + assert manager.get_user_key() is None + assert not (tmp_path / SAMPLE_USER_TOKEN_HASH).exists() + assert not (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").exists() + + +def test_get_user_key_downloads_keys(tmp_path: Path): + api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) + manager = UserSSHKeyManager(api_client, tmp_path) + + key = manager.get_user_key() + assert key is not None + assert key.public_key == SAMPLE_USER.ssh_public_key + assert key.private_key_path == tmp_path / SAMPLE_USER_TOKEN_HASH + assert (tmp_path / SAMPLE_USER_TOKEN_HASH).read_text() == SAMPLE_USER.ssh_private_key + assert (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").read_text() == SAMPLE_USER.ssh_public_key + + +def test_get_user_key_uses_existing_key(tmp_path: Path): + api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) + (tmp_path / SAMPLE_USER_TOKEN_HASH).write_text("private-contents") + (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").write_text("public-contents") + + manager = UserSSHKeyManager(api_client, tmp_path) + key = manager.get_user_key() + + assert api_client.users.get_my_user.call_count == 0 + assert key is not None + assert key.public_key == "public-contents" + assert key.private_key_path == (tmp_path / SAMPLE_USER_TOKEN_HASH) + + +def test_get_user_key_redownloads_expired_key(tmp_path: Path): + api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) + priv = tmp_path / SAMPLE_USER_TOKEN_HASH + pub = tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub" + priv.write_text("old-private") + pub.write_text("old-public") + stale_ts = time.time() - KEY_REFRESH_RATE.total_seconds() - 10 + set_mtime(priv, stale_ts) + set_mtime(pub, stale_ts) + + manager = UserSSHKeyManager(api_client, tmp_path) + key = manager.get_user_key() + assert key is not None + assert key.public_key == SAMPLE_USER.ssh_public_key + assert key.private_key_path == priv + assert priv.read_text() == SAMPLE_USER.ssh_private_key + assert pub.read_text() == SAMPLE_USER.ssh_public_key diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index ec38c94f04..54da638039 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -131,6 +131,8 @@ async def test_returns_logged_in_user( user = await create_user( session=session, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ssh_public_key="public-key", + ssh_private_key="private-key", ) response = await client.post( "/api/users/get_my_user", headers=get_auth_headers(user.token) @@ -147,10 +149,33 @@ async def test_returns_logged_in_user( "permissions": { "can_create_projects": True, }, - "ssh_private_key": None, - "ssh_public_key": None, + "ssh_private_key": "private-key", + "ssh_public_key": "public-key", } + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_generates_ssh_key_if_missing( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user( + session=session, + ssh_public_key=None, + ssh_private_key=None, + ) + with patch("dstack._internal.utils.crypto.generate_rsa_key_pair_bytes") as gen_mock: + gen_mock.return_value = (b"private-key", b"ssh-rsa AAA.public-key user\n") + response = await client.post( + "/api/users/get_my_user", headers=get_auth_headers(user.token) + ) + assert response.status_code == 200 + data = response.json() + assert data["ssh_private_key"] == "private-key" + assert data["ssh_public_key"] == "ssh-rsa AAA.public-key user\n" + await session.refresh(user) + assert user.ssh_private_key == data["ssh_private_key"] + assert user.ssh_public_key == data["ssh_public_key"] + class TestGetUser: @pytest.mark.asyncio