Skip to content

Commit e5b6db2

Browse files
committed
Entity lock incremental change
1 parent 9a74b6d commit e5b6db2

6 files changed

Lines changed: 58 additions & 37 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from durabletask.entities.entity_instance_id import EntityInstanceId
2+
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
3+
4+
5+
class EntityLock:
6+
def __init__(self, entity_context: OrchestrationEntityContext, entities: list[EntityInstanceId]):
7+
self.entity_context = entity_context
8+
self.entities = entities
9+
10+
def __enter__(self):
11+
print(f"Locking entities: {self.entities}")
12+
13+
def __exit__(self, exc_type, exc_val, exc_tb):
14+
print(f"Unlocking entities: {self.entities}")

durabletask/internal/entity_lock_releaser.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

durabletask/internal/helpers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,8 @@ def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: st
215215
)))
216216

217217

218-
def new_lock_entities_action(id: int, parent_instance_id: str, critical_section_id: str, entity_ids: list[EntityInstanceId]):
219-
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent(
220-
parentInstanceId=get_string_value(parent_instance_id),
221-
criticalSectionId=critical_section_id,
222-
lockSet=[str(eid) for eid in entity_ids],
223-
position=0
224-
)))
218+
def new_lock_entities_action(id: int, entity_message: pb.SendEntityMessageAction):
219+
return pb.OrchestratorAction(id=id, sendEntityMessage=entity_message)
225220

226221

227222
def convert_to_entity_batch_request(req: pb.EntityRequest) -> tuple[pb.EntityBatchRequest, list[pb.OperationInfo]]:

durabletask/internal/orchestration_entity_context.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from datetime import datetime
2-
from typing import Generator, List, Optional, Tuple
2+
from typing import Generator, List, Optional, Tuple, Union
33

4+
from durabletask.internal.helpers import get_string_value
5+
import durabletask.internal.orchestrator_service_pb2 as pb
46
from durabletask.entities.entity_instance_id import EntityInstanceId
57

68

@@ -63,8 +65,28 @@ def emit_request_message(self, target, operation_name: str, one_way: bool, opera
6365
request_time: Optional[datetime] = None, create_trace: bool = False):
6466
raise NotImplementedError()
6567

66-
def emit_acquire_message(self, lock_request_id: str, entities: List[str]):
67-
raise NotImplementedError()
68+
def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None, None], Tuple[str, pb.SendEntityMessageAction, pb.OrchestrationInstance]]:
69+
if not entities:
70+
return None, None, None
71+
72+
# Acquire the locks in a globally fixed order to avoid deadlocks
73+
# Also remove duplicates - this can be optimized for perf if necessary
74+
entity_ids = sorted(entities)
75+
entity_ids_dedup = []
76+
for i, entity_id in enumerate(entity_ids):
77+
if entity_id != entity_ids[i - 1] if i > 0 else True:
78+
entity_ids_dedup.append(entity_id)
79+
80+
target = pb.OrchestrationInstance(instanceId=str(entity_ids_dedup[0]))
81+
request = pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent(
82+
criticalSectionId=critical_section_id,
83+
parentInstanceId=get_string_value(self.instance_id),
84+
lockSet=entity_ids_dedup,
85+
position=0,
86+
))
87+
88+
return "op", request, target
89+
6890

6991
def complete_acquire(self, result, critical_section_id):
7092
# TODO: HashSet or equivalent

durabletask/task.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union
1111

1212
from durabletask.entities.entity_instance_id import EntityInstanceId
13-
from durabletask.internal.entity_lock_releaser import EntityLockReleaser
13+
from durabletask.entities.entity_lock import EntityLock
1414
from durabletask.internal.entity_state_shim import StateShim
1515
import durabletask.internal.helpers as pbh
1616
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -183,7 +183,7 @@ def signal_entity(
183183
pass
184184

185185
@abstractmethod
186-
def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser:
186+
def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock:
187187
"""Lock the specified entity instances for the duration of the orchestration.
188188
189189
Parameters
@@ -193,8 +193,8 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser:
193193
194194
Returns
195195
-------
196-
EntityLockReleaser
197-
A context manager that releases the locks when disposed.
196+
EntityLock
197+
A disposable object that acquires and releases the locks when initialized or disposed.
198198
"""
199199
pass
200200

durabletask/worker.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from enum import Enum
1515
from typing import Any, Generator, List, Optional, Sequence, TypeVar, Union
1616
import uuid
17+
from durabletask.entities.entity_lock import EntityLock
1718
from packaging.version import InvalidVersion, parse
1819

1920
import grpc
@@ -23,7 +24,6 @@
2324
from durabletask.internal.entity_state_shim import StateShim
2425
from durabletask.internal.helpers import new_timestamp
2526
from durabletask.entities.entity_instance_id import EntityInstanceId
26-
from durabletask.internal.entity_lock_releaser import EntityLockReleaser
2727
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
2828
import durabletask.internal.helpers as ph
2929
import durabletask.internal.exceptions as pe
@@ -994,7 +994,7 @@ def signal_entity(
994994
id, entity_id, operation, input
995995
)
996996

997-
def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser:
997+
def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock:
998998
id = self.next_sequence_number()
999999

10001000
self.lock_entities_function_helper(
@@ -1003,7 +1003,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser:
10031003

10041004
# Todo: EntityLockReleaser should be a disposable that uses python's using statement
10051005
# and should release the locks when disposed
1006-
return EntityLockReleaser(entities)
1006+
return EntityLock(self._entity_context, entities)
10071007

10081008
def call_sub_orchestrator(
10091009
self,
@@ -1129,26 +1129,22 @@ def lock_entities_function_helper(
11291129
id: Optional[int],
11301130
entity_ids: List[EntityInstanceId]
11311131
):
1132+
valid, message = self._entity_context.validate_acquire_transition()
1133+
if not valid:
1134+
raise RuntimeError(message)
1135+
11321136
if id is None:
11331137
id = self.next_sequence_number()
11341138

1135-
transition_valid, error_message = self._entity_context.validate_acquire_transition()
1136-
if not transition_valid:
1137-
raise RuntimeError(error_message)
1139+
# Use a deterministically replayable unique ID for this lock request
1140+
critical_section_id = f"{self.instance_id}:{id}"
11381141

1139-
# Acquire the locks in a globally fixed order to avoid deadlocks
1140-
# Also remove duplicates - this can be optimized for perf if necessary
1141-
entity_ids = sorted(entity_ids)
1142-
entity_ids_dedup = []
1143-
for i, entity_id in enumerate(entity_ids):
1144-
if entity_id != entity_ids[i - 1] if i > 0 else None:
1145-
entity_ids_dedup.append(entity_id)
1142+
event_name, request, target = self._entity_context.emit_acquire_message(critical_section_id, entity_ids)
11461143

1147-
# Use a deterministically replayable unique ID for this lock request
1148-
# TODO: Implement deterministically replayable IDs
1149-
critical_section_id = str(uuid.uuid4())
1144+
if not event_name or not request or not target:
1145+
raise RuntimeError("Failed to create entity lock request.")
11501146

1151-
action = ph.new_lock_entities_action(id, self.instance_id, critical_section_id, entity_ids_dedup)
1147+
action = ph.new_lock_entities_action(id, request)
11521148
self._pending_actions[id] = action
11531149

11541150
def wait_for_external_event(self, name: str) -> task.Task:

0 commit comments

Comments
 (0)