Skip to content

Commit 5ccb82f

Browse files
Add Middleware Support for OperationHandlers (#33)
* Add abstract base class and a no op implementation to enable task cancellation in operation handlers * fix some linter errors * Some PR feedback. Up min python version to 3.10 * Update some docs to more clearly highlight expected behavior of operation handlers and the potential race condition if but wait_until_.. and is_cancelled are used at the same time * Simple logging interceptor working with an InterceptedOperationHandler concept * Update test to confirm interceptors are applied in the order provided. Add test to confirm interceptors work for sync operation handlers * Do some renaming. Add some doc strings. remove type aliases that wound up not being very useful. Update sync test to force use of the executor. * Remove request_deadline as that's part of a different PR * remove some unused imports * Use public export in tests * Fix some linter errors * use cancellation in tests after rebasing to support new python * fix docstring errors * rename interceptor to middleware. Expose operation context to middleware * fix formatting and linter errors * Remove return repetitive types in OperationHandler.start. Make OperationHandlerMiddleware.intercept an abstract method. * Move deploy-docs to it's own workflow that runs on push to main * Fix workflow name in deploy-docs * export LazyValueT and Serializer from _serializer.py * remove the work 'docs' from the 'lint-test' job * Rename AwaitableOperationHandler to MiddlewareSafeOperationHandler * Run formatter * remove generic args in MiddlewareSafeOperationHandler since it by definition, must always be OperationHandler[Any,Any] * Finish removing generic args from MiddlewareSafeOperationHandler * Update old reference from 'interceptors' -> 'middleware' * Remove _all_ reference to interceptors
1 parent a55af1e commit 5ccb82f

8 files changed

Lines changed: 383 additions & 66 deletions

File tree

.github/workflows/ci.yml

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ name: CI
33
on:
44
pull_request:
55
push:
6-
branches: [ main ]
6+
branches:
7+
- main
78

89
jobs:
9-
lint-test-docs:
10+
lint-test:
1011
runs-on: ${{ matrix.os }}
1112
strategy:
1213
matrix:
13-
python-version: ['3.10', '3.13', '3.14']
14+
python-version: ['3.10', '3.14']
1415
os: [ubuntu-latest, macos-latest, windows-latest]
1516

1617
steps:
@@ -38,35 +39,3 @@ jobs:
3839
with:
3940
name: coverage-html-report-${{ matrix.os }}-${{ matrix.python-version }}
4041
path: coverage_html_report/
41-
42-
deploy-docs:
43-
runs-on: ubuntu-latest
44-
needs: lint-test-docs
45-
# TODO(preview): deploy on releases only
46-
permissions:
47-
contents: read
48-
pages: write
49-
id-token: write
50-
51-
steps:
52-
- name: Checkout repository
53-
uses: actions/checkout@v4
54-
55-
- name: Install uv
56-
uses: astral-sh/setup-uv@v6
57-
with:
58-
python-version: '3.10'
59-
60-
- name: Install dependencies
61-
run: uv sync
62-
63-
- name: Build API docs
64-
run: uv run poe docs
65-
66-
- name: Upload docs to GitHub Pages
67-
uses: actions/upload-pages-artifact@v3
68-
with:
69-
path: apidocs
70-
71-
- name: Deploy to GitHub Pages
72-
uses: actions/deploy-pages@v4

.github/workflows/deploy-docs.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Deploy Docs
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
8+
jobs:
9+
deploy-docs:
10+
runs-on: ubuntu-latest
11+
permissions:
12+
contents: read
13+
pages: write
14+
id-token: write
15+
16+
steps:
17+
- name: Checkout repository
18+
uses: actions/checkout@v4
19+
20+
- name: Install uv
21+
uses: astral-sh/setup-uv@v6
22+
with:
23+
python-version: '3.10'
24+
25+
- name: Install dependencies
26+
run: uv sync
27+
28+
- name: Build API docs
29+
run: uv run poe docs
30+
31+
- name: Upload docs to GitHub Pages
32+
uses: actions/upload-pages-artifact@v3
33+
with:
34+
path: apidocs
35+
36+
- name: Deploy to GitHub Pages
37+
uses: actions/deploy-pages@v4

src/nexusrpc/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
OperationErrorState,
2626
OutputT,
2727
)
28-
from ._serializer import Content, LazyValue
28+
from ._serializer import Content, LazyValue, LazyValueT, Serializer
2929
from ._service import Operation, OperationDefinition, ServiceDefinition, service
3030
from ._util import (
3131
get_operation,
@@ -42,12 +42,14 @@
4242
"HandlerErrorType",
4343
"InputT",
4444
"LazyValue",
45+
"LazyValueT",
4546
"Link",
4647
"Operation",
4748
"OperationDefinition",
4849
"OperationError",
4950
"OperationErrorState",
5051
"OutputT",
52+
"Serializer",
5153
"service",
5254
"ServiceDefinition",
5355
"set_operation",

src/nexusrpc/handler/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@
1818
StartOperationResultAsync,
1919
StartOperationResultSync,
2020
)
21-
from ._core import Handler as Handler
21+
from ._core import Handler, OperationHandlerMiddleware
2222
from ._decorators import operation_handler, service_handler, sync_operation
23-
from ._operation_handler import OperationHandler as OperationHandler
23+
from ._operation_handler import MiddlewareSafeOperationHandler, OperationHandler
2424

2525
__all__ = [
26+
"MiddlewareSafeOperationHandler",
2627
"CancelOperationContext",
2728
"Handler",
2829
"OperationContext",
2930
"OperationHandler",
3031
"OperationTaskCancellation",
32+
"OperationHandlerMiddleware",
33+
"operation_handler",
3134
"service_handler",
3235
"StartOperationContext",
3336
"StartOperationResultAsync",

src/nexusrpc/handler/_core.py

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
from abc import ABC, abstractmethod
103103
from collections.abc import Awaitable, Mapping, Sequence
104104
from dataclasses import dataclass
105-
from typing import Any, Callable, Optional, Union
105+
from typing import Any, Callable, Optional, Union, cast
106106

107107
from typing_extensions import Self, TypeGuard
108108

@@ -113,11 +113,13 @@
113113

114114
from ._common import (
115115
CancelOperationContext,
116+
OperationContext,
116117
StartOperationContext,
117118
StartOperationResultAsync,
118119
StartOperationResultSync,
119120
)
120121
from ._operation_handler import (
122+
MiddlewareSafeOperationHandler,
121123
OperationHandler,
122124
collect_operation_handler_factories_by_method_name,
123125
)
@@ -248,7 +250,9 @@ def __init__(
248250
self,
249251
user_service_handlers: Sequence[Any],
250252
executor: Optional[concurrent.futures.Executor] = None,
253+
middleware: Sequence[OperationHandlerMiddleware] | None = None,
251254
):
255+
self._middleware = cast(Sequence[OperationHandlerMiddleware], middleware or [])
252256
super().__init__(user_service_handlers, executor=executor)
253257
if not self.executor:
254258
self._validate_all_operation_handlers_are_async()
@@ -268,17 +272,11 @@ async def start_operation(
268272
input: The input to the operation, as a LazyValue.
269273
"""
270274
service_handler = self._get_service_handler(ctx.service)
271-
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
275+
op_handler = self._get_operation_handler(ctx, service_handler, ctx.operation)
276+
272277
op_defn = service_handler.service.operation_definitions[ctx.operation]
273278
deserialized_input = await input.consume(as_type=op_defn.input_type)
274-
# TODO(preview): apply middleware stack
275-
if is_async_callable(op_handler.start):
276-
return await op_handler.start(ctx, deserialized_input)
277-
else:
278-
assert self.executor
279-
return await self.executor.submit_to_event_loop(
280-
op_handler.start, ctx, deserialized_input
281-
)
279+
return await op_handler.start(ctx, deserialized_input)
282280

283281
async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> None:
284282
"""Handle a Cancel Operation request.
@@ -288,12 +286,23 @@ async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> Non
288286
token: The operation token.
289287
"""
290288
service_handler = self._get_service_handler(ctx.service)
291-
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
292-
if is_async_callable(op_handler.cancel):
293-
return await op_handler.cancel(ctx, token)
294-
else:
295-
assert self.executor
296-
return self.executor.submit(op_handler.cancel, ctx, token).result()
289+
op_handler = self._get_operation_handler(ctx, service_handler, ctx.operation)
290+
return await op_handler.cancel(ctx, token)
291+
292+
def _get_operation_handler(
293+
self, ctx: OperationContext, service_handler: ServiceHandler, operation: str
294+
) -> MiddlewareSafeOperationHandler:
295+
"""
296+
Get the specified handler for the specified operation from the given service_handler and apply all middleware.
297+
"""
298+
op_handler: MiddlewareSafeOperationHandler = _EnsuredAwaitableOperationHandler(
299+
self.executor, service_handler.get_operation_handler(operation)
300+
)
301+
302+
for middleware in reversed(self._middleware):
303+
op_handler = middleware.intercept(ctx, op_handler)
304+
305+
return op_handler
297306

298307
def _validate_all_operation_handlers_are_async(self) -> None:
299308
for service_handler in self.service_handlers.values():
@@ -360,7 +369,7 @@ def from_user_instance(cls, user_instance: Any) -> Self:
360369
operation_handlers=op_handlers,
361370
)
362371

363-
def _get_operation_handler(self, operation_name: str) -> OperationHandler[Any, Any]:
372+
def get_operation_handler(self, operation_name: str) -> OperationHandler[Any, Any]:
364373
"""Return an operation handler, given the operation name."""
365374
if operation_name not in self.service.operation_definitions:
366375
raise HandlerError(
@@ -401,3 +410,70 @@ def submit(
401410
self, fn: Callable[..., Any], *args: Any
402411
) -> concurrent.futures.Future[Any]:
403412
return self._executor.submit(fn, *args)
413+
414+
415+
class OperationHandlerMiddleware(ABC):
416+
"""
417+
Middleware for operation handlers.
418+
419+
This should be extended by any operation handler middelware.
420+
"""
421+
422+
@abstractmethod
423+
def intercept(
424+
self,
425+
ctx: OperationContext, # type: ignore[reportUnusedParameter]
426+
next: MiddlewareSafeOperationHandler,
427+
) -> MiddlewareSafeOperationHandler:
428+
"""
429+
Method called for intercepting operation handlers.
430+
431+
Args:
432+
ctx: The :py:class:`OperationContext` that will be passed to the operation handler.
433+
next: The underlying operation handler that this middleware
434+
should delegate to.
435+
436+
Returns:
437+
The new middleware that will be used to invoke
438+
:py:attr:`OperationHandler.start` or :py:attr:`OperationHandler.cancel`.
439+
"""
440+
...
441+
442+
443+
class _EnsuredAwaitableOperationHandler(MiddlewareSafeOperationHandler):
444+
"""
445+
An :py:class:`AwaitableOperationHandler` that wraps an :py:class:`OperationHandler` and uses an :py:class:`_Executor` to ensure
446+
that the :py:attr:`start` and :py:attr:`cancel` methods are awaitable.
447+
"""
448+
449+
def __init__(
450+
self,
451+
executor: _Executor | None,
452+
op_handler: OperationHandler[Any, Any],
453+
):
454+
self._executor = executor
455+
self._op_handler = op_handler
456+
457+
async def start(
458+
self, ctx: StartOperationContext, input: Any
459+
) -> StartOperationResultSync[Any] | StartOperationResultAsync:
460+
"""
461+
Start the operation using the wrapped :py:class:`OperationHandler`.
462+
"""
463+
if is_async_callable(self._op_handler.start):
464+
return await self._op_handler.start(ctx, input)
465+
else:
466+
assert self._executor
467+
return await self._executor.submit_to_event_loop(
468+
self._op_handler.start, ctx, input
469+
)
470+
471+
async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
472+
"""
473+
Cancel an operation using the wrapped :py:class:`OperationHandler`.
474+
"""
475+
if is_async_callable(self._op_handler.cancel):
476+
return await self._op_handler.cancel(ctx, token)
477+
else:
478+
assert self._executor
479+
return self._executor.submit(self._op_handler.cancel, ctx, token).result()

src/nexusrpc/handler/_operation_handler.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
from abc import ABC, abstractmethod
55
from collections.abc import Awaitable
6-
from typing import Any, Callable, Generic, Optional, Union
6+
from typing import Any, Callable, Generic, Optional
77

88
from nexusrpc._common import InputT, OutputT, ServiceHandlerT
99
from nexusrpc._service import Operation, OperationDefinition, ServiceDefinition
@@ -39,12 +39,11 @@ class OperationHandler(ABC, Generic[InputT, OutputT]):
3939
@abstractmethod
4040
def start(
4141
self, ctx: StartOperationContext, input: InputT
42-
) -> Union[
43-
StartOperationResultSync[OutputT],
44-
Awaitable[StartOperationResultSync[OutputT]],
45-
StartOperationResultAsync,
46-
Awaitable[StartOperationResultAsync],
47-
]:
42+
) -> (
43+
StartOperationResultSync[OutputT]
44+
| StartOperationResultAsync
45+
| Awaitable[StartOperationResultSync[OutputT] | StartOperationResultAsync]
46+
):
4847
"""
4948
Start the operation, completing either synchronously or asynchronously.
5049
@@ -54,9 +53,7 @@ def start(
5453
...
5554

5655
@abstractmethod
57-
def cancel(
58-
self, ctx: CancelOperationContext, token: str
59-
) -> Union[None, Awaitable[None]]:
56+
def cancel(self, ctx: CancelOperationContext, token: str) -> None | Awaitable[None]:
6057
"""
6158
Cancel the operation.
6259
"""
@@ -104,6 +101,31 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
104101
)
105102

106103

104+
class MiddlewareSafeOperationHandler(OperationHandler[Any, Any], ABC):
105+
"""
106+
An :py:class:`OperationHandler` where :py:attr:`start` and :py:attr:`cancel`
107+
can be awaited by an async runtime. It can produce a result synchronously by returning
108+
:py:class:`StartOperationResultSync` or asynchronously by returning :py:class:`StartOperationResultAsync`
109+
in the same fashion that :py:class:`OperationHandler` does.
110+
"""
111+
112+
@abstractmethod
113+
async def start(
114+
self, ctx: StartOperationContext, input: Any
115+
) -> StartOperationResultSync[Any] | StartOperationResultAsync:
116+
"""
117+
Start the operation and return it's result or an async token.
118+
"""
119+
...
120+
121+
@abstractmethod
122+
async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
123+
"""
124+
Cancel an in progress operation identified by the given token.
125+
"""
126+
...
127+
128+
107129
def collect_operation_handler_factories_by_method_name(
108130
user_service_cls: type[ServiceHandlerT],
109131
service: Optional[ServiceDefinition],

tests/handler/test_async_operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
OperationHandler,
1010
StartOperationContext,
1111
StartOperationResultAsync,
12+
operation_handler,
1213
service_handler,
1314
)
14-
from nexusrpc.handler._decorators import operation_handler
1515
from tests.helpers import DummySerializer, TestOperationTaskCancellation
1616

1717
_operation_results: dict[str, int] = {}

0 commit comments

Comments
 (0)