@@ -370,94 +370,57 @@ async def apply_plan(
370370 spec = spec ,
371371 )
372372
373- current_fleet_model = await get_project_fleet_model_by_name (
373+ fleet_model = await get_project_fleet_model_by_name (
374374 session = session ,
375375 project = project ,
376376 name = configuration .name ,
377377 )
378- if current_fleet_model is None :
378+ if fleet_model is None :
379379 return await _create_fleet (
380380 session = session ,
381381 project = project ,
382382 user = user ,
383383 spec = spec ,
384384 )
385385
386- current_fleet = fleet_model_to_fleet (current_fleet_model )
387- _set_fleet_spec_defaults (current_fleet .spec )
388- current_fleet_sensitive = fleet_model_to_fleet (current_fleet_model , include_sensitive = True )
389- _set_fleet_spec_defaults (current_fleet_sensitive .spec )
390-
391- if not force :
392- if plan .current_resource is not None :
393- _set_fleet_spec_defaults (plan .current_resource .spec )
394- if (
395- plan .current_resource is None
396- or plan .current_resource .id != current_fleet .id
397- or plan .current_resource .spec != current_fleet .spec
398- ):
399- raise ServerClientError (
400- "Failed to apply plan. Resource has been changed. Try again or use force apply."
401- )
402-
403- _check_can_update_fleet_spec (current_fleet_sensitive .spec , spec )
404-
405- spec_json = spec .json ()
406- current_fleet_model .spec = spec_json
407-
408- if (
409- current_fleet_sensitive .spec .configuration .ssh_config is not None
410- and spec .configuration .ssh_config is not None
386+ instances_ids = sorted (i .id for i in fleet_model .instances if not i .deleted )
387+ await session .commit ()
388+ async with (
389+ get_locker (get_db ().dialect_name ).lock_ctx (FleetModel .__tablename__ , [fleet_model .id ]),
390+ get_locker (get_db ().dialect_name ).lock_ctx (InstanceModel .__tablename__ , instances_ids ),
411391 ):
412- added_hosts , removed_hosts , changed_hosts = _calculate_ssh_hosts_changes (
413- current = current_fleet_sensitive .spec .configuration .ssh_config .hosts ,
414- new = spec .configuration .ssh_config .hosts ,
392+ # Refetch after lock
393+ # TODO: Lock instances with FOR UPDATE?
394+ res = await session .execute (
395+ select (FleetModel )
396+ .where (
397+ FleetModel .project_id == project .id ,
398+ FleetModel .id == fleet_model .id ,
399+ FleetModel .deleted == False ,
400+ )
401+ .options (selectinload (FleetModel .instances ))
402+ .options (selectinload (FleetModel .runs ))
403+ .execution_options (populate_existing = True )
404+ .order_by (FleetModel .id ) # take locks in order
405+ .with_for_update (key_share = True )
415406 )
416- # `_check_can_update_fleet_spec` ensures hosts are not changed
417- assert not changed_hosts , changed_hosts
418- active_instance_nums : set [int ] = set ()
419- removed_instance_nums : list [int ] = []
420- if removed_hosts or added_hosts :
421- for instance_model in current_fleet_model .instances :
422- if instance_model .deleted :
423- continue
424- active_instance_nums .add (instance_model .instance_num )
425- rci = get_instance_remote_connection_info (instance_model )
426- if rci is None :
427- logger .error (
428- "Cloud instance %s in SSH fleet %s" ,
429- instance_model .id ,
430- current_fleet_model .id ,
431- )
432- continue
433- if rci .host in removed_hosts :
434- removed_instance_nums .append (instance_model .instance_num )
435- if added_hosts :
436- await _check_ssh_hosts_not_yet_added (session , spec , current_fleet .id )
437- for host in added_hosts .values ():
438- instance_num = _get_next_instance_num (active_instance_nums )
439- instance_model = await create_fleet_ssh_instance_model (
440- project = project ,
441- spec = spec ,
442- ssh_params = spec .configuration .ssh_config ,
443- env = spec .configuration .env ,
444- instance_num = instance_num ,
445- host = host ,
446- )
447- current_fleet_model .instances .append (instance_model )
448- active_instance_nums .add (instance_num )
449- if removed_instance_nums :
450- # Calls `session.commit()`, must be called last
451- await delete_fleets (
407+ fleet_model = res .scalars ().unique ().one_or_none ()
408+ if fleet_model is not None :
409+ return await _update_fleet (
452410 session = session ,
453411 project = project ,
454- user = user ,
455- names = [current_fleet .name ],
456- instance_nums = removed_instance_nums ,
412+ spec = spec ,
413+ current_resource = plan .current_resource ,
414+ force = force ,
415+ fleet_model = fleet_model ,
457416 )
458417
459- await session .commit ()
460- return fleet_model_to_fleet (current_fleet_model )
418+ return await _create_fleet (
419+ session = session ,
420+ project = project ,
421+ user = user ,
422+ spec = spec ,
423+ )
461424
462425
463426async def create_fleet (
@@ -734,6 +697,84 @@ async def _create_fleet(
734697 return fleet_model_to_fleet (fleet_model )
735698
736699
700+ async def _update_fleet (
701+ session : AsyncSession ,
702+ project : ProjectModel ,
703+ spec : FleetSpec ,
704+ current_resource : Optional [Fleet ],
705+ force : bool ,
706+ fleet_model : FleetModel ,
707+ ) -> Fleet :
708+ fleet = fleet_model_to_fleet (fleet_model )
709+ _set_fleet_spec_defaults (fleet .spec )
710+ fleet_sensitive = fleet_model_to_fleet (fleet_model , include_sensitive = True )
711+ _set_fleet_spec_defaults (fleet_sensitive .spec )
712+
713+ if not force :
714+ if current_resource is not None :
715+ _set_fleet_spec_defaults (current_resource .spec )
716+ if (
717+ current_resource is None
718+ or current_resource .id != fleet .id
719+ or current_resource .spec != fleet .spec
720+ ):
721+ raise ServerClientError (
722+ "Failed to apply plan. Resource has been changed. Try again or use force apply."
723+ )
724+
725+ _check_can_update_fleet_spec (fleet_sensitive .spec , spec )
726+
727+ spec_json = spec .json ()
728+ fleet_model .spec = spec_json
729+
730+ if (
731+ fleet_sensitive .spec .configuration .ssh_config is not None
732+ and spec .configuration .ssh_config is not None
733+ ):
734+ added_hosts , removed_hosts , changed_hosts = _calculate_ssh_hosts_changes (
735+ current = fleet_sensitive .spec .configuration .ssh_config .hosts ,
736+ new = spec .configuration .ssh_config .hosts ,
737+ )
738+ # `_check_can_update_fleet_spec` ensures hosts are not changed
739+ assert not changed_hosts , changed_hosts
740+ active_instance_nums : set [int ] = set ()
741+ removed_instance_nums : list [int ] = []
742+ if removed_hosts or added_hosts :
743+ for instance_model in fleet_model .instances :
744+ if instance_model .deleted :
745+ continue
746+ active_instance_nums .add (instance_model .instance_num )
747+ rci = get_instance_remote_connection_info (instance_model )
748+ if rci is None :
749+ logger .error (
750+ "Cloud instance %s in SSH fleet %s" ,
751+ instance_model .id ,
752+ fleet_model .id ,
753+ )
754+ continue
755+ if rci .host in removed_hosts :
756+ removed_instance_nums .append (instance_model .instance_num )
757+ if added_hosts :
758+ await _check_ssh_hosts_not_yet_added (session , spec , fleet .id )
759+ for host in added_hosts .values ():
760+ instance_num = _get_next_instance_num (active_instance_nums )
761+ instance_model = await create_fleet_ssh_instance_model (
762+ project = project ,
763+ spec = spec ,
764+ ssh_params = spec .configuration .ssh_config ,
765+ env = spec .configuration .env ,
766+ instance_num = instance_num ,
767+ host = host ,
768+ )
769+ fleet_model .instances .append (instance_model )
770+ active_instance_nums .add (instance_num )
771+ if removed_instance_nums :
772+ _terminate_fleet_instances (fleet_model , removed_instance_nums )
773+
774+ await session .commit ()
775+ return fleet_model_to_fleet (fleet_model )
776+
777+
737778def _can_update_fleet_spec (current_fleet_spec : FleetSpec , new_fleet_spec : FleetSpec ) -> bool :
738779 try :
739780 _check_can_update_fleet_spec (current_fleet_spec , new_fleet_spec )
0 commit comments