diff --git a/providers/common/ai/docs/operators/agent.rst b/providers/common/ai/docs/operators/agent.rst index b3805aa34b765..6854ca4e948ad 100644 --- a/providers/common/ai/docs/operators/agent.rst +++ b/providers/common/ai/docs/operators/agent.rst @@ -207,18 +207,30 @@ fails mid-run (network error, timeout, transient API failure), a plain retry re-executes every LLM call and tool call from scratch -- repeating work that already succeeded and incurring additional cost. -Setting ``durable=True`` caches each LLM response and tool result to -ObjectStorage as it completes. On retry, completed steps are replayed from the -cache and only the remaining steps run against the live model and tools. The -cache is deleted after successful completion. +Setting ``durable=True`` caches each LLM response and tool result as it +completes. On retry, completed steps are replayed from the cache and only the +remaining steps run against the live model and tools. The cache is deleted +after successful completion. Durable execution only helps when the task has retries configured. Without retries there is nothing to replay. **Configuration** -Set the cache location in ``airflow.cfg``. The task raises ``ValueError`` at -runtime if ``durable=True`` and the option is missing. +On **Airflow >= 3.3** the cache is stored in the +:doc:`task state store `, +scoped to the task instance. No configuration is required; the store handles +persistence across retries. + +By default each cached step is written to the Airflow metadata database. Model +responses and large tool results can be sizable, so for agents with large +payloads configure ``[workers] state_store_backend`` to offload step values to +external storage (e.g. object storage) instead of the metadata database; the +provider then stores only a reference in the database. + +On **Airflow < 3.3** the cache is persisted to ObjectStorage and the location +must be set in ``airflow.cfg``. The task raises ``ValueError`` at runtime if +``durable=True`` and the option is missing. .. code-block:: ini @@ -251,10 +263,10 @@ cache: **How it works** -1. On first execution, each LLM response and tool result is saved to a JSON - file as the agent progresses, together with a fingerprint of the request - that produced it (model, message history, settings, and tools for LLM - steps; tool name, arguments, and call id for tool steps). +1. On first execution, each LLM response and tool result is saved as the agent + progresses, together with a fingerprint of the request that produced it + (model, message history, settings, and tools for LLM steps; tool name, + arguments, and call id for tool steps). 2. If the task fails and Airflow retries it, completed steps are loaded from the cache and returned without calling the model or tool. Steps not yet in the cache proceed normally. @@ -266,29 +278,31 @@ cache: an LLM step produces fresh tool call ids, so tool results recorded under the old conversation no longer match. A changed agent costs a re-run; it never replays responses that belong to a different conversation. -4. After successful completion, the cache file is deleted. +4. After successful completion, the cached steps are deleted. Replay verification compares the **requests** sent to models and tools, not the code behind them. Editing a tool's implementation between attempts does not invalidate an already-cached result for an identical call, and pointing ``llm_conn_id`` at a different endpoint serving the same model name does not -invalidate cached responses -- delete the cache file to force a fully fresh -run. +invalidate cached responses -- clear the cache to force a fully fresh run. After the run, a single INFO summary line reports how many steps were replayed vs executed fresh. Per-step detail is available at DEBUG level. -The cache file is named ``{dag_id}_{task_id}_{run_id}.json`` (with -``_{map_index}`` appended for mapped tasks) and stored under the configured -``durable_cache_path``. To force a completely fresh run, delete the cache file -for that task. +The cache is scoped to a single task instance (DAG id, run id, task id, and +map index), so each run replays only its own steps. On Airflow >= 3.3 the cache +lives in the task state store and is removed when the DAG run is cleaned up; on +Airflow < 3.3 it is a JSON file named ``{dag_id}_{task_id}_{run_id}.json`` (with +``_{map_index}`` appended for mapped tasks) under the configured +``durable_cache_path``. .. note:: - Runs that fail permanently (exhaust all retries) leave their cache file - behind. These orphaned files do not affect future DAG runs (each run gets - its own file) but will consume storage. Clean them up periodically or add - a lifecycle policy to the storage backend. + Runs that fail permanently (exhaust all retries) leave their cached steps + behind. These do not affect future DAG runs (each run is scoped separately). + On Airflow >= 3.3 they are reclaimed when the DAG run is removed; on Airflow + < 3.3 the orphaned JSON files consume storage until cleaned up, so add a + lifecycle policy to the storage backend or remove them periodically. **Side effects and idempotency** @@ -443,9 +457,10 @@ Parameters prone to runaway tool loops, so ``tool_calls_limit`` is a useful guardrail. See :ref:`howto/operator:llm` for an example. Default ``None``. - ``durable``: When ``True``, enables step-level caching of model responses and - tool results via ObjectStorage. On retry, cached steps are replayed instead of - re-executing expensive LLM calls. Requires the ``[common.ai] durable_cache_path`` - config option to be set. Default ``False``. + tool results. On retry, cached steps are replayed instead of re-executing + expensive LLM calls. On Airflow >= 3.3 the cache uses the task state store (no + configuration needed); on older cores it requires the ``[common.ai] + durable_cache_path`` config option to be set. Default ``False``. - ``code_mode``: When ``True``, wraps the agent's tools in a single ``run_code`` tool that the model drives by writing Python, executed in the Monty sandbox. Requires the ``code-mode`` extra. Default ``False``. See :ref:`code-mode`. diff --git a/providers/common/ai/provider.yaml b/providers/common/ai/provider.yaml index 893f5d230f7b3..9437690b6c00f 100644 --- a/providers/common/ai/provider.yaml +++ b/providers/common/ai/provider.yaml @@ -89,15 +89,17 @@ config: durable_cache_path: description: | ObjectStorage URI used to persist per-step caches when running - ``AgentOperator`` / ``@task.agent`` with ``durable=True``. Each task - execution writes a single JSON file under this path containing its - cached model responses and tool results, so that on retry the agent - can replay completed steps instead of re-issuing LLM calls and tool - invocations. The file is deleted on successful task completion. + ``AgentOperator`` / ``@task.agent`` with ``durable=True`` on Airflow + **< 3.3**. Each task execution writes a single JSON file under this + path containing its cached model responses and tool results, so that + on retry the agent can replay completed steps instead of re-issuing + LLM calls and tool invocations. The file is deleted on successful task + completion. - Required when ``durable=True`` is used. Any scheme supported by - ``airflow.sdk.ObjectStoragePath`` is accepted (``file://``, ``s3://``, - ``gs://``, ``azure://``, ...). + Required for ``durable=True`` only on Airflow < 3.3. On Airflow >= 3.3 + the cache is stored in the AIP-103 task state store and this option is + ignored. Any scheme supported by ``airflow.sdk.ObjectStoragePath`` is + accepted (``file://``, ``s3://``, ``gs://``, ``azure://``, ...). version_added: 0.1.0 type: string example: "file:///tmp/airflow_durable_cache" diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/base.py b/providers/common/ai/src/airflow/providers/common/ai/durable/base.py new file mode 100644 index 0000000000000..9fff491012efb --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/base.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Shared interface for durable execution storage backends.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from pydantic_ai.messages import ModelResponse + +# Marks a stored entry as a cached tool result; lets ``load_tool_result`` +# tell a cached ``None`` apart from a missing entry. Single source of truth so +# the two backends cannot drift on the envelope shape. +TOOL_RESULT_SENTINEL = "__durable_cached__" + +# Prefix for durable cache keys. On the task state store backend (>= 3.3) the +# cache shares the task instance's key namespace with anything user code writes +# via ``context["task_state_store"]``; the reserved prefix keeps durable steps +# from colliding with user keys. No ``/`` -- task state store keys are a single, +# un-encoded URL path segment. +DURABLE_KEY_PREFIX = "__commonai_durable__" + + +@runtime_checkable +class DurableStorageProtocol(Protocol): + """ + Persistence contract shared by the durable execution storage backends. + + Implemented by both :class:`~airflow.providers.common.ai.durable.storage.DurableStorage` + (ObjectStorage, Airflow < 3.3) and + :class:`~airflow.providers.common.ai.durable.task_state_store.TaskStateStoreDurableStorage` + (AIP-103 task state store, Airflow >= 3.3). ``CachingModel`` and + ``CachingToolset`` depend on this interface, not a concrete backend. + """ + + def save_model_response(self, key: str, response: ModelResponse, *, fingerprint: str | None) -> None: ... + + def load_model_response(self, key: str) -> tuple[ModelResponse | None, str | None]: ... + + def save_tool_result(self, key: str, result: Any, *, fingerprint: str | None) -> None: ... + + def load_tool_result(self, key: str) -> tuple[bool, Any, str | None]: ... + + def cleanup(self) -> None: ... diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py index 18f89439d82dd..6118311f70cf3 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py @@ -24,6 +24,7 @@ import structlog from pydantic_ai.models.wrapper import WrapperModel +from airflow.providers.common.ai.durable.base import DURABLE_KEY_PREFIX from airflow.providers.common.ai.durable.fingerprint import fingerprint_model_request log = structlog.get_logger(logger_name="task") @@ -33,8 +34,8 @@ from pydantic_ai.models import ModelRequestParameters from pydantic_ai.settings import ModelSettings + from airflow.providers.common.ai.durable.base import DurableStorageProtocol from airflow.providers.common.ai.durable.step_counter import DurableStepCounter - from airflow.providers.common.ai.durable.storage import DurableStorage @dataclass(init=False) @@ -51,14 +52,14 @@ class CachingModel(WrapperModel): discarded and the step re-runs live. """ - storage: DurableStorage = field(repr=False) + storage: DurableStorageProtocol = field(repr=False) counter: DurableStepCounter = field(repr=False) def __init__( self, wrapped: Any, *, - storage: DurableStorage, + storage: DurableStorageProtocol, counter: DurableStepCounter, ) -> None: super().__init__(wrapped) @@ -72,7 +73,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: step = self.counter.next_step() - key = f"model_step_{step}" + key = f"{DURABLE_KEY_PREFIX}model_step_{step}" # Fingerprint the *prepared* request, not the raw arguments. Concrete # models call ``prepare_request()`` at the start of ``request()`` to merge # their model-level ``settings`` and apply profile-specific transforms diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py index 045c98aea1f50..b411c57ef690f 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py @@ -24,13 +24,14 @@ import structlog from pydantic_ai.toolsets.wrapper import WrapperToolset +from airflow.providers.common.ai.durable.base import DURABLE_KEY_PREFIX from airflow.providers.common.ai.durable.fingerprint import fingerprint_tool_call if TYPE_CHECKING: from pydantic_ai.toolsets.abstract import ToolsetTool + from airflow.providers.common.ai.durable.base import DurableStorageProtocol from airflow.providers.common.ai.durable.step_counter import DurableStepCounter - from airflow.providers.common.ai.durable.storage import DurableStorage log = structlog.get_logger(logger_name="task") @@ -53,7 +54,7 @@ class CachingToolset(WrapperToolset[Any]): executing their synchronous preamble in creation order). """ - storage: DurableStorage = field(repr=False) + storage: DurableStorageProtocol = field(repr=False) counter: DurableStepCounter = field(repr=False) async def call_tool( @@ -66,7 +67,7 @@ async def call_tool( # Grab step index BEFORE any await -- ensures deterministic ordering # even when multiple tool calls run concurrently via asyncio.gather. step = self.counter.next_step() - key = f"tool_step_{step}" + key = f"{DURABLE_KEY_PREFIX}tool_step_{step}" fingerprint = fingerprint_tool_call(name, tool_args, ctx.tool_call_id) found, cached, cached_fingerprint = self.storage.load_tool_result(key) diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py index 88de494fe91df..d50107631a4a1 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py @@ -26,10 +26,11 @@ import structlog from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelResponse -log = structlog.get_logger(logger_name="task") - # Sentinel to distinguish "cached None" from "no cache entry" for tool results. -_SENTINEL = "__durable_cached__" +# Shared with the task state store backend so the envelope shape cannot drift. +from airflow.providers.common.ai.durable.base import TOOL_RESULT_SENTINEL as _SENTINEL + +log = structlog.get_logger(logger_name="task") SECTION = "common.ai" diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/task_state_store.py b/providers/common/ai/src/airflow/providers/common/ai/durable/task_state_store.py new file mode 100644 index 0000000000000..a81b6043eca4f --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/task_state_store.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Task-state-store-backed durable storage for pydantic-ai agent step caching. + +Available on Airflow >= 3.3, where the AIP-103 task state store provides a +per-task-instance key/value store that survives retries within a run and is +cleared when the run is removed. Each cached step is written under its own key +(``model_step_{N}`` / ``tool_step_{N}``); the store handles persistence and, +when ``[workers] state_store_backend`` is configured, transparently offloads +large values to external storage. No ``[common.ai] durable_cache_path`` is +needed. + +This module is imported only on Airflow >= 3.3 (see +``AgentOperator._build_durable_storage``); ``NEVER_EXPIRE`` does not exist on +older cores. +""" + +from __future__ import annotations + +import contextlib +import json +from typing import TYPE_CHECKING, Any + +import structlog +from pydantic_ai.messages import ModelMessagesTypeAdapter + +from airflow.providers.common.ai.durable.base import TOOL_RESULT_SENTINEL +from airflow.sdk.execution_time.context import NEVER_EXPIRE + +if TYPE_CHECKING: + from pydantic_ai.messages import ModelResponse + + from airflow.sdk.execution_time.context import TaskStateStoreAccessor + +log = structlog.get_logger(logger_name="task") + + +class TaskStateStoreDurableStorage: + """ + Stores step-level durable caches in the AIP-103 task state store. + + Each model response and tool result is written under its own key, scoped to + the current task instance. Entries are written with ``NEVER_EXPIRE`` so a + retry can replay them regardless of ``retry_delay`` or the global retention + config, and the keys this run touched are deleted on successful completion. + + A run that fails permanently leaves its keys behind (``NEVER_EXPIRE`` skips + garbage collection); they are removed when the DAG run is cleaned up, since + task state store rows cascade with the run. + + :param accessor: The task state store accessor for the current task + instance (``context["task_state_store"]``). + """ + + def __init__(self, accessor: TaskStateStoreAccessor) -> None: + self._store = accessor + # Keys written or replayed this run, deleted on cleanup. A divergent + # retry that takes fewer steps may orphan keys from a longer earlier + # attempt; those are reclaimed by the DAG-run cascade, not here. + self._keys: set[str] = set() + + def save_model_response(self, key: str, response: ModelResponse, *, fingerprint: str | None) -> None: + """Serialize and store a ModelResponse with the request fingerprint that produced it.""" + self._store.set( + key, + { + "fingerprint": fingerprint, + "data": ModelMessagesTypeAdapter.dump_python([response], mode="json"), + }, + retention=NEVER_EXPIRE, + ) + self._keys.add(key) + + def load_model_response(self, key: str) -> tuple[ModelResponse | None, str | None]: + """ + Load a cached ModelResponse and its stored request fingerprint. + + Returns ``(None, None)`` on a miss or a torn entry, so the step re-runs + live rather than crashing the task. + """ + raw = self._store.get(key) + if not isinstance(raw, dict): + return None, None + try: + messages = ModelMessagesTypeAdapter.validate_python(raw["data"]) + except (KeyError, IndexError, TypeError, ValueError): + log.warning("Durable: ignoring malformed cached model response", key=key) + return None, None + # A foreign/torn entry can validate as a ModelRequest; only a response is replayable. + if not messages or messages[0].kind != "response": + return None, None + self._keys.add(key) + fingerprint = raw.get("fingerprint") + return messages[0], fingerprint if isinstance(fingerprint, str) else None # type: ignore[return-value] + + def save_tool_result(self, key: str, result: Any, *, fingerprint: str | None) -> None: + """ + Store a tool call result with the call fingerprint that produced it. + + Non-serializable results (e.g. BinaryContent from MCP tools) are skipped + with a warning -- the tool call still succeeds, but won't be replayed on + retry. + """ + try: + # Probe serializability before writing: a non-serializable result + # must skip only this entry, not surface as an opaque comms error. + json.dumps(result) + except (TypeError, ValueError): + log.warning( + "Durable: skipping cache for non-serializable tool result", + key=key, + type=type(result).__name__, + ) + return + self._store.set( + key, + {TOOL_RESULT_SENTINEL: True, "value": result, "fingerprint": fingerprint}, + retention=NEVER_EXPIRE, + ) + self._keys.add(key) + + def load_tool_result(self, key: str) -> tuple[bool, Any, str | None]: + """ + Load a cached tool result and its stored call fingerprint. + + Returns a ``(found, value, fingerprint)`` tuple since the cached value + itself may be ``None``. + """ + raw = self._store.get(key) + if not isinstance(raw, dict) or TOOL_RESULT_SENTINEL not in raw: + return False, None, None + self._keys.add(key) + fingerprint = raw.get("fingerprint") + return True, raw.get("value"), fingerprint if isinstance(fingerprint, str) else None + + def cleanup(self) -> None: + """Delete the keys this run wrote or replayed after successful execution.""" + for key in self._keys: + # Runs only after the task has already succeeded, so it must never raise + # (that would fail a succeeded task). A key left behind by a failed delete + # is reclaimed by the DAG-run cascade -- hence the deliberately broad catch. + with contextlib.suppress(Exception): + self._store.delete(key) + self._keys.clear() + log.debug("Durable cache cleaned up") diff --git a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py index 8bc03c266cb14..4f31d62e84e7b 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py +++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py @@ -92,7 +92,7 @@ def get_provider_info(): "description": "Options for the ``apache-airflow-providers-common-ai`` provider.\n", "options": { "durable_cache_path": { - "description": "ObjectStorage URI used to persist per-step caches when running\n``AgentOperator`` / ``@task.agent`` with ``durable=True``. Each task\nexecution writes a single JSON file under this path containing its\ncached model responses and tool results, so that on retry the agent\ncan replay completed steps instead of re-issuing LLM calls and tool\ninvocations. The file is deleted on successful task completion.\n\nRequired when ``durable=True`` is used. Any scheme supported by\n``airflow.sdk.ObjectStoragePath`` is accepted (``file://``, ``s3://``,\n``gs://``, ``azure://``, ...).\n", + "description": "ObjectStorage URI used to persist per-step caches when running\n``AgentOperator`` / ``@task.agent`` with ``durable=True`` on Airflow\n**< 3.3**. Each task execution writes a single JSON file under this\npath containing its cached model responses and tool results, so that\non retry the agent can replay completed steps instead of re-issuing\nLLM calls and tool invocations. The file is deleted on successful task\ncompletion.\n\nRequired for ``durable=True`` only on Airflow < 3.3. On Airflow >= 3.3\nthe cache is stored in the AIP-103 task state store and this option is\nignored. Any scheme supported by ``airflow.sdk.ObjectStoragePath`` is\naccepted (``file://``, ``s3://``, ``gs://``, ``azure://``, ...).\n", "version_added": "0.1.0", "type": "string", "example": "file:///tmp/airflow_durable_cache", diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py index 56c9ec5bbb65a..7b3a7b54680ea 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py @@ -36,7 +36,7 @@ BaseOperatorLink, conf, ) -from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS +from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS try: # See LLMOperator: new enough cores register declared ``output_type`` classes @@ -52,8 +52,8 @@ from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.usage import UsageLimits + from airflow.providers.common.ai.durable.base import DurableStorageProtocol from airflow.providers.common.ai.durable.step_counter import DurableStepCounter - from airflow.providers.common.ai.durable.storage import DurableStorage from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.sdk import Context @@ -154,7 +154,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin): model, settings, tools, or message history changed since the failed attempt, the affected steps re-run live (with a warning) instead of replaying stale results. Default ``False``. - Requires ``[common.ai] durable_cache_path`` to be set. + On Airflow >= 3.3 the cache is kept in the AIP-103 task state store, so + no extra configuration is needed. On older cores it is persisted to + ObjectStorage and requires ``[common.ai] durable_cache_path`` to be set. :param code_mode: When ``True``, wraps the agent's tools in a single ``run_code`` tool powered by the Monty sandbox (pydantic-ai-harness ``CodeMode``). Instead of one model round-trip per tool call, the model @@ -324,13 +326,39 @@ def _build_agent(self) -> Agent[None, Any]: ) def _build_durable_toolsets( - self, toolsets: list[AbstractToolset], storage: DurableStorage, counter: DurableStepCounter + self, toolsets: list[AbstractToolset], storage: DurableStorageProtocol, counter: DurableStepCounter ) -> list[AbstractToolset]: """Wrap each toolset with CachingToolset for durable execution.""" from airflow.providers.common.ai.durable.caching_toolset import CachingToolset return [CachingToolset(wrapped=ts, storage=storage, counter=counter) for ts in toolsets] + def _build_durable_storage(self, context: Context) -> DurableStorageProtocol: + """ + Return the durable storage backend for the current task instance. + + On Airflow >= 3.3 durable steps are cached in the AIP-103 task state + store, which handles persistence and large-value offload natively, so no + ``[common.ai] durable_cache_path`` is required. On older cores, fall back + to the ObjectStorage backend configured via ``durable_cache_path``. + """ + if AIRFLOW_V_3_3_PLUS: + # Imported lazily: NEVER_EXPIRE and the task state store accessor do + # not exist on cores before 3.3. + from airflow.providers.common.ai.durable.task_state_store import TaskStateStoreDurableStorage + + return TaskStateStoreDurableStorage(context["task_state_store"]) + + from airflow.providers.common.ai.durable.storage import DurableStorage + + ti = context["task_instance"] + return DurableStorage( + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index if ti.map_index is not None else -1, + ) + def execute(self, context: Context) -> Any: if self.enable_hitl_review and not isinstance(self.prompt, str): raise TypeError( @@ -345,15 +373,8 @@ def execute(self, context: Context) -> Any: if self.durable: from airflow.providers.common.ai.durable.step_counter import DurableStepCounter - from airflow.providers.common.ai.durable.storage import DurableStorage - - ti = context["task_instance"] - self._durable_storage = DurableStorage( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=ti.map_index if ti.map_index is not None else -1, - ) + + self._durable_storage = self._build_durable_storage(context) self._durable_counter = DurableStepCounter() agent = self._build_agent() diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_base.py b/providers/common/ai/tests/unit/common/ai/durable/test_base.py new file mode 100644 index 0000000000000..afbf2e0d36c36 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_base.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from airflow.providers.common.ai.durable.base import ( + DURABLE_KEY_PREFIX, + TOOL_RESULT_SENTINEL, + DurableStorageProtocol, +) + + +class _CompleteBackend: + """Implements the full ``DurableStorageProtocol`` surface.""" + + def save_model_response(self, key, response, *, fingerprint): + pass + + def load_model_response(self, key): + return None, None + + def save_tool_result(self, key, result, *, fingerprint): + pass + + def load_tool_result(self, key): + return False, None, None + + def cleanup(self): + pass + + +class _PartialBackend: + """Missing ``cleanup`` -- must not satisfy the protocol.""" + + def save_model_response(self, key, response, *, fingerprint): + pass + + def load_model_response(self, key): + return None, None + + def save_tool_result(self, key, result, *, fingerprint): + pass + + def load_tool_result(self, key): + return False, None, None + + +class TestDurableConstants: + def test_key_prefix_is_a_single_path_segment(self): + # Task state store keys are a single, un-encoded URL path segment, so the + # reserved prefix must not contain a separator that would split the key. + assert "/" not in DURABLE_KEY_PREFIX + + def test_sentinel_and_prefix_are_distinct(self): + # Both are reserved markers written into the same store; if they ever + # coincided a cached tool result could be mistaken for a key prefix. + assert TOOL_RESULT_SENTINEL != DURABLE_KEY_PREFIX + + def test_constants_are_non_empty_strings(self): + assert isinstance(TOOL_RESULT_SENTINEL, str) + assert TOOL_RESULT_SENTINEL + assert isinstance(DURABLE_KEY_PREFIX, str) + assert DURABLE_KEY_PREFIX + + +class TestDurableStorageProtocol: + def test_complete_backend_satisfies_protocol(self): + assert isinstance(_CompleteBackend(), DurableStorageProtocol) + + def test_partial_backend_does_not_satisfy_protocol(self): + assert not isinstance(_PartialBackend(), DurableStorageProtocol) + + def test_arbitrary_object_does_not_satisfy_protocol(self): + assert not isinstance(object(), DurableStorageProtocol) + + def test_protocol_is_runtime_checkable(self): + # ``runtime_checkable`` is what makes the ``isinstance`` checks above + # legal; guard against the decorator being dropped. + backend: Any = _CompleteBackend() + assert isinstance(backend, DurableStorageProtocol) diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py index 2b00fa4a408cc..9c5eb288f40f0 100644 --- a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py +++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py @@ -22,6 +22,7 @@ from pydantic_ai.messages import ModelResponse, TextPart from pydantic_ai.models import ModelRequestParameters +from airflow.providers.common.ai.durable.base import DURABLE_KEY_PREFIX as P from airflow.providers.common.ai.durable.caching_model import CachingModel from airflow.providers.common.ai.durable.fingerprint import fingerprint_model_request from airflow.providers.common.ai.durable.step_counter import DurableStepCounter @@ -82,7 +83,7 @@ async def test_returns_cached_response_without_calling_model( assert result is sample_response mock_model.request.assert_not_called() - mock_storage.load_model_response.assert_called_once_with("model_step_0") + mock_storage.load_model_response.assert_called_once_with(f"{P}model_step_0") @pytest.mark.asyncio async def test_advances_counter_on_cache_hit(self, mock_model, mock_storage, counter, sample_response): @@ -105,7 +106,7 @@ async def test_calls_model_and_caches_on_miss(self, mock_model, mock_storage, co assert result is sample_response mock_model.request.assert_called_once() mock_storage.save_model_response.assert_called_once_with( - "model_step_0", sample_response, fingerprint=request_fingerprint() + f"{P}model_step_0", sample_response, fingerprint=request_fingerprint() ) @pytest.mark.asyncio @@ -119,7 +120,7 @@ async def test_sequential_calls_use_incrementing_keys(self, mock_model, mock_sto await caching.request([], None, ModelRequestParameters()) keys = [call[0][0] for call in mock_storage.save_model_response.call_args_list] - assert keys == ["model_step_0", "model_step_1"] + assert keys == [f"{P}model_step_0", f"{P}model_step_1"] class TestCachingModelReplayVerification: @@ -139,7 +140,7 @@ async def test_fingerprint_mismatch_treated_as_miss( mock_model.request.assert_called_once() assert counter.replayed_model == 0 mock_storage.save_model_response.assert_called_once_with( - "model_step_0", sample_response, fingerprint=request_fingerprint() + f"{P}model_step_0", sample_response, fingerprint=request_fingerprint() ) @pytest.mark.asyncio diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py index d7104928a1496..c0570f388a10c 100644 --- a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py +++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py @@ -23,6 +23,7 @@ from pydantic_ai.messages import ModelResponse, TextPart from pydantic_ai.models import ModelRequestParameters +from airflow.providers.common.ai.durable.base import DURABLE_KEY_PREFIX as P from airflow.providers.common.ai.durable.caching_model import CachingModel from airflow.providers.common.ai.durable.caching_toolset import CachingToolset from airflow.providers.common.ai.durable.fingerprint import fingerprint_tool_call @@ -67,7 +68,7 @@ async def test_returns_cached_result_without_calling_tool(self, mock_toolset, mo assert result == "cached result" mock_toolset.call_tool.assert_not_called() - mock_storage.load_tool_result.assert_called_once_with("tool_step_0") + mock_storage.load_tool_result.assert_called_once_with(f"{P}tool_step_0") @pytest.mark.asyncio async def test_advances_counter_on_cache_hit(self, mock_toolset, mock_storage, counter): @@ -90,7 +91,9 @@ async def test_calls_tool_and_caches_on_miss(self, mock_toolset, mock_storage, c assert result == "fresh result" mock_toolset.call_tool.assert_called_once() mock_storage.save_tool_result.assert_called_once_with( - "tool_step_0", "fresh result", fingerprint=fingerprint_tool_call("search", {"q": "foo"}, "call_1") + f"{P}tool_step_0", + "fresh result", + fingerprint=fingerprint_tool_call("search", {"q": "foo"}, "call_1"), ) @pytest.mark.asyncio @@ -102,7 +105,7 @@ async def test_sequential_calls_use_incrementing_keys(self, mock_toolset, mock_s await caching.call_tool("tool_b", {}, ctx_for(), MagicMock()) keys = [call[0][0] for call in mock_storage.save_tool_result.call_args_list] - assert keys == ["tool_step_0", "tool_step_1"] + assert keys == [f"{P}tool_step_0", f"{P}tool_step_1"] class TestCachingToolsetReplayVerification: @@ -173,6 +176,6 @@ async def test_model_and_toolset_share_counter(self, mock_toolset, mock_storage) model_keys = [call[0][0] for call in mock_storage.save_model_response.call_args_list] tool_keys = [call[0][0] for call in mock_storage.save_tool_result.call_args_list] - assert model_keys == ["model_step_0", "model_step_2"] - assert tool_keys == ["tool_step_1"] + assert model_keys == [f"{P}model_step_0", f"{P}model_step_2"] + assert tool_keys == [f"{P}tool_step_1"] assert counter.total_steps == 3 diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_task_state_store.py b/providers/common/ai/tests/unit/common/ai/durable/test_task_state_store.py new file mode 100644 index 0000000000000..8a5a657d4c429 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_task_state_store.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json + +import pytest + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + # ``airflow.sdk.execution_time.context`` exists on older cores, but ``NEVER_EXPIRE`` + # (imported transitively via ``task_state_store``) only lands in 3.3, so an + # ``importorskip`` on the module is not enough -- gate on the version instead. + pytest.skip("task state store needs Airflow >= 3.3", allow_module_level=True) + +from pydantic_ai.messages import ( + ModelMessagesTypeAdapter, + ModelResponse, + TextPart, +) +from pydantic_ai.usage import RequestUsage + +from airflow.providers.common.ai.durable.base import TOOL_RESULT_SENTINEL +from airflow.providers.common.ai.durable.task_state_store import TaskStateStoreDurableStorage +from airflow.sdk.execution_time.context import NEVER_EXPIRE + + +class FakeTaskStateStore: + """In-memory stand-in for the ``context['task_state_store']`` accessor.""" + + def __init__(self) -> None: + self.store: dict = {} + self.set_retentions: dict = {} + self.deleted: list[str] = [] + + def get(self, key, default=None): + return self.store.get(key, default) + + def set(self, key, value, *, retention=None): + # Mirror the real accessor, which serializes to JSON and persists in a + # Text column -- so the round-trip is exercised, not a by-reference stash. + self.store[key] = json.loads(json.dumps(value)) + self.set_retentions[key] = retention + + def delete(self, key): + self.deleted.append(key) + self.store.pop(key, None) + + +@pytest.fixture +def accessor(): + return FakeTaskStateStore() + + +@pytest.fixture +def storage(accessor): + return TaskStateStoreDurableStorage(accessor) + + +@pytest.fixture +def sample_response(): + return ModelResponse(parts=[TextPart(content="Hello!")]) + + +class TestSaveLoadModelResponse: + def test_save_and_load_roundtrips(self, storage, sample_response): + storage.save_model_response("model_step_0", sample_response, fingerprint="fp_abc") + + loaded, fingerprint = storage.load_model_response("model_step_0") + + assert loaded is not None + assert loaded.parts[0].content == "Hello!" + assert fingerprint == "fp_abc" + + def test_stored_with_never_expire(self, storage, accessor, sample_response): + """Cache entries must survive every retry regardless of retry_delay or retention config.""" + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + + assert accessor.set_retentions["model_step_0"] is NEVER_EXPIRE + + def test_stored_entry_is_native_json_not_a_string(self, storage, accessor, sample_response): + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + + entry = accessor.store["model_step_0"] + assert isinstance(entry, dict) + assert isinstance(entry["data"], list) + assert entry["fingerprint"] == "fp" + + def test_metadata_carrying_response_roundtrips_byte_identical(self, storage): + """A later step fingerprints earlier responses in history, metadata and all; + a store/load cycle that altered any of it would mismatch and re-run.""" + resp = ModelResponse( + parts=[TextPart(content="answer")], + usage=RequestUsage(input_tokens=11, output_tokens=22), + model_name="gpt-x", + provider_response_id="resp_xyz", + finish_reason="stop", + ) + before = ModelMessagesTypeAdapter.dump_python([resp], mode="json") + + storage.save_model_response("model_step_0", resp, fingerprint="fp") + loaded, _ = storage.load_model_response("model_step_0") + + after = ModelMessagesTypeAdapter.dump_python([loaded], mode="json") + assert after == before + + def test_load_returns_none_when_missing(self, storage): + assert storage.load_model_response("model_step_0") == (None, None) + + def test_empty_data_list_degrades_to_miss(self, storage, accessor): + accessor.store["model_step_0"] = {"fingerprint": "fp", "data": []} + assert storage.load_model_response("model_step_0") == (None, None) + + def test_entry_missing_data_key_degrades_to_miss(self, storage, accessor): + accessor.store["model_step_0"] = {"fingerprint": "fp"} + assert storage.load_model_response("model_step_0") == (None, None) + + +class TestSaveLoadToolResult: + def test_save_and_load_roundtrips(self, storage): + storage.save_tool_result("tool_step_0", {"rows": [1, 2, 3]}, fingerprint="fp_tool") + + found, value, fingerprint = storage.load_tool_result("tool_step_0") + + assert found is True + assert value == {"rows": [1, 2, 3]} + assert fingerprint == "fp_tool" + + def test_stored_with_never_expire(self, storage, accessor): + storage.save_tool_result("tool_step_0", "result", fingerprint="fp") + + assert accessor.set_retentions["tool_step_0"] is NEVER_EXPIRE + + def test_none_result_roundtrips(self, storage): + """A cached ``None`` is a hit, distinguished from a missing entry by the sentinel.""" + storage.save_tool_result("tool_step_0", None, fingerprint="fp") + + found, value, _ = storage.load_tool_result("tool_step_0") + assert found is True + assert value is None + + def test_load_returns_false_when_missing(self, storage): + assert storage.load_tool_result("tool_step_0") == (False, None, None) + + def test_non_dict_entry_is_a_miss(self, storage, accessor): + accessor.store["tool_step_0"] = "not a dict" + assert storage.load_tool_result("tool_step_0") == (False, None, None) + + def test_entry_without_sentinel_is_a_miss(self, storage, accessor): + accessor.store["tool_step_0"] = {"value": "x"} + assert storage.load_tool_result("tool_step_0") == (False, None, None) + assert TOOL_RESULT_SENTINEL not in accessor.store["tool_step_0"] + + def test_non_serializable_result_is_skipped_not_raised(self, storage, accessor): + """A non-serializable tool result skips caching with a warning; the tool step still succeeds.""" + storage.save_tool_result("tool_step_0", object(), fingerprint="fp") # must not raise + + assert "tool_step_0" not in accessor.store + assert storage.load_tool_result("tool_step_0") == (False, None, None) + + def test_circular_reference_result_is_skipped_not_raised(self, storage, accessor): + circular: dict = {} + circular["self"] = circular + + storage.save_tool_result("tool_step_0", circular, fingerprint="fp") # must not raise + + assert "tool_step_0" not in accessor.store + + +class TestCleanup: + def test_cleanup_deletes_keys_written_this_run(self, storage, accessor, sample_response): + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + storage.save_tool_result("tool_step_1", "result", fingerprint="fp") + + storage.cleanup() + + assert "model_step_0" not in accessor.store + assert "tool_step_1" not in accessor.store + + def test_cleanup_deletes_keys_only_replayed_this_run(self, accessor, sample_response): + """A retry that replays an earlier attempt's keys (cache hits) must still clean them up.""" + TaskStateStoreDurableStorage(accessor).save_model_response( + "model_step_0", sample_response, fingerprint="fp" + ) + + retry = TaskStateStoreDurableStorage(accessor) + loaded, _ = retry.load_model_response("model_step_0") # cache hit, no re-write + assert loaded is not None + + retry.cleanup() + assert "model_step_0" not in accessor.store + + def test_cleanup_leaves_untracked_keys_untouched(self, storage, accessor, sample_response): + """Cleanup deletes only the durable keys it touched, never the whole task instance namespace.""" + accessor.store["user_key"] = "kept" + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + + storage.cleanup() + + assert "model_step_0" not in accessor.store + assert accessor.store["user_key"] == "kept" + + def test_cleanup_is_best_effort_on_delete_failure(self, accessor, sample_response): + """A failing delete must not propagate out of cleanup.""" + storage = TaskStateStoreDurableStorage(accessor) + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + accessor.delete = lambda key: (_ for _ in ()).throw(RuntimeError("boom")) + + storage.cleanup() # must not raise diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py index 1288dbbe6525b..262836368baa0 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py @@ -557,17 +557,44 @@ def test_durable_default_false(self): op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm") assert op.durable is False + @patch("airflow.providers.common.ai.operators.agent.AIRFLOW_V_3_3_PLUS", True) + def test_build_durable_storage_uses_task_state_store_on_3_3(self): + """On Airflow >= 3.3 the cache lives in the task state store -- no durable_cache_path needed.""" + from airflow.providers.common.ai.durable.task_state_store import TaskStateStoreDurableStorage + + accessor = MagicMock() + op = AgentOperator(task_id="t", prompt="p", llm_conn_id="c", durable=True) + + storage = op._build_durable_storage({"task_state_store": accessor}) + + assert isinstance(storage, TaskStateStoreDurableStorage) + assert storage._store is accessor + + @patch("airflow.providers.common.ai.operators.agent.AIRFLOW_V_3_3_PLUS", False) + def test_build_durable_storage_falls_back_to_object_storage_below_3_3(self): + """On Airflow < 3.3 the cache falls back to the ObjectStorage backend.""" + from airflow.providers.common.ai.durable.storage import DurableStorage + + ti = MagicMock(dag_id="d", task_id="t", run_id="r", map_index=-1) + op = AgentOperator(task_id="t", prompt="p", llm_conn_id="c", durable=True) + + storage = op._build_durable_storage({"task_instance": ti}) + + assert isinstance(storage, DurableStorage) + assert storage._cache_id == "d_t_r" + @patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m) @patch("pydantic_ai.models.infer_model", autospec=True) - @patch("airflow.providers.common.ai.durable.storage._get_base_path") + @patch("airflow.providers.common.ai.operators.agent.AgentOperator._build_durable_storage") @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) def test_execute_durable_wraps_model_and_cleans_up( - self, mock_hook_cls, mock_base_path, mock_infer, _, tmp_path + self, mock_hook_cls, mock_build_storage, mock_infer, _ ): - """durable=True wraps model with CachingModel and cleans up on success.""" - from airflow.sdk import ObjectStoragePath + """durable=True wraps the model with CachingModel and cleans up the cache on success.""" + from airflow.providers.common.ai.durable.base import DurableStorageProtocol - mock_base_path.return_value = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + storage = MagicMock(spec=DurableStorageProtocol) + mock_build_storage.return_value = storage mock_agent = MagicMock() mock_agent.run_sync.return_value = _make_mock_run_result("ok") @@ -577,21 +604,15 @@ def test_execute_durable_wraps_model_and_cleans_up( mock_agent.override.return_value.__exit__ = MagicMock(return_value=False) mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent - mock_resolved = MagicMock() - mock_infer.return_value = mock_resolved - - context = MagicMock() - context.__getitem__ = MagicMock( - return_value=MagicMock(dag_id="d", task_id="t", run_id="r", map_index=-1) - ) + mock_infer.return_value = MagicMock() op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm", durable=True) - result = op.execute(context=context) + result = op.execute(context=MagicMock()) assert result == "ok" mock_agent.override.assert_called_once() - override_kwargs = mock_agent.override.call_args[1] - assert "model" in override_kwargs + assert "model" in mock_agent.override.call_args[1] + storage.cleanup.assert_called_once() @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) def test_execute_non_durable_does_not_wrap(self, mock_hook_cls): @@ -733,15 +754,13 @@ def test_message_history_with_hitl_review_raises(self): @patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m) @patch("pydantic_ai.models.infer_model", autospec=True) - @patch("airflow.providers.common.ai.durable.storage._get_base_path") + @patch("airflow.providers.common.ai.operators.agent.AgentOperator._build_durable_storage") @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) - def test_durable_path_also_seeds_message_history( - self, mock_hook_cls, mock_base_path, mock_infer, _, tmp_path - ): + def test_durable_path_also_seeds_message_history(self, mock_hook_cls, mock_build_storage, mock_infer, _): """The durable branch forwards message_history into the cached run too.""" - from airflow.sdk import ObjectStoragePath + from airflow.providers.common.ai.durable.base import DurableStorageProtocol - mock_base_path.return_value = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + mock_build_storage.return_value = MagicMock(spec=DurableStorageProtocol) mock_agent = MagicMock(spec=["run_sync", "model", "override"]) mock_agent.run_sync.return_value = _make_mock_run_result("ok") @@ -751,16 +770,11 @@ def test_durable_path_also_seeds_message_history( mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent mock_infer.return_value = MagicMock() - context = MagicMock() - context.__getitem__ = MagicMock( - return_value=MagicMock(dag_id="d", task_id="t", run_id="r", map_index=-1) - ) - history_json = ModelMessagesTypeAdapter.dump_json(_sample_history()).decode() op = AgentOperator( task_id="test", prompt="test", llm_conn_id="my_llm", durable=True, message_history=history_json ) - op.execute(context=context) + op.execute(context=MagicMock()) passed = mock_agent.run_sync.call_args.kwargs["message_history"] assert len(passed) == 2