From 04a1c6d4334a4b5d618ae58ba6f84aca7ba5c635 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:46:03 -0800 Subject: [PATCH] Payload limit configuration and validation --- temporalio/bridge/worker.py | 7 +- temporalio/client.py | 2 +- temporalio/converter.py | 74 +++++- temporalio/exceptions.py | 25 ++ temporalio/worker/_activity.py | 13 + temporalio/worker/_workflow.py | 16 ++ tests/worker/test_workflow.py | 473 +++++++++++++++++++++++++++++++++ 7 files changed, 596 insertions(+), 14 deletions(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 732174732..c98afefca 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -315,7 +315,6 @@ async def encode_completion( encode_headers: bool, ) -> None: """Encode all payloads in the completion.""" - if data_converter._encode_payload_has_effect: - await CommandAwarePayloadVisitor( - skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + await CommandAwarePayloadVisitor( + skip_search_attributes=True, skip_headers=not encode_headers + ).visit(_Visitor(data_converter._encode_payload_sequence), completion) diff --git a/temporalio/client.py b/temporalio/client.py index 765f662fb..8c6877ad1 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6837,7 +6837,7 @@ async def _apply_headers( ) -> None: if source is None: return - if encode_headers and data_converter._encode_payload_has_effect: + if encode_headers: for payload in source.values(): payload.CopyFrom(await data_converter._encode_payload(payload)) temporalio.common._apply_headers(source, dest) diff --git a/temporalio/converter.py b/temporalio/converter.py index 66b31167c..fafb76260 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1207,6 +1207,20 @@ def __init__(self) -> None: super().__init__(encode_common_attributes=True) +@dataclass(frozen=True) +class PayloadLimitsConfig: + """Configuration for when payload sizes exceed limits.""" + + memo_upload_error_limit: int | None = None + """The limit at which a memo size error is created.""" + memo_upload_warning_limit: int = 2 * 1024 + """The limit at which a memo size warning is created.""" + payload_upload_error_limit: int | None = None + """The limit at which a payloads size error is created.""" + payload_upload_warning_limit: int = 512 * 1024 + """The limit at which a payloads size warning is created.""" + + @dataclass(frozen=True) class DataConverter(WithSerializationContext): """Data converter for converting and encoding payloads to/from Python values. @@ -1230,6 +1244,9 @@ class DataConverter(WithSerializationContext): failure_converter: FailureConverter = dataclasses.field(init=False) """Failure converter created from the :py:attr:`failure_converter_class`.""" + payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() + """Settings for payload size limits.""" + default: ClassVar[DataConverter] """Singleton default data converter.""" @@ -1367,35 +1384,42 @@ async def _encode_memo( async def _encode_memo_existing( self, source: Mapping[str, Any], memo: temporalio.api.common.v1.Memo ): + payloads = [] for k, v in source.items(): payload = v if not isinstance(v, temporalio.api.common.v1.Payload): payload = (await self.encode([v]))[0] memo.fields[k].CopyFrom(payload) + payloads.append(payload) + # Memos have their field payloads validated all together in one unit + self._validate_limits( + payloads, + self.payload_limits.memo_upload_error_limit, + self.payload_limits.memo_upload_warning_limit, + "Memo size exceeded the warning limit.", + ) async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] + self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) + self._validate_payload_limits(payloads.payloads) async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - if not self.payload_codec: - return list(payloads) - return await self.payload_codec.encode(payloads) - - # Temporary shortcircuit detection while the _encode_* methods may no-op if - # a payload codec is not configured. Remove once those paths have more to them. - @property - def _encode_payload_has_effect(self) -> bool: - return self.payload_codec is not None + encoded_payloads = list(payloads) + if self.payload_codec: + encoded_payloads = await self.payload_codec.encode(encoded_payloads) + self._validate_payload_limits(encoded_payloads) + return encoded_payloads async def _decode_payload( self, payload: temporalio.api.common.v1.Payload @@ -1452,6 +1476,38 @@ async def _apply_to_failure_payloads( if failure.HasField("cause"): await DataConverter._apply_to_failure_payloads(failure.cause, cb) + def _validate_payload_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + ): + self._validate_limits( + payloads, + self.payload_limits.payload_upload_error_limit, + self.payload_limits.payload_upload_warning_limit, + "Payloads size exceeded the warning limit.", + ) + + def _validate_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + error_limit: int | None, + warning_limit: int, + warning_message: str, + ): + total_size = sum(payload.ByteSize() for payload in payloads) + + if error_limit and error_limit > 0 and total_size > error_limit: + raise temporalio.exceptions.PayloadSizeError( + size=total_size, + limit=error_limit, + ) + + if warning_limit and warning_limit > 0 and total_size > warning_limit: + # TODO: Use a context aware logger to log extra information about workflow/activity/etc + warnings.warn( + f"{warning_message} Size: {total_size} bytes, Limit: {warning_limit} bytes" + ) + DefaultPayloadConverter.default_encoding_payload_converters = ( BinaryNullPayloadConverter(), diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f8f8ca20c..96a644a61 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -446,3 +446,28 @@ def is_cancelled_exception(exception: BaseException) -> bool: and isinstance(exception.cause, CancelledError) ) ) + + +class PayloadSizeError(TemporalError): + """Error raised when payloads size exceeds payload size limits.""" + + def __init__(self, size: int, limit: int): + """Initialize a payloads limit error. + + Args: + size: Actual payloads size in bytes. + limit: Payloads size limit in bytes. + """ + super().__init__("Payloads size exceeded the error limit") + self._size = size + self._limit = limit + + @property + def payloads_size(self) -> int: + """Actual payloads size in bytes.""" + return self._size + + @property + def payloads_limit(self) -> int: + """Payloads size limit in bytes.""" + return self._limit diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 28b434e59..cf9b90d56 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -380,6 +380,19 @@ async def _handle_start_activity_task( temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, ) + elif isinstance( + err, + temporalio.exceptions.PayloadSizeError, + ): + temporalio.activity.logger.warning( + "Activity task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + extra={"__temporal_error_identifier": "ActivityFailure"}, + ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) else: if ( isinstance( diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 40b72286e..74ffaba1d 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -23,6 +23,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.workflow +from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError from . import _command_aware_visitor @@ -372,6 +373,21 @@ async def _handle_activation( data_converter, encode_headers=self._encode_headers, ) + except temporalio.exceptions.PayloadSizeError as err: + # TODO: Would like to use temporalio.workflow.logger here, but + # that requires being in the workflow event loop. Possibly refactor + # the logger core functionality into shareable class and update + # LoggerAdapter to be a decorator. + logger.warning( + "Workflow task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + ) + completion.failed.Clear() + await data_converter.encode_failure(err, completion.failed.failure) + completion.failed.force_cause = ( + WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE + ) except Exception as err: logger.exception( "Failed encoding completion on workflow with run ID %s", act.run_id diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index b597a85ab..7f86d81f2 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -15,6 +15,7 @@ import time import typing import uuid +import warnings from abc import ABC, abstractmethod from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass @@ -88,6 +89,7 @@ DefaultPayloadConverter, PayloadCodec, PayloadConverter, + PayloadLimitsConfig, ) from temporalio.exceptions import ( ActivityError, @@ -95,18 +97,23 @@ ApplicationErrorCategory, CancelledError, ChildWorkflowError, + PayloadSizeError, TemporalError, TimeoutError, + TimeoutType, WorkflowAlreadyStartedError, ) from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, BUFFERED_METRIC_KIND_HISTOGRAM, + LogForwardingConfig, + LoggingConfig, MetricBuffer, MetricBufferDurationFormat, PrometheusConfig, Runtime, TelemetryConfig, + TelemetryFilter, ) from temporalio.service import RPCError, RPCStatusCode, __version__ from temporalio.testing import WorkflowEnvironment @@ -8485,3 +8492,469 @@ async def test_disable_logger_sandbox( run_timeout=timedelta(seconds=1), retry_policy=RetryPolicy(maximum_attempts=1), ) + + +@dataclass +class LargePayloadWorkflowInput: + activity_input_data_size: int + activity_output_data_size: int + workflow_output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadWorkflowOutput: + data: list[int] + + +@dataclass +class LargePayloadActivityInput: + output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadActivityOutput: + data: list[int] + + +@activity.defn +async def large_payload_activity( + input: LargePayloadActivityInput, +) -> LargePayloadActivityOutput: + return LargePayloadActivityOutput(data=[0] * input.output_data_size) + + +@workflow.defn +class LargePayloadWorkflow: + @workflow.run + async def run(self, input: LargePayloadWorkflowInput) -> LargePayloadWorkflowOutput: + await workflow.execute_activity( + large_payload_activity, + LargePayloadActivityInput( + output_data_size=input.activity_output_data_size, + data=[0] * input.activity_input_data_size, + ), + schedule_to_close_timeout=timedelta(seconds=5), + ) + return LargePayloadWorkflowOutput(data=[0] * input.workflow_output_data_size) + + +async def test_large_payload_error_workflow_input(client: Client): + config = client.config() + error_limit = 5 * 1024 + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, payload_upload_warning_limit=1024 + ), + ) + client = Client(**config) + + with pytest.raises(PayloadSizeError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 6 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue="test-queue", + ) + + assert error_limit == err.value.payloads_limit + + +async def test_large_payload_error_workflow_memo(client: Client): + config = client.config() + error_limit = 128 + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(memo_upload_error_limit=error_limit), + ) + client = Client(**config) + + with pytest.raises(PayloadSizeError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue="test-queue", + memo={"key1": [0] * 256}, + ) + + assert error_limit == err.value.payloads_limit + + +async def test_large_payload_warning_workflow_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 2 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_warning_workflow_memo(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(payload_upload_warning_limit=128), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + memo={"key1": [0] * 256}, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_workflow_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=6 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + assert err.value.cause.type == TimeoutType.START_TO_CLOSE + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + print(f"Justin Record: {record}") + return ( + record.levelname == "WARNING" + and "Payloads size exceeded the error limit" in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Workflow task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_workflow_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=2 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_activity_input(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=6 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size exceeded the error limit" in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Workflow task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_activity_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=2 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_activity_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured( + activity.logger.base_logger + ) as activity_logger_capturer, + # LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=6 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + + def activity_logger_predicate(record: logging.LogRecord) -> bool: + return ( + hasattr(record, "__temporal_error_identifier") + and getattr(record, "__temporal_error_identifier") == "ActivityFailure" + and record.levelname == "WARNING" + and "Activity task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert activity_logger_capturer.find(activity_logger_predicate) + + # Worker logger is not emitting this follow message. Maybe activity completion failures + # are not routed through the log forwarder whereas workflow completion failures are? + # def worker_logger_predicate(record: logging.LogRecord) -> bool: + # return "Payloads size exceeded the error limit" in record.msg + + # assert worker_logger_capturer.find(worker_logger_predicate) + + +async def test_large_payload_warning_activity_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=2 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message)