Skip to content

Commit c62edb7

Browse files
committed
Signalling working, some other logic
1 parent 48e6ac1 commit c62edb7

9 files changed

Lines changed: 643 additions & 6 deletions

File tree

durabletask/client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import logging
55
import uuid
66
from dataclasses import dataclass
7-
from datetime import datetime
7+
from datetime import datetime, timezone
88
from enum import Enum
99
from typing import Any, Optional, Sequence, TypeVar, Union
1010

1111
import grpc
1212
from google.protobuf import wrappers_pb2
1313

14+
from durabletask.entities.entity_instance_id import EntityInstanceId
1415
import durabletask.internal.helpers as helpers
1516
import durabletask.internal.orchestrator_service_pb2 as pb
1617
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
@@ -227,3 +228,16 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True):
227228
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
228229
self._logger.info(f"Purging instance '{instance_id}'.")
229230
self._stub.PurgeInstances(req)
231+
232+
def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None, signal_entity_options=None, cancellation=None):
233+
scheduled_time = signal_entity_options.scheduled_time if signal_entity_options and signal_entity_options.scheduled_time else None
234+
req = pb.SignalEntityRequest(
235+
instanceId=str(entity_instance_id),
236+
requestId=str(uuid.uuid4()),
237+
name=operation_name,
238+
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
239+
scheduledTime=scheduled_time,
240+
requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
241+
)
242+
self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
243+
self._stub.SignalEntity(req, timeout=cancellation.timeout if cancellation else None)

durabletask/entities/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Durable Task SDK for Python entities component"""
5+
6+
from durabletask.entities.entity_instance_id import EntityInstanceId
7+
8+
__all__ = ["EntityInstanceId"]
9+
10+
PACKAGE_NAME = "durabletask.entities"
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Optional
2+
3+
4+
class EntityInstanceId:
5+
def __init__(self, entity: str, key: str):
6+
self.entity = entity
7+
self.key = key
8+
9+
def __str__(self) -> str:
10+
return f"@{self.entity}@{self.key}"
11+
12+
def __eq__(self, other):
13+
if not isinstance(other, EntityInstanceId):
14+
return False
15+
return self.entity == other.entity and self.key == other.key
16+
17+
def __lt__(self, other):
18+
if not isinstance(other, EntityInstanceId):
19+
return self < other
20+
return str(self) < str(other)
21+
22+
@staticmethod
23+
def parse(entity_id: str) -> Optional["EntityInstanceId"]:
24+
try:
25+
_, entity, key = entity_id.split("@", 2)
26+
return EntityInstanceId(entity=entity, key=key)
27+
except ValueError as ex:
28+
raise ValueError("Invalid entity ID format", ex)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from durabletask.entities.entity_instance_id import EntityInstanceId
2+
3+
4+
class EntityLockReleaser:
5+
def __init__(self, entities: list[EntityInstanceId]):
6+
self.entities = entities
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional, Type
2+
3+
4+
class StateShim:
5+
def __init__(self, start_state):
6+
self._current_state = start_state
7+
self._checkpoint_state = start_state
8+
9+
def get_state(self, intended_type: Optional[Type]):
10+
if not intended_type:
11+
return self._current_state
12+
if isinstance(self._current_state, intended_type) or self._current_state is None:
13+
return self._current_state
14+
return intended_type(self._current_state)
15+
16+
def set_state(self, state):
17+
self._current_state = state
18+
19+
def commit(self):
20+
self._checkpoint_state = self._current_state
21+
22+
def rollback(self):
23+
self._current_state = self._checkpoint_state
24+
25+
def reset(self):
26+
self._current_state = None

durabletask/internal/helpers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
159159
return wrappers_pb2.StringValue(value=val)
160160

161161

162+
def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
163+
if val is None:
164+
return wrappers_pb2.StringValue(value="")
165+
return wrappers_pb2.StringValue(value=val)
166+
167+
162168
def new_complete_orchestration_action(
163169
id: int,
164170
status: pb.OrchestrationStatus,
@@ -189,6 +195,52 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str],
189195
))
190196

191197

198+
def new_call_entity_action(id: int, name: str, encoded_input: Optional[str]):
199+
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent(
200+
requestId=None,
201+
targetInstanceId=get_string_value(name),
202+
input=get_string_value(encoded_input)
203+
)))
204+
205+
206+
def new_signal_entity_action(id: int, name: str):
207+
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent(
208+
requestId=None,
209+
targetInstanceId=get_string_value(name)
210+
)))
211+
212+
213+
def new_lock_entities_action(id: int, instance_id: str, critical_section_id: str, entity_ids: list[str]):
214+
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent(
215+
parentInstanceId=get_string_value(instance_id),
216+
criticalSectionId=critical_section_id,
217+
lockSet=entity_ids,
218+
position=0
219+
)))
220+
221+
222+
def convert_to_entity_batch_request(req: pb.EntityRequest) -> tuple[pb.EntityBatchRequest, list[pb.OperationInfo]]:
223+
batch_request = pb.EntityBatchRequest(entityState=req.entityState, instanceId=req.instanceId, operations=[])
224+
225+
operation_infos: list[pb.OperationInfo] = []
226+
227+
for op in req.operationRequests:
228+
if op.HasField("entityOperationSignaled"):
229+
batch_request.operations.append(pb.OperationRequest(requestId=op.entityOperationSignaled.requestId,
230+
operation=op.entityOperationSignaled.operation,
231+
input=op.entityOperationSignaled.input))
232+
operation_infos.append(pb.OperationInfo(requestId=op.entityOperationSignaled.requestId,
233+
responseDestination=None))
234+
elif op.HasField("entityOperationCalled"):
235+
batch_request.operations.append(pb.OperationRequest(requestId=op.entityOperationCalled.requestId,
236+
operation=op.entityOperationCalled.operation,
237+
input=op.entityOperationCalled.input))
238+
operation_infos.append(pb.OperationInfo(requestId=op.entityOperationCalled.requestId,
239+
responseDestination=None))
240+
241+
return batch_request, operation_infos
242+
243+
192244
def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp:
193245
ts = timestamp_pb2.Timestamp()
194246
ts.FromDatetime(dt)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from datetime import datetime
2+
from typing import Generator, List, Optional, Tuple
3+
4+
from durabletask.entities.entity_instance_id import EntityInstanceId
5+
6+
7+
class OrchestrationEntityContext:
8+
def __init__(self, instance_id: str):
9+
self.instance_id = instance_id
10+
11+
self.lock_acquisition_pending = False
12+
13+
self.critical_section_id = None
14+
self.critical_section_locks = []
15+
self.available_locks = []
16+
17+
@property
18+
def is_inside_critical_section(self) -> bool:
19+
return self.critical_section_id is not None
20+
21+
def get_available_entities(self) -> Generator[str, None, None]:
22+
if self.is_inside_critical_section:
23+
for available_lock in self.available_locks:
24+
yield available_lock
25+
26+
def validate_suborchestration_transition(self) -> Tuple[bool, str]:
27+
if self.is_inside_critical_section:
28+
return False, "While holding locks, cannot call suborchestrators."
29+
return True, ""
30+
31+
def validate_operation_transition(self, target_instance_id: EntityInstanceId, one_way: bool) -> Tuple[bool, str]:
32+
if self.is_inside_critical_section:
33+
lock_to_use = target_instance_id
34+
if one_way:
35+
if target_instance_id in self.critical_section_locks:
36+
return False, "Must not signal a locked entity from a critical section."
37+
else:
38+
try:
39+
self.available_locks.remove(lock_to_use)
40+
except ValueError:
41+
if self.lock_acquisition_pending:
42+
return False, "Must await the completion of the lock request prior to calling any entity."
43+
if lock_to_use in self.critical_section_locks:
44+
return False, "Must not call an entity from a critical section while a prior call to the same entity is still pending."
45+
else:
46+
return False, "Must not call an entity from a critical section if it is not one of the locked entities."
47+
return True, ""
48+
49+
def validate_acquire_transition(self) -> Tuple[bool, str]:
50+
if self.is_inside_critical_section:
51+
return False, "Must not enter another critical section from within a critical section."
52+
return True, ""
53+
54+
def recover_lock_after_call(self, target_instance_id: EntityInstanceId):
55+
if self.is_inside_critical_section:
56+
self.available_locks.append(target_instance_id)
57+
58+
def emit_lock_release_messages(self):
59+
raise NotImplementedError()
60+
61+
def emit_request_message(self, target, operation_name: str, one_way: bool, operation_id: str,
62+
scheduled_time_utc: datetime, input: Optional[str],
63+
request_time: Optional[datetime] = None, create_trace: bool = False):
64+
raise NotImplementedError()
65+
66+
def emit_acquire_message(self, lock_request_id: str, entities: List[str]):
67+
raise NotImplementedError()
68+
69+
def complete_acquire(self, result, critical_section_id):
70+
# TODO: HashSet or equivalent
71+
self.available_locks = self.critical_section_locks
72+
self.lock_acquisition_pending = False
73+
74+
def adjust_outgoing_message(self, instance_id: str, request_message, capped_time: datetime) -> str:
75+
raise NotImplementedError()
76+
77+
def deserialize_entity_response_event(self, event_content: str):
78+
raise NotImplementedError()

durabletask/task.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
import math
88
from abc import ABC, abstractmethod
99
from datetime import datetime, timedelta
10-
from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
10+
from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union
1111

12+
from durabletask.entities.entity_instance_id import EntityInstanceId
13+
from durabletask.internal.entity_lock_releaser import EntityLockReleaser
14+
from durabletask.internal.entity_state_shim import StateShim
1215
import durabletask.internal.helpers as pbh
1316
import durabletask.internal.orchestrator_service_pb2 as pb
1417

@@ -137,6 +140,55 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
137140
"""
138141
pass
139142

143+
@abstractmethod
144+
def call_entity(self, entity: EntityInstanceId, *,
145+
input: Optional[TInput] = None):
146+
"""Schedule entity function for execution.
147+
148+
Parameters
149+
----------
150+
entity: EntityInstanceId
151+
The ID of the entity instance to call.
152+
input: Optional[TInput]
153+
The optional JSON-serializable input to pass to the entity function.
154+
155+
Returns
156+
-------
157+
Task
158+
A Durable Task that completes when the called entity function completes or fails.
159+
"""
160+
pass
161+
162+
@abstractmethod
163+
def signal_entity(
164+
self,
165+
entity_id: EntityInstanceId
166+
) -> None:
167+
"""Signal an entity function for execution.
168+
169+
Parameters
170+
----------
171+
entity_id: EntityInstanceId
172+
The ID of the entity instance to signal.
173+
"""
174+
pass
175+
176+
@abstractmethod
177+
def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser:
178+
"""Lock the specified entity instances for the duration of the orchestration.
179+
180+
Parameters
181+
----------
182+
entities: list[EntityInstanceId]
183+
The list of entity instance IDs to lock.
184+
185+
Returns
186+
-------
187+
EntityLockReleaser
188+
A context manager that releases the locks when disposed.
189+
"""
190+
pass
191+
140192
@abstractmethod
141193
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
142194
input: Optional[TInput] = None,
@@ -452,12 +504,65 @@ def task_id(self) -> int:
452504
return self._task_id
453505

454506

507+
class EntityContext:
508+
def __init__(self, orchestration_id: str, operation: str, state: StateShim, entity_id: EntityInstanceId):
509+
self._orchestration_id = orchestration_id
510+
self._operation = operation
511+
self._state = state
512+
self._entity_id = entity_id
513+
514+
@property
515+
def orchestration_id(self) -> str:
516+
"""Get the ID of the orchestration instance that scheduled this entity.
517+
518+
Returns
519+
-------
520+
str
521+
The ID of the current orchestration instance.
522+
"""
523+
return self._orchestration_id
524+
525+
@property
526+
def operation(self) -> str:
527+
"""Get the operation associated with this entity invocation.
528+
529+
The operation is a string that identifies the specific action being
530+
performed on the entity. It can be used to distinguish between
531+
multiple operations that are part of the same entity invocation.
532+
533+
Returns
534+
-------
535+
str
536+
The operation associated with this entity invocation.
537+
"""
538+
return self._operation
539+
540+
def get_state(self, intended_type: Optional[Type] = None):
541+
return self._state.get_state(intended_type)
542+
543+
def set_state(self, new_state):
544+
self._state.set_state(new_state)
545+
546+
@property
547+
def entity_id(self) -> EntityInstanceId:
548+
"""Get the ID of the entity instance.
549+
550+
Returns
551+
-------
552+
str
553+
The ID of the current entity instance.
554+
"""
555+
return self._entity_id
556+
557+
455558
# Orchestrators are generators that yield tasks and receive/return any type
456559
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]]
457560

458561
# Activities are simple functions that can be scheduled by orchestrators
459562
Activity = Callable[[ActivityContext, TInput], TOutput]
460563

564+
Entity = Callable[[EntityContext, TInput], TOutput]
565+
461566

462567
class RetryPolicy:
463568
"""Represents the retry policy for an orchestration or activity function."""

0 commit comments

Comments
 (0)