Skip to content

Commit 6f64743

Browse files
authored
Add more events about users and projects (#3390)
- User updated - User token refreshed - User SSH key refreshed - User deleted - Project updated - Project deleted Also refactor the implementation of the relevant operations on users to enable more detailed event messages and to avoid race conditions and longer write transactions.
1 parent e74332a commit 6f64743

6 files changed

Lines changed: 177 additions & 89 deletions

File tree

src/dstack/_internal/server/routers/runs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def get_plan(
118118
"""
119119
user, project = user_project
120120
if not user.ssh_public_key and not body.run_spec.ssh_key_pub:
121-
await users.refresh_ssh_key(session=session, user=user)
121+
await users.refresh_ssh_key(session=session, actor=user)
122122
run_plan = await runs.get_plan(
123123
session=session,
124124
project=project,
@@ -148,7 +148,7 @@ async def apply_plan(
148148
"""
149149
user, project = user_project
150150
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
151-
await users.refresh_ssh_key(session=session, user=user)
151+
await users.refresh_ssh_key(session=session, actor=user)
152152
return CustomORJSONResponse(
153153
await runs.apply_plan(
154154
session=session,

src/dstack/_internal/server/routers/users.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def get_my_user(
4343
):
4444
if user.ssh_private_key is None or user.ssh_public_key is None:
4545
# Generate keys for pre-0.19.33 users
46-
await users.refresh_ssh_key(session=session, user=user)
46+
await users.refresh_ssh_key(session=session, actor=user)
4747
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))
4848

4949

@@ -86,6 +86,7 @@ async def update_user(
8686
):
8787
res = await users.update_user(
8888
session=session,
89+
actor=user,
8990
username=body.username,
9091
global_role=body.global_role,
9192
email=body.email,
@@ -102,7 +103,7 @@ async def refresh_ssh_key(
102103
session: AsyncSession = Depends(get_session),
103104
user: UserModel = Depends(Authenticated()),
104105
):
105-
res = await users.refresh_ssh_key(session=session, user=user, username=body.username)
106+
res = await users.refresh_ssh_key(session=session, actor=user, username=body.username)
106107
if res is None:
107108
raise ResourceNotExistsError()
108109
return CustomORJSONResponse(users.user_model_to_user_with_creds(res))
@@ -114,7 +115,7 @@ async def refresh_token(
114115
session: AsyncSession = Depends(get_session),
115116
user: UserModel = Depends(Authenticated()),
116117
):
117-
res = await users.refresh_user_token(session=session, user=user, username=body.username)
118+
res = await users.refresh_user_token(session=session, actor=user, username=body.username)
118119
if res is None:
119120
raise ResourceNotExistsError()
120121
return CustomORJSONResponse(users.user_model_to_user_with_creds(res))
@@ -128,6 +129,6 @@ async def delete_users(
128129
):
129130
await users.delete_users(
130131
session=session,
131-
user=user,
132+
actor=user,
132133
usernames=body.users,
133134
)

src/dstack/_internal/server/services/projects.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,16 @@ async def update_project(
169169
project: ProjectModel,
170170
is_public: bool,
171171
):
172-
"""Update project visibility (public/private)."""
173-
project.is_public = is_public
172+
updated_fields = []
173+
if is_public != project.is_public:
174+
project.is_public = is_public
175+
updated_fields.append(f"is_public={is_public}")
176+
events.emit(
177+
session,
178+
f"Project updated. Updated fields: {', '.join(updated_fields) or '<none>'}",
179+
actor=events.UserActor.from_user(user),
180+
targets=[events.Target.from_model(project)],
181+
)
174182
await session.commit()
175183

176184

@@ -222,9 +230,14 @@ async def delete_projects(
222230
"deleted": True,
223231
}
224232
)
233+
events.emit(
234+
session,
235+
"Project deleted",
236+
actor=events.UserActor.from_user(user),
237+
targets=[events.Target.from_model(p)],
238+
)
225239
await session.execute(update(ProjectModel), updates)
226240
await session.commit()
227-
logger.info("Deleted projects %s by user %s", projects_names, user.name)
228241

229242

230243
async def set_project_members(

src/dstack/_internal/server/services/users.py

Lines changed: 131 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
import re
44
import secrets
55
import uuid
6+
from collections.abc import AsyncGenerator
7+
from contextlib import asynccontextmanager
68
from typing import Awaitable, Callable, List, Optional, Tuple
79

8-
from sqlalchemy import delete, select, update
10+
from sqlalchemy import delete, select
911
from sqlalchemy import func as safunc
1012
from sqlalchemy.ext.asyncio import AsyncSession
1113
from sqlalchemy.orm import load_only
1214

13-
from dstack._internal.core.errors import ResourceExistsError, ServerClientError
15+
from dstack._internal.core.errors import (
16+
ResourceExistsError,
17+
ServerClientError,
18+
)
1419
from dstack._internal.core.models.users import (
1520
GlobalRole,
1621
User,
@@ -19,8 +24,10 @@
1924
UserTokenCreds,
2025
UserWithCreds,
2126
)
27+
from dstack._internal.server.db import get_db
2228
from dstack._internal.server.models import DecryptedString, MemberModel, UserModel
2329
from dstack._internal.server.services import events
30+
from dstack._internal.server.services.locking import get_locker
2431
from dstack._internal.server.services.permissions import get_default_permissions
2532
from dstack._internal.server.utils.routers import error_forbidden
2633
from dstack._internal.utils import crypto
@@ -123,114 +130,128 @@ async def create_user(
123130

124131
async def update_user(
125132
session: AsyncSession,
133+
actor: UserModel,
126134
username: str,
127135
global_role: GlobalRole,
128136
email: Optional[str] = None,
129137
active: bool = True,
130-
) -> UserModel:
131-
await session.execute(
132-
update(UserModel)
133-
.where(
134-
UserModel.name == username,
135-
UserModel.deleted == False,
136-
)
137-
.values(
138-
global_role=global_role,
139-
email=email,
140-
active=active,
138+
) -> Optional[UserModel]:
139+
async with get_user_model_by_name_for_update(session, username) as user:
140+
if user is None:
141+
return None
142+
updated_fields = []
143+
if global_role != user.global_role:
144+
user.global_role = global_role
145+
updated_fields.append(f"global_role={global_role}")
146+
if email != user.email:
147+
user.email = email
148+
updated_fields.append("email") # do not include potentially sensitive new value
149+
if active != user.active:
150+
user.active = active
151+
updated_fields.append(f"active={active}")
152+
events.emit(
153+
session,
154+
f"User updated. Updated fields: {', '.join(updated_fields) or '<none>'}",
155+
actor=events.UserActor.from_user(actor),
156+
targets=[events.Target.from_model(user)],
141157
)
142-
)
143-
await session.commit()
144-
return await get_user_model_by_name_or_error(session=session, username=username)
158+
await session.commit()
159+
return user
145160

146161

147162
async def refresh_ssh_key(
148163
session: AsyncSession,
149-
user: UserModel,
164+
actor: UserModel,
150165
username: Optional[str] = None,
151166
) -> Optional[UserModel]:
152167
if username is None:
153-
username = user.name
154-
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
155-
if user.global_role != GlobalRole.ADMIN and user.name != username:
168+
username = actor.name
169+
if actor.global_role != GlobalRole.ADMIN and actor.name != username:
156170
raise error_forbidden()
157-
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
158-
await session.execute(
159-
update(UserModel)
160-
.where(
161-
UserModel.name == username,
162-
UserModel.deleted == False,
163-
)
164-
.values(
165-
ssh_private_key=private_bytes.decode(),
166-
ssh_public_key=public_bytes.decode(),
171+
async with get_user_model_by_name_for_update(session, username) as user:
172+
if user is None:
173+
return None
174+
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
175+
user.ssh_private_key = private_bytes.decode()
176+
user.ssh_public_key = public_bytes.decode()
177+
events.emit(
178+
session,
179+
"User SSH key refreshed",
180+
actor=events.UserActor.from_user(actor),
181+
targets=[events.Target.from_model(user)],
167182
)
168-
)
169-
await session.commit()
170-
return await get_user_model_by_name(session=session, username=username)
183+
await session.commit()
184+
return user
171185

172186

173187
async def refresh_user_token(
174188
session: AsyncSession,
175-
user: UserModel,
189+
actor: UserModel,
176190
username: str,
177191
) -> Optional[UserModel]:
178-
if user.global_role != GlobalRole.ADMIN and user.name != username:
192+
if actor.global_role != GlobalRole.ADMIN and actor.name != username:
179193
raise error_forbidden()
180-
new_token = str(uuid.uuid4())
181-
await session.execute(
182-
update(UserModel)
183-
.where(
184-
UserModel.name == username,
185-
UserModel.deleted == False,
186-
)
187-
.values(
188-
token=DecryptedString(plaintext=new_token),
189-
token_hash=get_token_hash(new_token),
194+
async with get_user_model_by_name_for_update(session, username) as user:
195+
if user is None:
196+
return None
197+
new_token = str(uuid.uuid4())
198+
user.token = DecryptedString(plaintext=new_token)
199+
user.token_hash = get_token_hash(new_token)
200+
events.emit(
201+
session,
202+
"User token refreshed",
203+
actor=events.UserActor.from_user(actor),
204+
targets=[events.Target.from_model(user)],
190205
)
191-
)
192-
await session.commit()
193-
return await get_user_model_by_name(session=session, username=username)
206+
await session.commit()
207+
return user
194208

195209

196210
async def delete_users(
197211
session: AsyncSession,
198-
user: UserModel,
212+
actor: UserModel,
199213
usernames: List[str],
200214
):
201215
if _ADMIN_USERNAME in usernames:
202-
raise ServerClientError("User 'admin' cannot be deleted")
203-
204-
res = await session.execute(
205-
select(UserModel)
206-
.where(
207-
UserModel.name.in_(usernames),
208-
UserModel.deleted == False,
209-
)
210-
.options(load_only(UserModel.id, UserModel.name))
211-
)
212-
users = res.scalars().all()
213-
if len(users) != len(usernames):
214-
raise ServerClientError("Failed to delete non-existent users")
215-
216-
user_ids = [u.id for u in users]
217-
timestamp = str(int(get_current_datetime().timestamp()))
218-
updates = []
219-
for u in users:
220-
updates.append(
221-
{
222-
"id": u.id,
223-
"name": f"_deleted_{timestamp}_{secrets.token_hex(8)}",
224-
"original_name": u.name,
225-
"deleted": True,
226-
"active": False,
227-
}
216+
raise ServerClientError(f"User {_ADMIN_USERNAME!r} cannot be deleted")
217+
218+
filters = [
219+
UserModel.name.in_(usernames),
220+
UserModel.deleted == False,
221+
]
222+
res = await session.execute(select(UserModel.id).where(*filters))
223+
user_ids = list(res.scalars().all())
224+
user_ids.sort()
225+
226+
async with get_locker(get_db().dialect_name).lock_ctx(UserModel.__tablename__, user_ids):
227+
# Refetch after lock
228+
res = await session.execute(
229+
select(UserModel)
230+
.where(UserModel.id.in_(user_ids), *filters)
231+
.order_by(UserModel.id) # take locks in order
232+
.options(load_only(UserModel.id, UserModel.name))
233+
.with_for_update(key_share=True)
228234
)
229-
await session.execute(update(UserModel), updates)
230-
await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids)))
231-
# Projects are not deleted automatically if owners are deleted.
232-
await session.commit()
233-
logger.info("Deleted users %s by user %s", usernames, user.name)
235+
users = list(res.scalars().all())
236+
if len(users) != len(usernames):
237+
raise ServerClientError("Failed to delete non-existent users")
238+
user_ids = [u.id for u in users]
239+
timestamp = str(int(get_current_datetime().timestamp()))
240+
for u in users:
241+
event_target = events.Target.from_model(u) # build target before renaming the user
242+
u.deleted = True
243+
u.active = False
244+
u.original_name = u.name
245+
u.name = f"_deleted_{timestamp}_{secrets.token_hex(8)}"
246+
events.emit(
247+
session,
248+
"User deleted",
249+
actor=events.UserActor.from_user(actor),
250+
targets=[event_target],
251+
)
252+
await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids)))
253+
# Projects are not deleted automatically if owners are deleted.
254+
await session.commit()
234255

235256

236257
async def get_user_model_by_name(
@@ -257,6 +278,36 @@ async def get_user_model_by_name_or_error(
257278
)
258279

259280

281+
@asynccontextmanager
282+
async def get_user_model_by_name_for_update(
283+
session: AsyncSession, username: str
284+
) -> AsyncGenerator[Optional[UserModel], None]:
285+
"""
286+
Fetch the user from the database and lock it for update.
287+
288+
**NOTE**: commit changes to the database before exiting from this context manager,
289+
so that in-memory locks are only released after commit.
290+
"""
291+
292+
filters = [
293+
UserModel.name == username,
294+
UserModel.deleted == False,
295+
]
296+
res = await session.execute(select(UserModel.id).where(*filters))
297+
user_id = res.scalar_one_or_none()
298+
if user_id is None:
299+
yield None
300+
else:
301+
async with get_locker(get_db().dialect_name).lock_ctx(UserModel.__tablename__, [user_id]):
302+
# Refetch after lock
303+
res = await session.execute(
304+
select(UserModel)
305+
.where(UserModel.id.in_([user_id]), *filters)
306+
.with_for_update(key_share=True)
307+
)
308+
yield res.scalar_one_or_none()
309+
310+
260311
async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]:
261312
token_hash = get_token_hash(token)
262313
res = await session.execute(

src/tests/_internal/server/routers/test_projects.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,16 @@ async def test_deletes_projects(
495495
await session.refresh(project2)
496496
assert project1.deleted
497497
assert not project2.deleted
498+
# Validate an event is emitted
499+
response = await client.post(
500+
"/api/events/list", headers=get_auth_headers(user.token), json={}
501+
)
502+
assert response.status_code == 200
503+
assert len(response.json()) == 1
504+
assert response.json()[0]["message"] == "Project deleted"
505+
assert len(response.json()[0]["targets"]) == 1
506+
assert response.json()[0]["targets"][0]["id"] == str(project1.id)
507+
assert response.json()[0]["targets"][0]["name"] == project_name
498508

499509
@pytest.mark.asyncio
500510
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)

0 commit comments

Comments
 (0)