Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,17 @@ def fork(
but has its own identity and independent state going forward.
"""
fork_id = conversation_id or uuid.uuid4()
if agent is not None:
fork_agent = agent
else:
# Round-trip via JSON to produce a deep copy that avoids
# thread-lock pickling issues with model_copy(deep=True).
agent_cls = type(self.agent)
fork_agent = agent_cls.model_validate(
self.agent.model_dump(context={"expose_secrets": True}),
)
# Always deep-copy the agent (supplied or source) so the fork owns
# its own object graph. Required because __init__ mutates
# agent.llm._prompt_cache_key in place (#2917): a shared/aliased
# agent would clobber the source conversation's cache key.
# Round-trip via JSON avoids thread-lock pickling issues with
# model_copy(deep=True).
source_agent = agent if agent is not None else self.agent
agent_cls = type(source_agent)
fork_agent = agent_cls.model_validate(
source_agent.model_dump(context={"expose_secrets": True}),
Comment thread
VascoSch92 marked this conversation as resolved.
)
Comment thread
xingyaoww marked this conversation as resolved.

# Hold the state lock while reading mutable state from the source
# conversation to avoid torn reads if run() is executing concurrently.
Expand Down
26 changes: 26 additions & 0 deletions tests/sdk/conversation/local/test_fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,29 @@ def test_fork_persisted_events_survive_reload():
resumed_ids = [e.id for e in resumed.state.events]
assert evt_id_1 in resumed_ids
assert evt_id_2 in resumed_ids


def test_fork_default_does_not_clobber_source_cache_key():
"""Default fork() must leave the source's prompt_cache_key intact (#2917)."""
with tempfile.TemporaryDirectory() as tmpdir:
src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir)
src_key_before = src.agent.llm._prompt_cache_key

fork = src.fork()

assert src.agent.llm._prompt_cache_key == src_key_before == str(src.id)
assert fork.agent.llm._prompt_cache_key == str(fork.id)
assert fork.agent.llm._prompt_cache_key != src.agent.llm._prompt_cache_key


def test_fork_with_aliased_agent_does_not_clobber_source_cache_key():
"""fork(agent=source.agent) must not repin the source LLM's cache key (#2917)."""
with tempfile.TemporaryDirectory() as tmpdir:
src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir)
src_key_before = src.agent.llm._prompt_cache_key

fork = src.fork(agent=src.agent)

assert src.agent.llm._prompt_cache_key == src_key_before == str(src.id)
assert fork.agent.llm._prompt_cache_key == str(fork.id)
assert fork.agent.llm is not src.agent.llm
Loading