33import re
44import secrets
55import uuid
6+ from collections .abc import AsyncGenerator
7+ from contextlib import asynccontextmanager
68from typing import Awaitable , Callable , List , Optional , Tuple
79
8- from sqlalchemy import delete , select , update
10+ from sqlalchemy import delete , select
911from sqlalchemy import func as safunc
1012from sqlalchemy .ext .asyncio import AsyncSession
1113from sqlalchemy .orm import load_only
1214
13- from dstack ._internal .core .errors import ResourceExistsError , ServerClientError
15+ from dstack ._internal .core .errors import (
16+ ResourceExistsError ,
17+ ServerClientError ,
18+ )
1419from dstack ._internal .core .models .users import (
1520 GlobalRole ,
1621 User ,
1924 UserTokenCreds ,
2025 UserWithCreds ,
2126)
27+ from dstack ._internal .server .db import get_db
2228from dstack ._internal .server .models import DecryptedString , MemberModel , UserModel
2329from dstack ._internal .server .services import events
30+ from dstack ._internal .server .services .locking import get_locker
2431from dstack ._internal .server .services .permissions import get_default_permissions
2532from dstack ._internal .server .utils .routers import error_forbidden
2633from dstack ._internal .utils import crypto
@@ -123,114 +130,128 @@ async def create_user(
123130
124131async def update_user (
125132 session : AsyncSession ,
133+ actor : UserModel ,
126134 username : str ,
127135 global_role : GlobalRole ,
128136 email : Optional [str ] = None ,
129137 active : bool = True ,
130- ) -> UserModel :
131- await session .execute (
132- update (UserModel )
133- .where (
134- UserModel .name == username ,
135- UserModel .deleted == False ,
136- )
137- .values (
138- global_role = global_role ,
139- email = email ,
140- active = active ,
138+ ) -> Optional [UserModel ]:
139+ async with get_user_model_by_name_for_update (session , username ) as user :
140+ if user is None :
141+ return None
142+ updated_fields = []
143+ if global_role != user .global_role :
144+ user .global_role = global_role
145+ updated_fields .append (f"global_role={ global_role } " )
146+ if email != user .email :
147+ user .email = email
148+ updated_fields .append ("email" ) # do not include potentially sensitive new value
149+ if active != user .active :
150+ user .active = active
151+ updated_fields .append (f"active={ active } " )
152+ events .emit (
153+ session ,
154+ f"User updated. Updated fields: { ', ' .join (updated_fields ) or '<none>' } " ,
155+ actor = events .UserActor .from_user (actor ),
156+ targets = [events .Target .from_model (user )],
141157 )
142- )
143- await session .commit ()
144- return await get_user_model_by_name_or_error (session = session , username = username )
158+ await session .commit ()
159+ return user
145160
146161
147162async def refresh_ssh_key (
148163 session : AsyncSession ,
149- user : UserModel ,
164+ actor : UserModel ,
150165 username : Optional [str ] = None ,
151166) -> Optional [UserModel ]:
152167 if username is None :
153- username = user .name
154- logger .debug ("Refreshing SSH key for user [code]%s[/code]" , username )
155- if user .global_role != GlobalRole .ADMIN and user .name != username :
168+ username = actor .name
169+ if actor .global_role != GlobalRole .ADMIN and actor .name != username :
156170 raise error_forbidden ()
157- private_bytes , public_bytes = await run_async (crypto .generate_rsa_key_pair_bytes , username )
158- await session .execute (
159- update (UserModel )
160- .where (
161- UserModel .name == username ,
162- UserModel .deleted == False ,
163- )
164- .values (
165- ssh_private_key = private_bytes .decode (),
166- ssh_public_key = public_bytes .decode (),
171+ async with get_user_model_by_name_for_update (session , username ) as user :
172+ if user is None :
173+ return None
174+ private_bytes , public_bytes = await run_async (crypto .generate_rsa_key_pair_bytes , username )
175+ user .ssh_private_key = private_bytes .decode ()
176+ user .ssh_public_key = public_bytes .decode ()
177+ events .emit (
178+ session ,
179+ "User SSH key refreshed" ,
180+ actor = events .UserActor .from_user (actor ),
181+ targets = [events .Target .from_model (user )],
167182 )
168- )
169- await session .commit ()
170- return await get_user_model_by_name (session = session , username = username )
183+ await session .commit ()
184+ return user
171185
172186
173187async def refresh_user_token (
174188 session : AsyncSession ,
175- user : UserModel ,
189+ actor : UserModel ,
176190 username : str ,
177191) -> Optional [UserModel ]:
178- if user .global_role != GlobalRole .ADMIN and user .name != username :
192+ if actor .global_role != GlobalRole .ADMIN and actor .name != username :
179193 raise error_forbidden ()
180- new_token = str (uuid .uuid4 ())
181- await session .execute (
182- update (UserModel )
183- .where (
184- UserModel .name == username ,
185- UserModel .deleted == False ,
186- )
187- .values (
188- token = DecryptedString (plaintext = new_token ),
189- token_hash = get_token_hash (new_token ),
194+ async with get_user_model_by_name_for_update (session , username ) as user :
195+ if user is None :
196+ return None
197+ new_token = str (uuid .uuid4 ())
198+ user .token = DecryptedString (plaintext = new_token )
199+ user .token_hash = get_token_hash (new_token )
200+ events .emit (
201+ session ,
202+ "User token refreshed" ,
203+ actor = events .UserActor .from_user (actor ),
204+ targets = [events .Target .from_model (user )],
190205 )
191- )
192- await session .commit ()
193- return await get_user_model_by_name (session = session , username = username )
206+ await session .commit ()
207+ return user
194208
195209
196210async def delete_users (
197211 session : AsyncSession ,
198- user : UserModel ,
212+ actor : UserModel ,
199213 usernames : List [str ],
200214):
201215 if _ADMIN_USERNAME in usernames :
202- raise ServerClientError ("User 'admin' cannot be deleted" )
203-
204- res = await session .execute (
205- select (UserModel )
206- .where (
207- UserModel .name .in_ (usernames ),
208- UserModel .deleted == False ,
209- )
210- .options (load_only (UserModel .id , UserModel .name ))
211- )
212- users = res .scalars ().all ()
213- if len (users ) != len (usernames ):
214- raise ServerClientError ("Failed to delete non-existent users" )
215-
216- user_ids = [u .id for u in users ]
217- timestamp = str (int (get_current_datetime ().timestamp ()))
218- updates = []
219- for u in users :
220- updates .append (
221- {
222- "id" : u .id ,
223- "name" : f"_deleted_{ timestamp } _{ secrets .token_hex (8 )} " ,
224- "original_name" : u .name ,
225- "deleted" : True ,
226- "active" : False ,
227- }
216+ raise ServerClientError (f"User { _ADMIN_USERNAME !r} cannot be deleted" )
217+
218+ filters = [
219+ UserModel .name .in_ (usernames ),
220+ UserModel .deleted == False ,
221+ ]
222+ res = await session .execute (select (UserModel .id ).where (* filters ))
223+ user_ids = list (res .scalars ().all ())
224+ user_ids .sort ()
225+
226+ async with get_locker (get_db ().dialect_name ).lock_ctx (UserModel .__tablename__ , user_ids ):
227+ # Refetch after lock
228+ res = await session .execute (
229+ select (UserModel )
230+ .where (UserModel .id .in_ (user_ids ), * filters )
231+ .order_by (UserModel .id ) # take locks in order
232+ .options (load_only (UserModel .id , UserModel .name ))
233+ .with_for_update (key_share = True )
228234 )
229- await session .execute (update (UserModel ), updates )
230- await session .execute (delete (MemberModel ).where (MemberModel .user_id .in_ (user_ids )))
231- # Projects are not deleted automatically if owners are deleted.
232- await session .commit ()
233- logger .info ("Deleted users %s by user %s" , usernames , user .name )
235+ users = list (res .scalars ().all ())
236+ if len (users ) != len (usernames ):
237+ raise ServerClientError ("Failed to delete non-existent users" )
238+ user_ids = [u .id for u in users ]
239+ timestamp = str (int (get_current_datetime ().timestamp ()))
240+ for u in users :
241+ event_target = events .Target .from_model (u ) # build target before renaming the user
242+ u .deleted = True
243+ u .active = False
244+ u .original_name = u .name
245+ u .name = f"_deleted_{ timestamp } _{ secrets .token_hex (8 )} "
246+ events .emit (
247+ session ,
248+ "User deleted" ,
249+ actor = events .UserActor .from_user (actor ),
250+ targets = [event_target ],
251+ )
252+ await session .execute (delete (MemberModel ).where (MemberModel .user_id .in_ (user_ids )))
253+ # Projects are not deleted automatically if owners are deleted.
254+ await session .commit ()
234255
235256
236257async def get_user_model_by_name (
@@ -257,6 +278,36 @@ async def get_user_model_by_name_or_error(
257278 )
258279
259280
281+ @asynccontextmanager
282+ async def get_user_model_by_name_for_update (
283+ session : AsyncSession , username : str
284+ ) -> AsyncGenerator [Optional [UserModel ], None ]:
285+ """
286+ Fetch the user from the database and lock it for update.
287+
288+ **NOTE**: commit changes to the database before exiting from this context manager,
289+ so that in-memory locks are only released after commit.
290+ """
291+
292+ filters = [
293+ UserModel .name == username ,
294+ UserModel .deleted == False ,
295+ ]
296+ res = await session .execute (select (UserModel .id ).where (* filters ))
297+ user_id = res .scalar_one_or_none ()
298+ if user_id is None :
299+ yield None
300+ else :
301+ async with get_locker (get_db ().dialect_name ).lock_ctx (UserModel .__tablename__ , [user_id ]):
302+ # Refetch after lock
303+ res = await session .execute (
304+ select (UserModel )
305+ .where (UserModel .id .in_ ([user_id ]), * filters )
306+ .with_for_update (key_share = True )
307+ )
308+ yield res .scalar_one_or_none ()
309+
310+
260311async def log_in_with_token (session : AsyncSession , token : str ) -> Optional [UserModel ]:
261312 token_hash = get_token_hash (token )
262313 res = await session .execute (
0 commit comments