Skip to content

Commit 4107182

Browse files
committed
More entity implementing
1 parent e5b6db2 commit 4107182

7 files changed

Lines changed: 169 additions & 66 deletions

File tree

durabletask/entities/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""Durable Task SDK for Python entities component"""
55

66
from durabletask.entities.entity_instance_id import EntityInstanceId
7+
from durabletask.entities.durable_entity import DurableEntity
8+
from durabletask.entities.entity_lock import EntityLock
79

8-
__all__ = ["EntityInstanceId"]
10+
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock"]
911

1012
PACKAGE_NAME = "durabletask.entities"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Any, Optional, Type, TypeVar, overload
2+
3+
TState = TypeVar("TState")
4+
5+
6+
class DurableEntity:
7+
def _initialize_entity_context(self, context):
8+
self.entity_context = context
9+
10+
@overload
11+
def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ...
12+
13+
@overload
14+
def get_state(self, intended_type: None = None) -> Any: ...
15+
16+
def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any:
17+
return self.entity_context.get_state(intended_type)
18+
19+
def set_state(self, state: Any):
20+
self.entity_context.set_state(state)
Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
import durabletask.internal.helpers as ph
2+
13
from durabletask.entities.entity_instance_id import EntityInstanceId
2-
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
4+
import durabletask.internal.orchestrator_service_pb2 as pb
35

46

57
class EntityLock:
6-
def __init__(self, entity_context: OrchestrationEntityContext, entities: list[EntityInstanceId]):
7-
self.entity_context = entity_context
8-
self.entities = entities
8+
def __init__(self, context):
9+
self._context = context
910

1011
def __enter__(self):
11-
print(f"Locking entities: {self.entities}")
12+
return self
1213

13-
def __exit__(self, exc_type, exc_val, exc_tb):
14-
print(f"Unlocking entities: {self.entities}")
14+
def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions?
15+
print(f"Unlocking entities: {self._context._entity_context.critical_section_locks}")
16+
for entity_unlock_message in self._context._entity_context.emit_lock_release_messages():
17+
task_id = self._context.next_sequence_number()
18+
action = pb.OrchestratorAction(task_id, sendEntityMessage=entity_unlock_message)
19+
self._context._pending_actions[task_id] = action

durabletask/internal/entity_state_shim.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
1-
from typing import Optional, Type
1+
from ctypes import Union
2+
from typing import Any, TypeVar, runtime_checkable
3+
from typing import Optional, Type, overload
4+
from typing_extensions import Protocol
5+
6+
7+
TState = TypeVar("TState")
28

39

410
class StateShim:
511
def __init__(self, start_state):
6-
self._current_state = start_state
7-
self._checkpoint_state = start_state
12+
self._current_state: Any = start_state
13+
self._checkpoint_state: Any = start_state
14+
15+
@overload
16+
def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ...
17+
18+
@overload
19+
def get_state(self, intended_type: None = None) -> Any: ...
820

9-
def get_state(self, intended_type: Optional[Type]):
10-
if not intended_type:
21+
def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any:
22+
if intended_type is None:
1123
return self._current_state
24+
1225
if isinstance(self._current_state, intended_type) or self._current_state is None:
1326
return self._current_state
14-
return intended_type(self._current_state)
27+
28+
try:
29+
return intended_type(self._current_state) # type: ignore[call-arg]
30+
except Exception as ex:
31+
raise TypeError(
32+
f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'"
33+
) from ex
1534

1635
def set_state(self, state):
1736
self._current_state = state

durabletask/internal/orchestration_entity_context.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ def __init__(self, instance_id: str):
1313
self.lock_acquisition_pending = False
1414

1515
self.critical_section_id = None
16-
self.critical_section_locks = []
17-
self.available_locks = []
16+
self.critical_section_locks: list[EntityInstanceId] = []
17+
self.available_locks: list[EntityInstanceId] = []
1818

1919
@property
2020
def is_inside_critical_section(self) -> bool:
2121
return self.critical_section_id is not None
2222

23-
def get_available_entities(self) -> Generator[str, None, None]:
23+
def get_available_entities(self) -> Generator[EntityInstanceId, None, None]:
2424
if self.is_inside_critical_section:
2525
for available_lock in self.available_locks:
2626
yield available_lock
@@ -58,16 +58,27 @@ def recover_lock_after_call(self, target_instance_id: EntityInstanceId):
5858
self.available_locks.append(target_instance_id)
5959

6060
def emit_lock_release_messages(self):
61-
raise NotImplementedError()
61+
if self.is_inside_critical_section:
62+
for entity_id in self.critical_section_locks:
63+
unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent(
64+
criticalSectionId=self.critical_section_id,
65+
targetInstanceId=get_string_value(str(entity_id))
66+
))
67+
yield unlock_event
68+
69+
# TODO: Emit the actual release messages (?)
70+
self.critical_section_locks = []
71+
self.available_locks = []
72+
self.critical_section_id = None
6273

6374
def emit_request_message(self, target, operation_name: str, one_way: bool, operation_id: str,
6475
scheduled_time_utc: datetime, input: Optional[str],
6576
request_time: Optional[datetime] = None, create_trace: bool = False):
6677
raise NotImplementedError()
6778

68-
def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None, None], Tuple[str, pb.SendEntityMessageAction, pb.OrchestrationInstance]]:
79+
def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None], Tuple[pb.SendEntityMessageAction, pb.OrchestrationInstance]]:
6980
if not entities:
70-
return None, None, None
81+
return None, None
7182

7283
# Acquire the locks in a globally fixed order to avoid deadlocks
7384
# Also remove duplicates - this can be optimized for perf if necessary
@@ -81,12 +92,15 @@ def emit_acquire_message(self, critical_section_id: str, entities: List[EntityIn
8192
request = pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent(
8293
criticalSectionId=critical_section_id,
8394
parentInstanceId=get_string_value(self.instance_id),
84-
lockSet=entity_ids_dedup,
95+
lockSet=[str(eid) for eid in entity_ids_dedup],
8596
position=0,
8697
))
8798

88-
return "op", request, target
89-
99+
self.critical_section_id = critical_section_id
100+
self.critical_section_locks = entity_ids_dedup
101+
self.lock_acquisition_pending = True
102+
103+
return request, target
90104

91105
def complete_acquire(self, result, critical_section_id):
92106
# TODO: HashSet or equivalent

durabletask/task.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
import math
88
from abc import ABC, abstractmethod
99
from datetime import datetime, timedelta
10-
from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union
10+
from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union, overload
1111

12-
from durabletask.entities.entity_instance_id import EntityInstanceId
13-
from durabletask.entities.entity_lock import EntityLock
12+
from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock
1413
from durabletask.internal.entity_state_shim import StateShim
1514
import durabletask.internal.helpers as pbh
1615
import durabletask.internal.orchestrator_service_pb2 as pb
1716

1817
T = TypeVar('T')
1918
TInput = TypeVar('TInput')
2019
TOutput = TypeVar('TOutput')
20+
TState = TypeVar("TState")
2121

2222

2323
class OrchestrationContext(ABC):
@@ -545,11 +545,17 @@ def operation(self) -> str:
545545
The operation associated with this entity invocation.
546546
"""
547547
return self._operation
548-
549-
def get_state(self, intended_type: Optional[Type] = None):
548+
549+
@overload
550+
def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ...
551+
552+
@overload
553+
def get_state(self, intended_type: None = None) -> Any: ...
554+
555+
def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any:
550556
return self._state.get_state(intended_type)
551557

552-
def set_state(self, new_state):
558+
def set_state(self, new_state: Any):
553559
self._state.set_state(new_state)
554560

555561
@property
@@ -570,7 +576,7 @@ def entity_id(self) -> EntityInstanceId:
570576
# Activities are simple functions that can be scheduled by orchestrators
571577
Activity = Callable[[ActivityContext, TInput], TOutput]
572578

573-
Entity = Callable[[EntityContext, TInput], TOutput]
579+
Entity = Union[Callable[[EntityContext, TInput], TOutput], type[DurableEntity]]
574580

575581

576582
class RetryPolicy:

0 commit comments

Comments
 (0)