diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 732174732..c98afefca 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -315,7 +315,6 @@ async def encode_completion( encode_headers: bool, ) -> None: """Encode all payloads in the completion.""" - if data_converter._encode_payload_has_effect: - await CommandAwarePayloadVisitor( - skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + await CommandAwarePayloadVisitor( + skip_search_attributes=True, skip_headers=not encode_headers + ).visit(_Visitor(data_converter._encode_payload_sequence), completion) diff --git a/temporalio/client.py b/temporalio/client.py index 765f662fb..8c6877ad1 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6837,7 +6837,7 @@ async def _apply_headers( ) -> None: if source is None: return - if encode_headers and data_converter._encode_payload_has_effect: + if encode_headers: for payload in source.values(): payload.CopyFrom(await data_converter._encode_payload(payload)) temporalio.common._apply_headers(source, dest) diff --git a/temporalio/converter.py b/temporalio/converter.py index 66b31167c..8af621d2d 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1207,6 +1207,43 @@ def __init__(self) -> None: super().__init__(encode_common_attributes=True) +@dataclass(frozen=True) +class PayloadLimitsConfig: + """Configuration for when uploaded payload sizes exceed the Temporal server's limits.""" + + memo_upload_error_disabled: bool = False + """Field indiciating that the memo size checks should be disabled in the SDK. + + A value of False will cause the SDK to fail tasks that attempt to upload memos + with a size that is over the Temporal server memo limit. A value of True will + disable memo size checks in the SDK, allowing it to attempt to upload memos + even if their size is over the Temporal server limit. + + The default value is False.""" + + memo_upload_warning_limit: int = 2 * 1024 + """The limit (in bytes) at which a memo size warning is logged.""" + + payload_upload_error_disabled: bool = False + """Field indiciating that the payload size checks should be disabled in the SDK. + + A value of False will cause the SDK to fail tasks that attempt to upload payloads + with a size that is over the Temporal server payloads limit. A value of True will + disable payload size checks in the SDK, allowing it to attempt to upload payloads + even if their size is over the Temporal server limit. + + The default value is False.""" + + payload_upload_warning_limit: int = 512 * 1024 + """The limit (in bytes) at which a payload size warning is logged.""" + + +@dataclass +class _PayloadErrorLimits: + memo_upload_error_limit: int + payload_upload_error_limit: int + + @dataclass(frozen=True) class DataConverter(WithSerializationContext): """Data converter for converting and encoding payloads to/from Python values. @@ -1230,9 +1267,16 @@ class DataConverter(WithSerializationContext): failure_converter: FailureConverter = dataclasses.field(init=False) """Failure converter created from the :py:attr:`failure_converter_class`.""" + payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() + """Settings for payload size limits.""" + default: ClassVar[DataConverter] """Singleton default data converter.""" + _memo_upload_error_limit: int = 0 + + _payload_upload_error_limit: int = 0 + def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "payload_converter", self.payload_converter_class()) object.__setattr__(self, "failure_converter", self.failure_converter_class()) @@ -1334,6 +1378,17 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "failure_converter", failure_converter) return cloned + def _with_payload_error_limits(self, options: _PayloadErrorLimits) -> DataConverter: + return dataclasses.replace( + self, + _memo_upload_error_limit=0 + if self.payload_limits.memo_upload_error_disabled + else options.memo_upload_error_limit, + _payload_upload_error_limit=0 + if self.payload_limits.payload_upload_error_disabled + else options.payload_upload_error_limit, + ) + async def _decode_memo( self, source: temporalio.api.common.v1.Memo, @@ -1367,35 +1422,42 @@ async def _encode_memo( async def _encode_memo_existing( self, source: Mapping[str, Any], memo: temporalio.api.common.v1.Memo ): + payloads = [] for k, v in source.items(): payload = v if not isinstance(v, temporalio.api.common.v1.Payload): payload = (await self.encode([v]))[0] memo.fields[k].CopyFrom(payload) + payloads.append(payload) + # Memos have their field payloads validated all together in one unit + self._validate_limits( + payloads, + self._memo_upload_error_limit, + self.payload_limits.memo_upload_warning_limit, + "Memo size exceeded the warning limit.", + ) async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] + self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) + self._validate_payload_limits(payloads.payloads) async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - if not self.payload_codec: - return list(payloads) - return await self.payload_codec.encode(payloads) - - # Temporary shortcircuit detection while the _encode_* methods may no-op if - # a payload codec is not configured. Remove once those paths have more to them. - @property - def _encode_payload_has_effect(self) -> bool: - return self.payload_codec is not None + encoded_payloads = list(payloads) + if self.payload_codec: + encoded_payloads = await self.payload_codec.encode(encoded_payloads) + self._validate_payload_limits(encoded_payloads) + return encoded_payloads async def _decode_payload( self, payload: temporalio.api.common.v1.Payload @@ -1452,6 +1514,38 @@ async def _apply_to_failure_payloads( if failure.HasField("cause"): await DataConverter._apply_to_failure_payloads(failure.cause, cb) + def _validate_payload_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + ): + self._validate_limits( + payloads, + self._payload_upload_error_limit, + self.payload_limits.payload_upload_warning_limit, + "Payloads size exceeded the warning limit.", + ) + + def _validate_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + error_limit: int, + warning_limit: int, + warning_message: str, + ): + total_size = sum(payload.ByteSize() for payload in payloads) + + if error_limit > 0 and total_size > error_limit: + raise temporalio.exceptions.PayloadSizeError( + size=total_size, + limit=error_limit, + ) + + if warning_limit > 0 and total_size > warning_limit: + # TODO: Use a context aware logger to log extra information about workflow/activity/etc + warnings.warn( + f"{warning_message} Size: {total_size} bytes, Limit: {warning_limit} bytes" + ) + DefaultPayloadConverter.default_encoding_payload_converters = ( BinaryNullPayloadConverter(), diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f8f8ca20c..96a644a61 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -446,3 +446,28 @@ def is_cancelled_exception(exception: BaseException) -> bool: and isinstance(exception.cause, CancelledError) ) ) + + +class PayloadSizeError(TemporalError): + """Error raised when payloads size exceeds payload size limits.""" + + def __init__(self, size: int, limit: int): + """Initialize a payloads limit error. + + Args: + size: Actual payloads size in bytes. + limit: Payloads size limit in bytes. + """ + super().__init__("Payloads size exceeded the error limit") + self._size = size + self._limit = limit + + @property + def payloads_size(self) -> int: + """Actual payloads size in bytes.""" + return self._size + + @property + def payloads_limit(self) -> int: + """Payloads size limit in bytes.""" + return self._limit diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 28b434e59..55e25680b 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -126,8 +126,15 @@ def __init__( else: self._dynamic_activity = defn - async def run(self) -> None: + async def run( + self, + payload_error_limits: temporalio.converter._PayloadErrorLimits | None, + ) -> None: """Continually poll for activity tasks and dispatch to handlers.""" + if payload_error_limits: + self._data_converter = self._data_converter._with_payload_error_limits( + payload_error_limits + ) async def raise_from_exception_queue() -> NoReturn: raise await self._fail_worker_exception_queue.get() @@ -380,6 +387,19 @@ async def _handle_start_activity_task( temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, ) + elif isinstance( + err, + temporalio.exceptions.PayloadSizeError, + ): + temporalio.activity.logger.warning( + "Activity task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + extra={"__temporal_error_identifier": "ActivityFailure"}, + ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) else: if ( isinstance( diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 9a32c2cd5..53a998816 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -93,8 +93,15 @@ def __init__( self._running_tasks: dict[bytes, _RunningNexusTask] = {} self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() - async def run(self) -> None: + async def run( + self, + payload_error_limits: temporalio.converter._PayloadErrorLimits | None, + ) -> None: """Continually poll for Nexus tasks and dispatch to handlers.""" + if payload_error_limits: + self._data_converter = self._data_converter._with_payload_error_limits( + payload_error_limits + ) async def raise_from_exception_queue() -> NoReturn: raise await self._fail_worker_exception_queue.get() diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 7c93a453b..de55c4845 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -327,7 +327,7 @@ def on_eviction_hook( bridge_worker_scope = bridge_worker # Start worker - workflow_worker_task = asyncio.create_task(workflow_worker.run()) + workflow_worker_task = asyncio.create_task(workflow_worker.run(None)) # Yield iterator async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index f664c0877..e1a063b67 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -29,6 +29,7 @@ VersioningBehavior, WorkerDeploymentVersion, ) +from temporalio.converter import _PayloadErrorLimits from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor @@ -715,7 +716,15 @@ def make_lambda(plugin: Plugin, next: Callable[[Worker], Awaitable[None]]): async def _run(self): # Eagerly validate which will do a namespace check in Core - await self._bridge_worker.validate() + namespace_info = await self._bridge_worker.validate() + payload_error_limits = ( + _PayloadErrorLimits( + memo_upload_error_limit=namespace_info.Limits.memo_size_limit_error, + payload_upload_error_limit=namespace_info.Limits.blob_size_limit_error, + ) + if namespace_info.HasField("limits") + else None + ) if self._started: raise RuntimeError("Already started") @@ -735,14 +744,16 @@ async def raise_on_shutdown(): # Create tasks for workers if self._activity_worker: tasks[self._activity_worker] = asyncio.create_task( - self._activity_worker.run() + self._activity_worker.run(payload_error_limits) ) if self._workflow_worker: tasks[self._workflow_worker] = asyncio.create_task( - self._workflow_worker.run() + self._workflow_worker.run(payload_error_limits) ) if self._nexus_worker: - tasks[self._nexus_worker] = asyncio.create_task(self._nexus_worker.run()) + tasks[self._nexus_worker] = asyncio.create_task( + self._nexus_worker.run(payload_error_limits) + ) # Wait for either worker or shutdown requested wait_task = asyncio.wait(tasks.values(), return_when=asyncio.FIRST_EXCEPTION) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 40b72286e..04dc33638 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -23,6 +23,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.workflow +from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError from . import _command_aware_visitor @@ -163,7 +164,15 @@ def __init__( else: self._dynamic_workflow = defn - async def run(self) -> None: + async def run( + self, + payload_error_limits: temporalio.converter._PayloadErrorLimits | None, + ) -> None: + if payload_error_limits: + self._data_converter = self._data_converter._with_payload_error_limits( + payload_error_limits + ) + # Continually poll for workflow work task_tag = object() try: @@ -372,6 +381,21 @@ async def _handle_activation( data_converter, encode_headers=self._encode_headers, ) + except temporalio.exceptions.PayloadSizeError as err: + # TODO: Would like to use temporalio.workflow.logger here, but + # that requires being in the workflow event loop. Possibly refactor + # the logger core functionality into shareable class and update + # LoggerAdapter to be a decorator. + logger.warning( + "Workflow task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + ) + completion.failed.Clear() + await data_converter.encode_failure(err, completion.failed.failure) + completion.failed.force_cause = ( + WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE + ) except Exception as err: logger.exception( "Failed encoding completion on workflow with run ID %s", act.run_id diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index b597a85ab..866fdb753 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -15,6 +15,7 @@ import time import typing import uuid +import warnings from abc import ABC, abstractmethod from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass @@ -88,6 +89,8 @@ DefaultPayloadConverter, PayloadCodec, PayloadConverter, + PayloadLimitsConfig, + _PayloadErrorLimits, ) from temporalio.exceptions import ( ActivityError, @@ -97,16 +100,20 @@ ChildWorkflowError, TemporalError, TimeoutError, + TimeoutType, WorkflowAlreadyStartedError, ) from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, BUFFERED_METRIC_KIND_HISTOGRAM, + LogForwardingConfig, + LoggingConfig, MetricBuffer, MetricBufferDurationFormat, PrometheusConfig, Runtime, TelemetryConfig, + TelemetryFilter, ) from temporalio.service import RPCError, RPCStatusCode, __version__ from temporalio.testing import WorkflowEnvironment @@ -8485,3 +8492,412 @@ async def test_disable_logger_sandbox( run_timeout=timedelta(seconds=1), retry_policy=RetryPolicy(maximum_attempts=1), ) + + +@dataclass +class LargePayloadWorkflowInput: + activity_input_data_size: int + activity_output_data_size: int + workflow_output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadWorkflowOutput: + data: list[int] + + +@dataclass +class LargePayloadActivityInput: + output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadActivityOutput: + data: list[int] + + +@activity.defn +async def large_payload_activity( + input: LargePayloadActivityInput, +) -> LargePayloadActivityOutput: + return LargePayloadActivityOutput(data=[0] * input.output_data_size) + + +@workflow.defn +class LargePayloadWorkflow: + @workflow.run + async def run(self, input: LargePayloadWorkflowInput) -> LargePayloadWorkflowOutput: + await workflow.execute_activity( + large_payload_activity, + LargePayloadActivityInput( + output_data_size=input.activity_output_data_size, + data=[0] * input.activity_input_data_size, + ), + schedule_to_close_timeout=timedelta(seconds=5), + ) + return LargePayloadWorkflowOutput(data=[0] * input.workflow_output_data_size) + + +async def test_large_payload_warning_workflow_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_warning_limit=102, + ), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 2 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_warning_workflow_memo(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(payload_upload_warning_limit=128), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + memo={"key1": [0] * 256}, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_workflow_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=temporalio.converter.default()._with_payload_error_limits( + _PayloadErrorLimits( + memo_upload_error_limit=0, + payload_upload_error_limit=error_limit, + ) + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=6 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + assert err.value.cause.type == TimeoutType.START_TO_CLOSE + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size exceeded the error limit" in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Workflow task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_workflow_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_warning_limit=1024, + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=2 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_activity_input(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=temporalio.converter.default()._with_payload_error_limits( + _PayloadErrorLimits( + memo_upload_error_limit=0, + payload_upload_error_limit=error_limit, + ) + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=6 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size exceeded the error limit" in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Workflow task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_activity_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_warning_limit=1024, + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=2 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_activity_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=temporalio.converter.default()._with_payload_error_limits( + _PayloadErrorLimits( + memo_upload_error_limit=0, + payload_upload_error_limit=error_limit, + ) + ), + ) + + with ( + LogCapturer().logs_captured( + activity.logger.base_logger + ) as activity_logger_capturer, + # LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=6 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + + def activity_logger_predicate(record: logging.LogRecord) -> bool: + return ( + hasattr(record, "__temporal_error_identifier") + and getattr(record, "__temporal_error_identifier") == "ActivityFailure" + and record.levelname == "WARNING" + and "Activity task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert activity_logger_capturer.find(activity_logger_predicate) + + # Worker logger is not emitting this follow message. Maybe activity completion failures + # are not routed through the log forwarder whereas workflow completion failures are? + # def worker_logger_predicate(record: logging.LogRecord) -> bool: + # return "Payloads size exceeded the error limit" in record.msg + + # assert worker_logger_capturer.find(worker_logger_predicate) + + +async def test_large_payload_warning_activity_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_warning_limit=1024, + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=2 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message)