Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 4eb8a9f

Browse files
committed
share validating grpc options
Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com>
1 parent ed03c16 commit 4eb8a9f

5 files changed

Lines changed: 89 additions & 82 deletions

File tree

durabletask/aio/internal/shared.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright (c) The Dapr Authors.
22
# Licensed under the MIT License.
33

4-
from typing import Any, Optional, Sequence, Union
4+
from typing import Any, Dict, Optional, Sequence, Union
55

66
import grpc
77
from grpc import aio as grpc_aio
8+
from grpc.aio import ChannelArgumentType
89

910
from durabletask.internal.shared import (
1011
INSECURE_PROTOCOLS,
1112
SECURE_PROTOCOLS,
1213
get_default_host_address,
14+
validate_grpc_options,
1315
)
1416

1517
ClientInterceptor = Union[
@@ -24,7 +26,7 @@ def get_grpc_aio_channel(
2426
host_address: Optional[str],
2527
secure_channel: bool = False,
2628
interceptors: Optional[Sequence[ClientInterceptor]] = None,
27-
options: Optional[Sequence[tuple[str, Any]]] = None,
29+
options: Optional[ChannelArgumentType] = None,
2830
) -> grpc_aio.Channel:
2931
"""create a grpc asyncio channel
3032
@@ -50,13 +52,9 @@ def get_grpc_aio_channel(
5052
break
5153

5254
# channel interceptors/options
53-
channel_kwargs = dict(interceptors=interceptors)
55+
channel_kwargs: Dict[str, ChannelArgumentType | Sequence[ClientInterceptor]] = dict(interceptors=interceptors)
5456
if options is not None:
55-
# validate all options keys prefix starts with `grpc.`
56-
if not all(key.startswith('grpc.') for key, _ in options):
57-
raise ValueError(
58-
f'All options keys must start with `grpc.`. Invalid options: {options}'
59-
)
57+
validate_grpc_options(options)
6058
channel_kwargs["options"] = options
6159

6260
if secure_channel:

durabletask/internal/shared.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Optional, Sequence, Union
1010

1111
import grpc
12+
from grpc.aio import ChannelArgumentType
1213

1314
ClientInterceptor = Union[
1415
grpc.UnaryUnaryClientInterceptor,
@@ -50,6 +51,17 @@ def get_default_host_address() -> str:
5051
return "localhost:4001"
5152

5253

54+
def validate_grpc_options(options: ChannelArgumentType):
55+
"""Validate that all gRPC options are valid. Mainly checking keys. Values can be string, int, float, bool and pointer"""
56+
for key, value in options:
57+
if not isinstance(key, str):
58+
raise ValueError(f"gRPC option key must be a string. Invalid key: {key}")
59+
if not all(key.startswith("grpc.") for key, _ in options):
60+
raise ValueError(
61+
f"All options keys must start with `grpc.`. Invalid options: {options}"
62+
)
63+
64+
5365
def get_grpc_channel(
5466
host_address: Optional[str],
5567
secure_channel: bool = False,
@@ -84,10 +96,7 @@ def get_grpc_channel(
8496
# Create the base channel
8597
if options is not None:
8698
# validate all options keys prefix starts with `grpc.`
87-
if not all(key.startswith('grpc.') for key, _ in options):
88-
raise ValueError(
89-
f'All options keys must start with `grpc.`. Invalid options: {options}'
90-
)
99+
validate_grpc_options(options)
91100
if secure_channel:
92101
channel = grpc.secure_channel(
93102
host_address, grpc.ssl_channel_credentials(), options=options

tests/durabletask/test_grpc_aio_channel_options.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,34 @@
55

66
from durabletask.aio.internal.shared import get_grpc_aio_channel
77

8-
HOST_ADDRESS = 'localhost:50051'
8+
HOST_ADDRESS = "localhost:50051"
99

1010

1111
def _find_option(options, key):
1212
for k, v in options:
1313
if k == key:
1414
return v
15-
raise AssertionError(f'Option with key {key} not found in options: {options}')
15+
raise AssertionError(f"Option with key {key} not found in options: {options}")
1616

1717

1818
def test_aio_channel_passes_base_options_and_max_lengths():
1919
base_options = [
20-
('grpc.max_send_message_length', 4321),
21-
('grpc.max_receive_message_length', 8765),
22-
('grpc.primary_user_agent', 'durabletask-aio-tests'),
20+
("grpc.max_send_message_length", 4321),
21+
("grpc.max_receive_message_length", 8765),
22+
("grpc.primary_user_agent", "durabletask-aio-tests"),
2323
]
24-
with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel:
24+
with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel:
2525
get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options)
2626
# Ensure called with options kwarg
2727
assert mock_channel.call_count == 1
2828
args, kwargs = mock_channel.call_args
2929
assert args[0] == HOST_ADDRESS
30-
assert 'options' in kwargs
31-
opts = kwargs['options']
30+
assert "options" in kwargs
31+
opts = kwargs["options"]
3232
# Check our base options made it through
33-
assert ('grpc.max_send_message_length', 4321) in opts
34-
assert ('grpc.max_receive_message_length', 8765) in opts
35-
assert ('grpc.primary_user_agent', 'durabletask-aio-tests') in opts
33+
assert ("grpc.max_send_message_length", 4321) in opts
34+
assert ("grpc.max_receive_message_length", 8765) in opts
35+
assert ("grpc.primary_user_agent", "durabletask-aio-tests") in opts
3636

3737

3838
def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch):
@@ -42,53 +42,53 @@ def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPa
4242
initial_backoff_ms = 250
4343
max_backoff_ms = 2000
4444
backoff_multiplier = 1.5
45-
codes = ['RESOURCE_EXHAUSTED']
45+
codes = ["RESOURCE_EXHAUSTED"]
4646
service_config = {
47-
'methodConfig': [
47+
"methodConfig": [
4848
{
49-
'name': [{'service': ''}], # match all services/methods
50-
'retryPolicy': {
51-
'maxAttempts': max_attempts,
52-
'initialBackoff': f'{initial_backoff_ms / 1000.0}s',
53-
'maxBackoff': f'{max_backoff_ms / 1000.0}s',
54-
'backoffMultiplier': backoff_multiplier,
55-
'retryableStatusCodes': codes,
49+
"name": [{"service": ""}], # match all services/methods
50+
"retryPolicy": {
51+
"maxAttempts": max_attempts,
52+
"initialBackoff": f"{initial_backoff_ms / 1000.0}s",
53+
"maxBackoff": f"{max_backoff_ms / 1000.0}s",
54+
"backoffMultiplier": backoff_multiplier,
55+
"retryableStatusCodes": codes,
5656
},
5757
}
5858
]
5959
}
6060

61-
base_options = [('grpc.service_config', json.dumps(service_config))]
61+
base_options = [("grpc.service_config", json.dumps(service_config))]
6262

63-
with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel:
63+
with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel:
6464
get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options)
6565

6666
args, kwargs = mock_channel.call_args
6767
assert args[0] == HOST_ADDRESS
68-
assert 'options' in kwargs
69-
opts = kwargs['options']
68+
assert "options" in kwargs
69+
opts = kwargs["options"]
7070

7171
# Retry service config present and parses correctly
72-
svc_cfg_str = _find_option(opts, 'grpc.service_config')
72+
svc_cfg_str = _find_option(opts, "grpc.service_config")
7373
svc_cfg = json.loads(svc_cfg_str)
74-
assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list)
75-
retry_policy = svc_cfg['methodConfig'][0]['retryPolicy']
76-
assert retry_policy['maxAttempts'] == 4
77-
assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s'
78-
assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s'
79-
assert retry_policy['backoffMultiplier'] == 1.5
74+
assert "methodConfig" in svc_cfg and isinstance(svc_cfg["methodConfig"], list)
75+
retry_policy = svc_cfg["methodConfig"][0]["retryPolicy"]
76+
assert retry_policy["maxAttempts"] == 4
77+
assert retry_policy["initialBackoff"] == f"{250 / 1000.0}s"
78+
assert retry_policy["maxBackoff"] == f"{2000 / 1000.0}s"
79+
assert retry_policy["backoffMultiplier"] == 1.5
8080
# Codes are upper-cased list
81-
assert 'RESOURCE_EXHAUSTED' in retry_policy['retryableStatusCodes']
81+
assert "RESOURCE_EXHAUSTED" in retry_policy["retryableStatusCodes"]
8282

8383

8484
def test_aio_secure_channel_receives_options_when_secure_true():
85-
base_options = [('grpc.max_receive_message_length', 999999)]
85+
base_options = [("grpc.max_receive_message_length", 999999)]
8686
with (
87-
patch('durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_channel,
88-
patch('grpc.ssl_channel_credentials') as mock_credentials,
87+
patch("durabletask.aio.internal.shared.grpc_aio.secure_channel") as mock_channel,
88+
patch("grpc.ssl_channel_credentials") as mock_credentials,
8989
):
9090
get_grpc_aio_channel(HOST_ADDRESS, True, options=base_options)
9191
args, kwargs = mock_channel.call_args
9292
assert args[0] == HOST_ADDRESS
9393
assert args[1] == mock_credentials.return_value
94-
assert ('grpc.max_receive_message_length', 999999) in kwargs.get('options', [])
94+
assert ("grpc.max_receive_message_length", 999999) in kwargs.get("options", [])
Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
11
import json
2-
from unittest.mock import ANY, patch
2+
from unittest.mock import patch
33

44
import pytest
55

66
from durabletask.internal.shared import get_grpc_channel
77

8-
HOST_ADDRESS = 'localhost:50051'
8+
HOST_ADDRESS = "localhost:50051"
99

1010

1111
def _find_option(options, key):
1212
for k, v in options:
1313
if k == key:
1414
return v
15-
raise AssertionError(f'Option with key {key} not found in options: {options}')
15+
raise AssertionError(f"Option with key {key} not found in options: {options}")
1616

1717

1818
def test_sync_channel_passes_base_options_and_max_lengths():
1919
base_options = [
20-
('grpc.max_send_message_length', 1234),
21-
('grpc.max_receive_message_length', 5678),
22-
('grpc.primary_user_agent', 'durabletask-tests'),
20+
("grpc.max_send_message_length", 1234),
21+
("grpc.max_receive_message_length", 5678),
22+
("grpc.primary_user_agent", "durabletask-tests"),
2323
]
24-
with patch('grpc.insecure_channel') as mock_channel:
24+
with patch("grpc.insecure_channel") as mock_channel:
2525
get_grpc_channel(HOST_ADDRESS, False, options=base_options)
2626
# Ensure called with options kwarg
2727
assert mock_channel.call_count == 1
2828
args, kwargs = mock_channel.call_args
2929
assert args[0] == HOST_ADDRESS
30-
assert 'options' in kwargs
31-
opts = kwargs['options']
30+
assert "options" in kwargs
31+
opts = kwargs["options"]
3232
# Check our base options made it through
33-
assert ('grpc.max_send_message_length', 1234) in opts
34-
assert ('grpc.max_receive_message_length', 5678) in opts
35-
assert ('grpc.primary_user_agent', 'durabletask-tests') in opts
33+
assert ("grpc.max_send_message_length", 1234) in opts
34+
assert ("grpc.max_receive_message_length", 5678) in opts
35+
assert ("grpc.primary_user_agent", "durabletask-tests") in opts
3636

3737

3838
def test_sync_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch):
@@ -42,40 +42,40 @@ def test_sync_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyP
4242
initial_backoff_ms = 250
4343
max_backoff_ms = 2000
4444
backoff_multiplier = 1.5
45-
codes = ['ABORTED']
45+
codes = ["ABORTED"]
4646
service_config = {
47-
'methodConfig': [
47+
"methodConfig": [
4848
{
49-
'name': [{'service': ''}], # match all services/methods
50-
'retryPolicy': {
51-
'maxAttempts': max_attempts,
52-
'initialBackoff': f'{initial_backoff_ms / 1000.0}s',
53-
'maxBackoff': f'{max_backoff_ms / 1000.0}s',
54-
'backoffMultiplier': backoff_multiplier,
55-
'retryableStatusCodes': codes,
49+
"name": [{"service": ""}], # match all services/methods
50+
"retryPolicy": {
51+
"maxAttempts": max_attempts,
52+
"initialBackoff": f"{initial_backoff_ms / 1000.0}s",
53+
"maxBackoff": f"{max_backoff_ms / 1000.0}s",
54+
"backoffMultiplier": backoff_multiplier,
55+
"retryableStatusCodes": codes,
5656
},
5757
}
5858
]
5959
}
6060

61-
base_options = [('grpc.service_config', json.dumps(service_config))]
61+
base_options = [("grpc.service_config", json.dumps(service_config))]
6262

63-
with patch('grpc.insecure_channel') as mock_channel:
63+
with patch("grpc.insecure_channel") as mock_channel:
6464
get_grpc_channel(HOST_ADDRESS, False, options=base_options)
6565

6666
args, kwargs = mock_channel.call_args
6767
assert args[0] == HOST_ADDRESS
68-
assert 'options' in kwargs
69-
opts = kwargs['options']
68+
assert "options" in kwargs
69+
opts = kwargs["options"]
7070

7171
# Retry service config present and parses correctly
72-
svc_cfg_str = _find_option(opts, 'grpc.service_config')
72+
svc_cfg_str = _find_option(opts, "grpc.service_config")
7373
svc_cfg = json.loads(svc_cfg_str)
74-
assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list)
75-
retry_policy = svc_cfg['methodConfig'][0]['retryPolicy']
76-
assert retry_policy['maxAttempts'] == 4
77-
assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s'
78-
assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s'
79-
assert retry_policy['backoffMultiplier'] == 1.5
74+
assert "methodConfig" in svc_cfg and isinstance(svc_cfg["methodConfig"], list)
75+
retry_policy = svc_cfg["methodConfig"][0]["retryPolicy"]
76+
assert retry_policy["maxAttempts"] == 4
77+
assert retry_policy["initialBackoff"] == f"{250 / 1000.0}s"
78+
assert retry_policy["maxBackoff"] == f"{2000 / 1000.0}s"
79+
assert retry_policy["backoffMultiplier"] == 1.5
8080
# Codes are upper-cased list
81-
assert 'ABORTED' in retry_policy['retryableStatusCodes']
81+
assert "ABORTED" in retry_policy["retryableStatusCodes"]

tests/durabletask/test_orchestration_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
3030
# set a custom max send length option
3131
c = client.TaskHubGrpcClient(
3232
channel_options=[
33-
('grpc.max_send_message_length', 1024 * 1024), # 1MB
33+
("grpc.max_send_message_length", 1024 * 1024), # 1MB
3434
]
3535
)
3636
id = c.schedule_new_orchestration(empty_orchestrator)

0 commit comments

Comments
 (0)