@@ -13,14 +13,14 @@ def __init__(self, instance_id: str):
1313 self .lock_acquisition_pending = False
1414
1515 self .critical_section_id = None
16- self .critical_section_locks = []
17- self .available_locks = []
16+ self .critical_section_locks : list [ EntityInstanceId ] = []
17+ self .available_locks : list [ EntityInstanceId ] = []
1818
1919 @property
2020 def is_inside_critical_section (self ) -> bool :
2121 return self .critical_section_id is not None
2222
23- def get_available_entities (self ) -> Generator [str , None , None ]:
23+ def get_available_entities (self ) -> Generator [EntityInstanceId , None , None ]:
2424 if self .is_inside_critical_section :
2525 for available_lock in self .available_locks :
2626 yield available_lock
@@ -58,16 +58,27 @@ def recover_lock_after_call(self, target_instance_id: EntityInstanceId):
5858 self .available_locks .append (target_instance_id )
5959
6060 def emit_lock_release_messages (self ):
61- raise NotImplementedError ()
61+ if self .is_inside_critical_section :
62+ for entity_id in self .critical_section_locks :
63+ unlock_event = pb .SendEntityMessageAction (entityUnlockSent = pb .EntityUnlockSentEvent (
64+ criticalSectionId = self .critical_section_id ,
65+ targetInstanceId = get_string_value (str (entity_id ))
66+ ))
67+ yield unlock_event
68+
69+ # TODO: Emit the actual release messages (?)
70+ self .critical_section_locks = []
71+ self .available_locks = []
72+ self .critical_section_id = None
6273
6374 def emit_request_message (self , target , operation_name : str , one_way : bool , operation_id : str ,
6475 scheduled_time_utc : datetime , input : Optional [str ],
6576 request_time : Optional [datetime ] = None , create_trace : bool = False ):
6677 raise NotImplementedError ()
6778
68- def emit_acquire_message (self , critical_section_id : str , entities : List [EntityInstanceId ]) -> Union [Tuple [None , None , None ], Tuple [str , pb .SendEntityMessageAction , pb .OrchestrationInstance ]]:
79+ def emit_acquire_message (self , critical_section_id : str , entities : List [EntityInstanceId ]) -> Union [Tuple [None , None ], Tuple [pb .SendEntityMessageAction , pb .OrchestrationInstance ]]:
6980 if not entities :
70- return None , None , None
81+ return None , None
7182
7283 # Acquire the locks in a globally fixed order to avoid deadlocks
7384 # Also remove duplicates - this can be optimized for perf if necessary
@@ -81,12 +92,15 @@ def emit_acquire_message(self, critical_section_id: str, entities: List[EntityIn
8192 request = pb .SendEntityMessageAction (entityLockRequested = pb .EntityLockRequestedEvent (
8293 criticalSectionId = critical_section_id ,
8394 parentInstanceId = get_string_value (self .instance_id ),
84- lockSet = entity_ids_dedup ,
95+ lockSet = [ str ( eid ) for eid in entity_ids_dedup ] ,
8596 position = 0 ,
8697 ))
8798
88- return "op" , request , target
89-
99+ self .critical_section_id = critical_section_id
100+ self .critical_section_locks = entity_ids_dedup
101+ self .lock_acquisition_pending = True
102+
103+ return request , target
90104
91105 def complete_acquire (self , result , critical_section_id ):
92106 # TODO: HashSet or equivalent
0 commit comments