diff --git a/src/dstack/_internal/core/services/ssh/key_manager.py b/src/dstack/_internal/core/services/ssh/key_manager.py new file mode 100644 index 0000000000..98e941638a --- /dev/null +++ b/src/dstack/_internal/core/services/ssh/key_manager.py @@ -0,0 +1,56 @@ +import os +from dataclasses import dataclass +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from dstack._internal.core.models.users import UserWithCreds + +if TYPE_CHECKING: + from dstack.api.server import APIClient + +KEY_REFRESH_RATE = timedelta(minutes=10) # redownload the key periodically in case it was rotated + + +@dataclass +class UserSSHKey: + public_key: str + private_key_path: Path + + +class UserSSHKeyManager: + def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None: + self._api_client = api_client + self._key_path = ssh_keys_dir / api_client.get_token_hash() + self._pub_key_path = self._key_path.with_suffix(".pub") + + def get_user_key(self) -> Optional[UserSSHKey]: + """ + Return the up-to-date user key, or None if the user has no key (if created before 0.19.33) + """ + if ( + not self._key_path.exists() + or not self._pub_key_path.exists() + or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime) + > KEY_REFRESH_RATE + ): + if not self._download_user_key(): + return None + return UserSSHKey( + public_key=self._pub_key_path.read_text(), private_key_path=self._key_path + ) + + def _download_user_key(self) -> bool: + user = self._api_client.users.get_my_user() + if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key): + return False + + def key_opener(path, flags): + return os.open(path, flags, 0o600) + + with open(self._key_path, "w", opener=key_opener) as f: + f.write(user.ssh_private_key) + with open(self._pub_key_path, "w") as f: + f.write(user.ssh_public_key) + + return True diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 31f44d3692..c6d97b810e 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -190,6 +190,9 @@ class UserModel(BaseModel): # deactivated users cannot access API active: Mapped[bool] = mapped_column(Boolean, default=True) + # SSH keys can be null for users created before 0.19.33. + # Keys for those users are being gradually generated on /get_my_user calls. + # TODO: make keys required in a future version. ssh_private_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True) ssh_public_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 706d3db8b5..d58472275d 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -38,8 +38,15 @@ async def list_users( @router.post("/get_my_user", response_model=UserWithCreds) async def get_my_user( + session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): + if user.ssh_private_key is None or user.ssh_public_key is None: + # Generate keys for pre-0.19.33 users + updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name) + if updated_user is None: + raise ResourceNotExistsError() + user = updated_user return CustomORJSONResponse(users.user_model_to_user_with_creds(user)) diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 12fb0e00cf..9fdbe3b4e8 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -20,8 +20,8 @@ from dstack._internal.server.models import DecryptedString, UserModel from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.utils.routers import error_forbidden +from dstack._internal.utils import crypto from dstack._internal.utils.common import run_async -from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -88,7 +88,7 @@ async def create_user( raise ResourceExistsError() if token is None: token = str(uuid.uuid4()) - private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username) + private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) user = UserModel( id=uuid.uuid4(), name=username, @@ -135,7 +135,7 @@ async def refresh_ssh_key( logger.debug("Refreshing SSH key for user [code]%s[/code]", username) if user.global_role != GlobalRole.ADMIN and user.name != username: raise error_forbidden() - private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username) + private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) await session.execute( update(UserModel) .where(UserModel.name == username) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 0a8adaa428..e6de272911 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -126,6 +126,8 @@ async def create_user( global_role: GlobalRole = GlobalRole.ADMIN, token: Optional[str] = None, email: Optional[str] = None, + ssh_public_key: Optional[str] = None, + ssh_private_key: Optional[str] = None, active: bool = True, ) -> UserModel: if token is None: @@ -137,6 +139,8 @@ async def create_user( token=DecryptedString(plaintext=token), token_hash=get_token_hash(token), email=email, + ssh_public_key=ssh_public_key, + ssh_private_key=ssh_private_key, active=active, ) session.add(user) diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index 510555a898..1d9ab353d2 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -6,7 +6,7 @@ from dstack._internal.utils.path import PathLike as PathLike from dstack.api._public.backends import BackendCollection from dstack.api._public.repos import RepoCollection -from dstack.api._public.runs import RunCollection, warn +from dstack.api._public.runs import RunCollection from dstack.api.server import APIClient logger = get_logger(__name__) @@ -42,7 +42,7 @@ def __init__( self._backends = BackendCollection(api_client, project_name) self._runs = RunCollection(api_client, project_name, self) if ssh_identity_file is not None: - warn( + logger.warning( "[code]ssh_identity_file[/code] in [code]Client[/code] is deprecated and ignored; will be removed" " since 0.19.40" ) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index c51c899d30..b583e676ec 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -1,6 +1,4 @@ import base64 -import hashlib -import os import queue import tempfile import threading @@ -17,7 +15,6 @@ from websocket import WebSocketApp import dstack.api as api -from dstack._internal.cli.utils.common import warn from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType @@ -48,10 +45,10 @@ get_service_port, ) from dstack._internal.core.models.runs import Run as RunModel -from dstack._internal.core.models.users import UserWithCreds from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.logs import URLReplacer from dstack._internal.core.services.ssh.attach import SSHAttach +from dstack._internal.core.services.ssh.key_manager import UserSSHKeyManager from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url @@ -88,7 +85,7 @@ def __init__( self._ports_lock: Optional[PortsLock] = ports_lock self._ssh_attach: Optional[SSHAttach] = None if ssh_identity_file is not None: - warn( + logger.warning( "[code]ssh_identity_file[/code] in [code]Run[/code] is deprecated and ignored; will be removed" " since 0.19.40" ) @@ -281,31 +278,20 @@ def attach( dstack.api.PortUsedError: If ports are in use or the run is attached by another process. """ if not ssh_identity_file: - user = self._api_client.users.get_my_user() - run_ssh_key_pub = self._run.run_spec.ssh_key_pub config_manager = ConfigManager() - if isinstance(user, UserWithCreds) and user.ssh_public_key == run_ssh_key_pub: - token_hash = hashlib.sha1(user.creds.token.encode()).hexdigest()[:8] - config_manager.dstack_ssh_dir.mkdir(parents=True, exist_ok=True) - ssh_identity_file = config_manager.dstack_ssh_dir / token_hash - - def key_opener(path, flags): - return os.open(path, flags, 0o600) - - with open(ssh_identity_file, "wb", opener=key_opener) as f: - assert user.ssh_private_key - f.write(user.ssh_private_key.encode()) + key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) + if ( + user_key := key_manager.get_user_key() + ) and user_key.public_key == self._run.run_spec.ssh_key_pub: + ssh_identity_file = user_key.private_key_path else: if config_manager.dstack_key_path.exists(): # TODO: Remove since 0.19.40 - warn( - f"Using legacy [code]{config_manager.dstack_key_path}[/code]." - " Future versions will use the user SSH key from the server.", - ) + logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].") ssh_identity_file = config_manager.dstack_key_path else: raise ConfigurationError( - f"User SSH key doen't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist" + f"User SSH key doesn't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist" ) ssh_identity_file = str(ssh_identity_file) @@ -504,15 +490,19 @@ def get_run_plan( ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text() else: config_manager = ConfigManager() - if not config_manager.dstack_key_path.exists(): - generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) - warn( - f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." - " Future versions will use the user SSH key from the server.", - ) - ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() - # TODO: Uncomment after 0.19.40 - # ssh_key_pub = None + key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) + if key_manager.get_user_key(): + ssh_key_pub = None # using the server-managed user key + else: + if not config_manager.dstack_key_path.exists(): + generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) + logger.warning( + f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." + " You will only be able to attach to the run from this client." + " Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys" + " automatically replicated to all clients.", + ) + ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() run_spec = RunSpec( run_name=configuration.name, repo_id=repo.repo_id, @@ -760,12 +750,19 @@ def get_plan( idle_duration=idle_duration, # type: ignore[assignment] ) config_manager = ConfigManager() - if not config_manager.dstack_key_path.exists(): - generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) - warn( - f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." - " Future versions will use the user SSH key from the server.", - ) + key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) + if key_manager.get_user_key(): + ssh_key_pub = None # using the server-managed user key + else: + if not config_manager.dstack_key_path.exists(): + generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) + logger.warning( + f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." + " You will only be able to attach to the run from this client." + " Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys" + " automatically replicated to all clients.", + ) + ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() run_spec = RunSpec( run_name=run_name, repo_id=repo.repo_id, @@ -775,7 +772,7 @@ def get_plan( configuration_path=configuration_path, configuration=configuration, profile=profile, - ssh_key_pub=config_manager.dstack_key_path.with_suffix(".pub").read_text(), + ssh_key_pub=ssh_key_pub, ) logger.debug("Getting run plan") run_plan = self._api_client.runs.get_plan(self._project, run_spec) diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index ce0328c2ef..be0e586e62 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -1,3 +1,4 @@ +import hashlib import os import pprint import time @@ -121,6 +122,9 @@ def volumes(self) -> VolumesAPIClient: def files(self) -> FilesAPIClient: return FilesAPIClient(self._request, self._logger) + def get_token_hash(self) -> str: + return hashlib.sha1(self._token.encode()).hexdigest()[:8] + def _request( self, path: str, diff --git a/src/tests/_internal/core/services/ssh/test_key_manager.py b/src/tests/_internal/core/services/ssh/test_key_manager.py new file mode 100644 index 0000000000..d727ba6258 --- /dev/null +++ b/src/tests/_internal/core/services/ssh/test_key_manager.py @@ -0,0 +1,100 @@ +import os +import time +import uuid +from datetime import datetime +from pathlib import Path +from unittest.mock import Mock + +from dstack._internal.core.models.users import ( + GlobalRole, + User, + UserPermissions, + UserTokenCreds, + UserWithCreds, +) +from dstack._internal.core.services.ssh.key_manager import ( + KEY_REFRESH_RATE, + UserSSHKeyManager, +) + +SAMPLE_USER = UserWithCreds( + id=uuid.uuid4(), + username="test", + created_at=datetime.now(), + global_role=GlobalRole.USER, + active=True, + email="test@example.com", + permissions=UserPermissions(can_create_projects=False), + creds=UserTokenCreds(token="7f92121b-a1b9-4ff2-8c0e-39070ffcd964"), + ssh_public_key="ssh-rsa AAA.public", + ssh_private_key="-----BEGIN PRIVATE KEY-----\nPRIVATE\n-----END PRIVATE KEY-----", +) +SAMPLE_USER_TOKEN_HASH = "4f010545" # sha1(SAMPLE_USER.creds.token.encode()).hexdigest[:8] + + +def make_api_client(user: User, token_hash: str): + api_client = Mock() + api_client.get_token_hash.return_value = token_hash + api_client.users = Mock() + api_client.users.get_my_user.return_value = user + return api_client + + +def set_mtime(path: Path, ts: float): + os.utime(path, (ts, ts)) + + +def test_get_user_key_returns_none_when_no_user_creds(tmp_path: Path): + api_client = make_api_client( + user=User.__response__.parse_obj(SAMPLE_USER.dict()), token_hash=SAMPLE_USER_TOKEN_HASH + ) + manager = UserSSHKeyManager(api_client, tmp_path) + + assert manager.get_user_key() is None + assert not (tmp_path / SAMPLE_USER_TOKEN_HASH).exists() + assert not (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").exists() + + +def test_get_user_key_downloads_keys(tmp_path: Path): + api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) + manager = UserSSHKeyManager(api_client, tmp_path) + + key = manager.get_user_key() + assert key is not None + assert key.public_key == SAMPLE_USER.ssh_public_key + assert key.private_key_path == tmp_path / SAMPLE_USER_TOKEN_HASH + assert (tmp_path / SAMPLE_USER_TOKEN_HASH).read_text() == SAMPLE_USER.ssh_private_key + assert (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").read_text() == SAMPLE_USER.ssh_public_key + + +def test_get_user_key_uses_existing_key(tmp_path: Path): + api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) + (tmp_path / SAMPLE_USER_TOKEN_HASH).write_text("private-contents") + (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").write_text("public-contents") + + manager = UserSSHKeyManager(api_client, tmp_path) + key = manager.get_user_key() + + assert api_client.users.get_my_user.call_count == 0 + assert key is not None + assert key.public_key == "public-contents" + assert key.private_key_path == (tmp_path / SAMPLE_USER_TOKEN_HASH) + + +def test_get_user_key_redownloads_expired_key(tmp_path: Path): + api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) + priv = tmp_path / SAMPLE_USER_TOKEN_HASH + pub = tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub" + priv.write_text("old-private") + pub.write_text("old-public") + stale_ts = time.time() - KEY_REFRESH_RATE.total_seconds() - 10 + set_mtime(priv, stale_ts) + set_mtime(pub, stale_ts) + + manager = UserSSHKeyManager(api_client, tmp_path) + key = manager.get_user_key() + assert key is not None + assert key.public_key == SAMPLE_USER.ssh_public_key + assert key.private_key_path == priv + assert priv.read_text() == SAMPLE_USER.ssh_private_key + assert pub.read_text() == SAMPLE_USER.ssh_public_key diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index ec38c94f04..54da638039 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -131,6 +131,8 @@ async def test_returns_logged_in_user( user = await create_user( session=session, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ssh_public_key="public-key", + ssh_private_key="private-key", ) response = await client.post( "/api/users/get_my_user", headers=get_auth_headers(user.token) @@ -147,10 +149,33 @@ async def test_returns_logged_in_user( "permissions": { "can_create_projects": True, }, - "ssh_private_key": None, - "ssh_public_key": None, + "ssh_private_key": "private-key", + "ssh_public_key": "public-key", } + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_generates_ssh_key_if_missing( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user( + session=session, + ssh_public_key=None, + ssh_private_key=None, + ) + with patch("dstack._internal.utils.crypto.generate_rsa_key_pair_bytes") as gen_mock: + gen_mock.return_value = (b"private-key", b"ssh-rsa AAA.public-key user\n") + response = await client.post( + "/api/users/get_my_user", headers=get_auth_headers(user.token) + ) + assert response.status_code == 200 + data = response.json() + assert data["ssh_private_key"] == "private-key" + assert data["ssh_public_key"] == "ssh-rsa AAA.public-key user\n" + await session.refresh(user) + assert user.ssh_private_key == data["ssh_private_key"] + assert user.ssh_public_key == data["ssh_public_key"] + class TestGetUser: @pytest.mark.asyncio