Skip to content

Commit 329e993

Browse files
committed
Remove redundant get_user_key()
1 parent 4694dcc commit 329e993

File tree

6 files changed

+66
-12
lines changed

6 files changed

+66
-12
lines changed

src/dstack/_internal/server/routers/runs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def get_plan(
118118
"""
119119
user, project = user_project
120120
if not user.ssh_public_key and not body.run_spec.ssh_key_pub:
121-
await users.refresh_ssh_key(session=session, user=user, username=user.name)
121+
await users.refresh_ssh_key(session=session, user=user)
122122
run_plan = await runs.get_plan(
123123
session=session,
124124
project=project,
@@ -148,7 +148,7 @@ async def apply_plan(
148148
"""
149149
user, project = user_project
150150
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
151-
await users.refresh_ssh_key(session=session, user=user, username=user.name)
151+
await users.refresh_ssh_key(session=session, user=user)
152152
return CustomORJSONResponse(
153153
await runs.apply_plan(
154154
session=session,

src/dstack/_internal/server/routers/users.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ async def get_my_user(
4343
):
4444
if user.ssh_private_key is None or user.ssh_public_key is None:
4545
# Generate keys for pre-0.19.33 users
46-
updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name)
47-
if updated_user is None:
48-
raise ResourceNotExistsError()
49-
user = updated_user
46+
await users.refresh_ssh_key(session=session, user=user)
5047
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))
5148

5249

src/dstack/_internal/server/services/users.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ async def update_user(
147147
async def refresh_ssh_key(
148148
session: AsyncSession,
149149
user: UserModel,
150-
username: str,
150+
username: Optional[str] = None,
151151
) -> Optional[UserModel]:
152+
if username is None:
153+
username = user.name
152154
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
153155
if user.global_role != GlobalRole.ADMIN and user.name != username:
154156
raise error_forbidden()

src/dstack/_internal/server/testing/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def get_run_spec(
277277
configuration_path: str = "dstack.yaml",
278278
profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"),
279279
configuration: Optional[AnyRunConfiguration] = None,
280+
ssh_key_pub: Optional[str] = "user_ssh_key",
280281
) -> RunSpec:
281282
if callable(profile):
282283
profile = profile()
@@ -288,7 +289,7 @@ def get_run_spec(
288289
configuration_path=configuration_path,
289290
configuration=configuration or DevEnvironmentConfiguration(ide="vscode"),
290291
profile=profile,
291-
ssh_key_pub="user_ssh_key",
292+
ssh_key_pub=ssh_key_pub,
292293
)
293294

294295

src/dstack/api/_public/runs.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,6 @@ def get_run_plan(
504504
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
505505
else:
506506
ssh_key_pub = None # using the server-managed user key
507-
config_manager = ConfigManager()
508-
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
509-
# Ensure we have a fresh key locally
510-
key_manager.get_user_key()
511507
run_spec = RunSpec(
512508
run_name=configuration.name,
513509
repo_id=repo.repo_id,

src/tests/_internal/server/routers/test_runs.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,32 @@ async def test_returns_update_or_create_action_on_conf_change(
13641364
assert response_json["action"] == action
13651365
assert response_json["current_resource"] == json.loads(run.json())
13661366

1367+
@pytest.mark.asyncio
1368+
@pytest.mark.usefixtures("test_db")
1369+
async def test_generates_user_ssh_key(self, session: AsyncSession, client: AsyncClient):
1370+
user = await create_user(
1371+
session=session, global_role=GlobalRole.USER, ssh_public_key=None, ssh_private_key=None
1372+
)
1373+
project = await create_project(session=session, owner=user)
1374+
await add_project_member(
1375+
session=session, project=project, user=user, project_role=ProjectRole.USER
1376+
)
1377+
repo = await create_repo(session=session, project_id=project.id)
1378+
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name, ssh_key_pub=None)
1379+
1380+
response = await client.post(
1381+
f"/api/project/{project.name}/runs/get_plan",
1382+
headers=get_auth_headers(user.token),
1383+
json={"run_spec": run_spec.dict()},
1384+
)
1385+
1386+
assert response.status_code == 200, response.json()
1387+
run_spec_ssh_public_key = response.json()["effective_run_spec"]["ssh_key_pub"]
1388+
assert run_spec_ssh_public_key is not None
1389+
await session.refresh(user)
1390+
assert user.ssh_public_key == run_spec_ssh_public_key
1391+
assert user.ssh_private_key is not None
1392+
13671393

13681394
class TestApplyPlan:
13691395
@pytest.mark.asyncio
@@ -1517,6 +1543,38 @@ async def test_creates_pending_run_if_run_is_scheduled(
15171543
assert run.status == RunStatus.PENDING
15181544
assert run.next_triggered_at == datetime(2023, 1, 2, 3, 10, tzinfo=timezone.utc)
15191545

1546+
@pytest.mark.asyncio
1547+
@pytest.mark.usefixtures("test_db")
1548+
async def test_generates_user_ssh_key(self, session: AsyncSession, client: AsyncClient):
1549+
user = await create_user(
1550+
session=session, global_role=GlobalRole.USER, ssh_public_key=None, ssh_private_key=None
1551+
)
1552+
project = await create_project(session=session, owner=user)
1553+
await add_project_member(
1554+
session=session, project=project, user=user, project_role=ProjectRole.USER
1555+
)
1556+
repo = await create_repo(session=session, project_id=project.id)
1557+
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name, ssh_key_pub=None)
1558+
1559+
response = await client.post(
1560+
f"/api/project/{project.name}/runs/apply",
1561+
headers=get_auth_headers(user.token),
1562+
json={
1563+
"plan": {
1564+
"run_spec": run_spec.dict(),
1565+
"current_resource": None,
1566+
},
1567+
"force": False,
1568+
},
1569+
)
1570+
1571+
assert response.status_code == 200, response.json()
1572+
run_spec_ssh_public_key = response.json()["run_spec"]["ssh_key_pub"]
1573+
assert run_spec_ssh_public_key is not None
1574+
await session.refresh(user)
1575+
assert user.ssh_public_key == run_spec_ssh_public_key
1576+
assert user.ssh_private_key is not None
1577+
15201578

15211579
class TestSubmitRun:
15221580
@pytest.mark.asyncio

0 commit comments

Comments
 (0)