diff --git a/.agents/skills/custom-codereview-guide.md b/.agents/skills/custom-codereview-guide.md index 18d8dfba56..294c1acba8 100644 --- a/.agents/skills/custom-codereview-guide.md +++ b/.agents/skills/custom-codereview-guide.md @@ -109,12 +109,16 @@ If the updated package was uploaded **within the last 7 days**, treat it as a re ## What to Check - **Complexity**: Over-engineered solutions, unnecessary abstractions, complex logic that could be refactored -- **Testing**: Duplicate test coverage, tests for library features, missing edge case coverage +- **Testing**: Duplicate test coverage, tests for library features, missing edge case coverage. For code that writes to disk, verify that tests cover the **persistence round-trip** (write → close → reopen → verify), not just in-memory state - **Type Safety**: `# type: ignore` usage, missing type annotations, `getattr`/`hasattr` guards, mocking non-existent arguments - **Breaking Changes**: API changes affecting users, removed public fields/methods, changed defaults - **Code Quality**: Code duplication, missing comments for non-obvious decisions, inline imports (unless necessary for circular deps) - **Repository Conventions**: Use `pyright` not `mypy`, put fixtures in `conftest.py`, avoid `sys.path.insert` hacks - **Event Type Deprecation**: Changes to event types (Pydantic models used in serialization) must handle deprecated fields properly +- **Thread Safety**: New methods in `LocalConversation` that read or write `self._state` must use `with self._state:` — see the [Concurrency](#concurrency---localconversation-state-lock) section below +- **Persistence Paths**: Code that computes persistence directories must not double-append the conversation hex — see the [Persistence Paths](#persistence-path-construction) section below +- **Server-Side Cleanup**: Endpoints that create persistent state (directories, files) must have rollback logic for partial failures — see the [Server Error Handling](#server-side-error-handling) section below +- **Cross-File Data Flow**: When new code calls existing APIs (constructors, factory methods), trace 1–2 levels into those APIs to verify the caller uses them correctly. Bugs often hide at layer boundaries where the caller's assumptions don't match the callee's behavior ## Event Type Deprecation - Critical Review Checkpoint @@ -162,6 +166,50 @@ pydantic_core.ValidationError: Extra inputs are not permitted **This is a production-breaking change.** Do not approve PRs that modify event types without proper backward compatibility handling and tests. +## SDK Architecture Conventions + +These conventions codify patterns that are easy to violate when adding new features. Each was learned from a real bug. + +### Concurrency - LocalConversation State Lock + +`LocalConversation` protects mutable state with a FIFOLock accessed via `with self._state:`. **Every** method that reads or writes `self._state.events`, `self._state.stats`, `self._state.agent_state`, `self._state.activated_knowledge_skills`, or any other mutable field on `ConversationState` must hold this lock. There are currently ~13 call sites using this pattern. + +When reviewing a PR that adds a new method to `LocalConversation`: +1. Check whether it accesses any `self._state.*` field. +2. If yes, verify the access is inside a `with self._state:` block. +3. If not, flag it — the method is unsafe for concurrent use with `run()`. + +### Persistence Path Construction + +`BaseConversation.get_persistence_dir(base, conversation_id)` returns `str(Path(base) / conversation_id.hex)`. The `LocalConversation.__init__` constructor calls this automatically when `persistence_dir` is provided. + +**Rule:** Callers that pass `persistence_dir` to `LocalConversation()` must pass only the **base directory** (e.g., `/data/conversations/`). The constructor appends the conversation hex. Passing a pre-constructed full path (e.g., `/data/conversations/abc123`) causes double-appending: `/data/conversations/abc123/abc123`. + +When reviewing code that creates a new `LocalConversation` (fork, resume, migration): +1. Check what value is passed as `persistence_dir`. +2. Verify it does **not** already include the conversation ID hex. + +### Server-Side Error Handling + +Server endpoints in `conversation_service.py` that create persistent state (writing directories, files, or calling `fork()` which writes to disk) and then perform follow-up operations (like `_start_event_service`) must handle partial failure. + +**Pattern:** If the follow-up operation fails, clean up the already-written persistent state so it doesn't become an orphaned directory that confuses future startups. + +```python +# Good: rollback on failure +fork_dir = self.conversations_dir / fork_conv_id.hex +try: + fork_event_service = await self._start_event_service(fork_stored) +except Exception: + safe_rmtree(fork_dir) + raise +``` + +When reviewing server endpoints that create conversations or persistent artifacts: +1. Identify the "point of no return" where state is written to disk. +2. Check that subsequent operations are wrapped in try/except with cleanup. +3. For client-supplied IDs, verify there's a duplicate check before creating state (return 409 Conflict if taken). + ## What NOT to Comment On Do not leave comments for: diff --git a/examples/01_standalone_sdk/48_conversation_fork.py b/examples/01_standalone_sdk/48_conversation_fork.py new file mode 100644 index 0000000000..c5ddb04145 --- /dev/null +++ b/examples/01_standalone_sdk/48_conversation_fork.py @@ -0,0 +1,105 @@ +"""Fork a conversation to branch off for follow-up exploration. + +``Conversation.fork()`` deep-copies a conversation — events, agent config, +workspace metadata — into a new conversation with its own ID. The fork +starts in ``idle`` status and retains full event memory of the source, so +calling ``run()`` picks up right where the original left off. + +Use cases: + - CI agents that produced a wrong patch — engineer forks to debug + without losing the original run's audit trail + - A/B-testing prompts — fork at a given turn, change one variable, + compare downstream + - Swapping tools mid-conversation (fork-on-tool-change) +""" + +import os + +from openhands.sdk import LLM, Agent, Conversation, Tool +from openhands.tools.terminal import TerminalTool + + +# ----------------------------------------------------------------- +# Setup +# ----------------------------------------------------------------- +llm = LLM( + model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"), + api_key=os.getenv("LLM_API_KEY"), + base_url=os.getenv("LLM_BASE_URL", None), +) + +agent = Agent(llm=llm, tools=[Tool(name=TerminalTool.name)]) +cwd = os.getcwd() + +# ================================================================= +# 1. Run the source conversation +# ================================================================= +source = Conversation(agent=agent, workspace=cwd) +source.send_message("Run `echo hello-from-source` in the terminal.") +source.run() + +print("=" * 64) +print(" Conversation.fork() — SDK Example") +print("=" * 64) +print(f"\nSource conversation ID : {source.id}") +print(f"Source events count : {len(source.state.events)}") + +# ================================================================= +# 2. Fork and continue independently +# ================================================================= +fork = source.fork(title="Follow-up fork") +source_event_count = len(source.state.events) + +print("\n--- Fork created ---") +print(f"Fork ID : {fork.id}") +print(f"Fork events (copied) : {len(fork.state.events)}") +print(f"Fork title : {fork.state.tags.get('title')}") + +assert fork.id != source.id +assert len(fork.state.events) == source_event_count + +fork.send_message("Now run `echo hello-from-fork` in the terminal.") +fork.run() + +# Source is untouched +assert len(source.state.events) == source_event_count +print("\n--- After running fork ---") +print(f"Source events (unchanged): {source_event_count}") +print(f"Fork events (grew) : {len(fork.state.events)}") + +# ================================================================= +# 3. Fork with a different agent (tool-change / A/B testing) +# ================================================================= +alt_llm = LLM( + model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"), + api_key=os.getenv("LLM_API_KEY"), + base_url=os.getenv("LLM_BASE_URL", None), + usage_id="alt", +) +alt_agent = Agent(llm=alt_llm, tools=[Tool(name=TerminalTool.name)]) + +fork_alt = source.fork( + agent=alt_agent, + title="Tool-change experiment", + tags={"purpose": "a/b-test"}, +) + +print("\n--- Fork with alternate agent ---") +print(f"Fork ID : {fork_alt.id}") +print(f"Fork tags : {dict(fork_alt.state.tags)}") + +fork_alt.send_message("What command did you run earlier? Just tell me, no tools.") +fork_alt.run() + +print(f"Fork events : {len(fork_alt.state.events)}") + +# ================================================================= +# Summary +# ================================================================= +print(f"\n{'=' * 64}") +print("All done — fork() works end-to-end.") +print("=" * 64) + +# Report cost +cost = llm.metrics.accumulated_cost + alt_llm.metrics.accumulated_cost +print(f"EXAMPLE_COST: {cost}") diff --git a/examples/02_remote_agent_server/11_conversation_fork.py b/examples/02_remote_agent_server/11_conversation_fork.py new file mode 100644 index 0000000000..4f5536f02c --- /dev/null +++ b/examples/02_remote_agent_server/11_conversation_fork.py @@ -0,0 +1,201 @@ +"""Fork a conversation through the agent server REST API. + +Demonstrates ``RemoteConversation.fork()`` which delegates to the server's +``POST /api/conversations/{id}/fork`` endpoint. The fork deep-copies +events and state on the server side, then returns a new +``RemoteConversation`` pointing at the copy. + +Scenarios covered: + 1. Run a source conversation on the server + 2. Fork it — verify independent event histories + 3. Fork with a title and custom tags +""" + +import os +import subprocess +import sys +import tempfile +import threading +import time + +from pydantic import SecretStr + +from openhands.sdk import LLM, Agent, Conversation, RemoteConversation, Tool, Workspace +from openhands.tools.terminal import TerminalTool + + +# ----------------------------------------------------------------- +# Managed server helper (reused from example 01) +# ----------------------------------------------------------------- +def _stream_output(stream, prefix, target_stream): + try: + for line in iter(stream.readline, ""): + if line: + target_stream.write(f"[{prefix}] {line}") + target_stream.flush() + except Exception as e: + print(f"Error streaming {prefix}: {e}", file=sys.stderr) + finally: + stream.close() + + +class ManagedAPIServer: + """Context manager that starts and stops a local agent-server.""" + + def __init__(self, port: int = 8000, host: str = "127.0.0.1"): + self.port = port + self.host = host + self.process: subprocess.Popen[str] | None = None + self.base_url = f"http://{host}:{port}" + + def __enter__(self): + print(f"Starting agent-server on {self.base_url} ...") + self.process = subprocess.Popen( + [ + "python", + "-m", + "openhands.agent_server", + "--port", + str(self.port), + "--host", + self.host, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env={"LOG_JSON": "true", **os.environ}, + ) + assert self.process.stdout is not None + assert self.process.stderr is not None + threading.Thread( + target=_stream_output, + args=(self.process.stdout, "SERVER", sys.stdout), + daemon=True, + ).start() + threading.Thread( + target=_stream_output, + args=(self.process.stderr, "SERVER", sys.stderr), + daemon=True, + ).start() + + import httpx + + for _ in range(30): + try: + if httpx.get(f"{self.base_url}/health", timeout=1.0).status_code == 200: + print(f"Agent-server ready at {self.base_url}") + return self + except Exception: + pass + assert self.process.poll() is None, "Server exited unexpectedly" + time.sleep(1) + raise RuntimeError("Server failed to start in 30 s") + + def __exit__(self, *args): + if self.process: + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + time.sleep(0.5) + print("Agent-server stopped.") + + +# ----------------------------------------------------------------- +# Config +# ----------------------------------------------------------------- +api_key = os.getenv("LLM_API_KEY") +assert api_key, "LLM_API_KEY must be set" + +llm = LLM( + model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"), + api_key=SecretStr(api_key), + base_url=os.getenv("LLM_BASE_URL"), +) +agent = Agent(llm=llm, tools=[Tool(name=TerminalTool.name)]) + +# ----------------------------------------------------------------- +# Run +# ----------------------------------------------------------------- +with ManagedAPIServer(port=8002) as server: + workspace_dir = tempfile.mkdtemp(prefix="fork_demo_") + workspace = Workspace(host=server.base_url, working_dir=workspace_dir) + + # ============================================================= + # 1. Source conversation + # ============================================================= + source = Conversation(agent=agent, workspace=workspace) + assert isinstance(source, RemoteConversation) + + source.send_message("Run `echo hello-from-source` in the terminal.") + source.run() + + print("=" * 64) + print(" RemoteConversation.fork() — Agent-Server Example") + print("=" * 64) + print(f"\nSource conversation ID : {source.id}") + source_event_count = len(source.state.events) + print(f"Source events count : {source_event_count}") + + # ============================================================= + # 2. Fork and continue independently + # ============================================================= + fork = source.fork(title="Follow-up fork") + assert isinstance(fork, RemoteConversation) + + print("\n--- Fork created ---") + print(f"Fork ID : {fork.id}") + fork_event_count = len(fork.state.events) + print(f"Fork events (copied) : {fork_event_count}") + + assert fork.id != source.id + # The fork copies all persisted events from the server-side EventLog. + # The source's client-side list may additionally contain transient + # WebSocket-only events (e.g. full-state snapshots) that are never + # persisted, so we only assert the fork has a non-trivial number of + # events rather than exact parity. + assert fork_event_count > 0 + + fork.send_message("Now run `echo hello-from-fork` in the terminal.") + fork.run() + + print("\n--- After running fork ---") + print(f"Source events : {len(source.state.events)}") + print(f"Fork events (grew) : {len(fork.state.events)}") + assert len(fork.state.events) > fork_event_count + + # ============================================================= + # 3. Fork with tags + # ============================================================= + fork_tagged = source.fork( + title="Tagged experiment", + tags={"purpose": "a/b-test"}, + ) + assert isinstance(fork_tagged, RemoteConversation) + + print("\n--- Fork with tags ---") + print(f"Fork ID : {fork_tagged.id}") + + fork_tagged.send_message( + "What command did you run earlier? Just tell me, no tools." + ) + fork_tagged.run() + + print(f"Fork events : {len(fork_tagged.state.events)}") + + # ============================================================= + # Summary + # ============================================================= + print(f"\n{'=' * 64}") + print("All done — RemoteConversation.fork() works end-to-end.") + print("=" * 64) + + # Cleanup + fork.close() + fork_tagged.close() + source.close() + +cost = llm.metrics.accumulated_cost +print(f"EXAMPLE_COST: {cost}") diff --git a/openhands-agent-server/openhands/agent_server/conversation_router.py b/openhands-agent-server/openhands/agent_server/conversation_router.py index 33d6a3ea33..4b51e3595e 100644 --- a/openhands-agent-server/openhands/agent_server/conversation_router.py +++ b/openhands-agent-server/openhands/agent_server/conversation_router.py @@ -18,6 +18,7 @@ ConversationInfo, ConversationPage, ConversationSortOrder, + ForkConversationRequest, GenerateTitleRequest, GenerateTitleResponse, SendMessageRequest, @@ -392,3 +393,43 @@ async def condense_conversation( if not success: raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Conversation not found") return Success() + + +@conversation_router.post( + "/{conversation_id}/fork", + responses={ + 201: {"description": "Forked conversation created"}, + 404: {"description": "Source conversation not found"}, + 409: {"description": "Fork ID already in use"}, + }, + status_code=status.HTTP_201_CREATED, +) +async def fork_conversation( + conversation_id: UUID, + request: Annotated[ForkConversationRequest, Body()] = ForkConversationRequest(), # noqa: B008 + conversation_service: ConversationService = Depends(get_conversation_service), +) -> ConversationInfo: + """Fork a conversation, deep-copying its event history. + + The fork starts in ``idle`` status with a fresh event loop. + Calling ``run`` on the fork resumes from the copied state, meaning + the agent has full event memory of the source conversation. + """ + try: + info = await conversation_service.fork_conversation( + conversation_id, + fork_id=request.id, + title=request.title, + tags=request.tags if request.tags is not None else None, + reset_metrics=request.reset_metrics, + ) + except ValueError as exc: + if "already exists" in str(exc): + raise HTTPException(status.HTTP_409_CONFLICT, detail=str(exc)) from exc + raise + if info is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail="Source conversation not found", + ) + return info diff --git a/openhands-agent-server/openhands/agent_server/conversation_service.py b/openhands-agent-server/openhands/agent_server/conversation_service.py index 3cc2bcbb68..fc581c3f1c 100644 --- a/openhands-agent-server/openhands/agent_server/conversation_service.py +++ b/openhands-agent-server/openhands/agent_server/conversation_service.py @@ -659,6 +659,73 @@ async def condense(self, conversation_id: UUID) -> bool: await event_service.condense() return True + async def fork_conversation( + self, + source_id: UUID, + *, + fork_id: UUID | None = None, + title: str | None = None, + tags: dict[str, str] | None = None, + reset_metrics: bool = True, + ) -> ConversationInfo | None: + """Fork an existing conversation, deep-copying its event history. + + The fork is persisted to disk and then loaded as a new EventService, + so the forked conversation is fully independent from the source. + + Returns ``None`` when *source_id* does not exist. + + Raises: + ValueError: If *fork_id* is already taken by an active + conversation. + """ + if self._event_services is None: + raise ValueError("inactive_service") + + # Reject duplicate fork IDs early to avoid clobbering an active + # conversation or leaking an EventService reference. + if fork_id is not None and fork_id in self._event_services: + raise ValueError(f"Conversation with id {fork_id} already exists") + + source_service = self._event_services.get(source_id) + if source_service is None: + return None + + source_conversation = source_service.get_conversation() + + # fork() deep-copies events, state, and writes to a new persistence dir. + fork_conv = await asyncio.to_thread( + source_conversation.fork, + conversation_id=fork_id, + title=title, + tags=tags, + reset_metrics=reset_metrics, + ) + # Extract the persisted data, then discard the temporary conversation. + fork_conv_id = fork_conv.id + fork_agent = cast(Agent, fork_conv.agent) + fork_workspace = fork_conv.workspace + fork_conv.delete_on_close = False + fork_conv.close() + + # _start_event_service will resume from the persisted fork directory. + fork_stored = StoredConversation( + id=fork_conv_id, + agent=fork_agent, + workspace=fork_workspace, + ) + # If the service fails to start, clean up the orphaned persistence + # directory so we don't leave stale state on disk. + fork_dir = self.conversations_dir / fork_conv_id.hex + try: + fork_event_service = await self._start_event_service(fork_stored) + except Exception: + safe_rmtree(fork_dir) + raise + + state = await fork_event_service.get_state() + return _compose_conversation_info_v1(fork_event_service.stored, state) + async def __aenter__(self): self.conversations_dir.mkdir(parents=True, exist_ok=True) self._event_services = {} diff --git a/openhands-agent-server/openhands/agent_server/models.py b/openhands-agent-server/openhands/agent_server/models.py index 7161c97bf8..f3290fdfe9 100644 --- a/openhands-agent-server/openhands/agent_server/models.py +++ b/openhands-agent-server/openhands/agent_server/models.py @@ -309,6 +309,34 @@ class UpdateConversationRequest(BaseModel): ) +class ForkConversationRequest(BaseModel): + """Payload to fork a conversation.""" + + id: UUID | None = Field( + default=None, + description="ID for the forked conversation (auto-generated if null)", + ) + title: str | None = Field( + default=None, + max_length=200, + description="Optional title for the forked conversation", + ) + tags: ConversationTags | None = Field( + default=None, + description=( + "Optional tags for the forked conversation. Keys must be " + "lowercase alphanumeric." + ), + ) + reset_metrics: bool = Field( + default=True, + description=( + "If true, cost/token stats start fresh on the fork. " + "If false, metrics are copied from the source." + ), + ) + + class GenerateTitleRequest(BaseModel): """Payload to generate a title for a conversation.""" diff --git a/openhands-sdk/openhands/sdk/conversation/base.py b/openhands-sdk/openhands/sdk/conversation/base.py index f131889d1e..be3db01d5c 100644 --- a/openhands-sdk/openhands/sdk/conversation/base.py +++ b/openhands-sdk/openhands/sdk/conversation/base.py @@ -304,6 +304,38 @@ def execute_tool(self, tool_name: str, action: Action) -> Observation: """ ... + @abstractmethod + def fork( + self, + *, + conversation_id: ConversationID | None = None, + agent: "AgentBase | None" = None, + title: str | None = None, + tags: dict[str, str] | None = None, + reset_metrics: bool = True, + ) -> "BaseConversation": + """Deep-copy this conversation with a new ID. + + Events are copied so the source remains immutable. The fork starts + in ``execution_status='idle'``; calling ``run()`` resumes from the + copied state — meaning the agent has full event memory of the source. + + Args: + conversation_id: ID for the forked conversation (auto-generated + if ``None``). + agent: Agent for the fork. Defaults to a deep-copy of the + source agent. + title: Optional title for the forked conversation. + tags: Optional tags for the forked conversation. + reset_metrics: If ``True`` (default), cost/token stats start + fresh on the fork. + + Returns: + A new conversation that shares the same event history but has + its own identity and independent state going forward. + """ + ... + @staticmethod def compose_callbacks(callbacks: Iterable[CallbackType]) -> CallbackType: """Compose multiple callbacks into a single callback function. diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 99c794d1ff..33be67b60b 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -1,4 +1,5 @@ import atexit +import copy import uuid from collections.abc import Mapping from pathlib import Path @@ -307,6 +308,107 @@ def resolved_plugins(self) -> list[ResolvedPluginSource] | None: """ return self._resolved_plugins + def fork( + self, + *, + conversation_id: ConversationID | None = None, + agent: AgentBase | None = None, + title: str | None = None, + tags: dict[str, str] | None = None, + reset_metrics: bool = True, + ) -> "LocalConversation": + """Deep-copy this conversation with a new ID. + + Events are copied so the source remains immutable. The fork starts + in ``execution_status='idle'``; calling ``run()`` resumes from the + copied state — meaning the agent has full event memory of the source. + + Args: + conversation_id: ID for the forked conversation (auto-generated + if ``None``). + agent: Agent for the fork. Defaults to a deep-copy of the + source agent. + title: Optional title for the forked conversation. + tags: Optional tags for the forked conversation. + reset_metrics: If ``True`` (default), cost/token stats start + fresh on the fork. + + Returns: + A new ``LocalConversation`` that shares the same event history + but has its own identity and independent state going forward. + """ + fork_id = conversation_id or uuid.uuid4() + if agent is not None: + fork_agent = agent + else: + # Round-trip via JSON to produce a deep copy that avoids + # thread-lock pickling issues with model_copy(deep=True). + agent_cls = type(self.agent) + fork_agent = agent_cls.model_validate( + self.agent.model_dump(context={"expose_secrets": True}), + ) + + # Hold the state lock while reading mutable state from the source + # conversation to avoid torn reads if run() is executing concurrently. + with self._state: + # Determine persistence_dir for the fork. + # Pass the *base* directory only — __init__ calls + # get_persistence_dir() which appends the conversation ID hex, + # so we must not do that here. + source_persistence = self._state.persistence_dir + fork_persistence: str | None = None + if source_persistence is not None: + source_path = Path(source_persistence) + fork_persistence = str(source_path.parent) + + # Build the fork conversation (empty – no events yet) + fork_conv = LocalConversation( + agent=fork_agent, + workspace=self.workspace, + plugins=self._plugin_specs, + persistence_dir=fork_persistence, + conversation_id=fork_id, + max_iteration_per_run=self.max_iteration_per_run, + stuck_detection=self._stuck_detector is not None, + visualizer=type(self._visualizer) if self._visualizer else None, + delete_on_close=self.delete_on_close, + tags=tags, + ) + + # Deep-copy events from source → fork so the source stays + # immutable. + for event in self._state.events: + fork_conv._state.events.append(event.model_copy(deep=True)) + + # Copy runtime state that accumulated during the source + # conversation. activated_knowledge_skills is list[str] – strings + # are immutable so a shallow list copy is sufficient. + # agent_state can hold arbitrary mutable values, so deep-copy it. + fork_conv._state.activated_knowledge_skills = list( + self._state.activated_knowledge_skills + ) + fork_conv._state.agent_state = copy.deepcopy(self._state.agent_state) + + # Copy title via tags if provided + if title is not None: + fork_conv._state.tags = { + **fork_conv._state.tags, + "title": title, + } + + # Reset or copy metrics + if not reset_metrics: + fork_conv._state.stats = self._state.stats.model_copy(deep=True) + + event_count = len(self._state.events) + + logger.info( + f"Forked conversation {self.id} → {fork_id} " + f"({event_count} events copied, " + f"reset_metrics={reset_metrics})" + ) + return fork_conv + def _ensure_plugins_loaded(self) -> None: """Lazy load plugins and set up hooks on first use. diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index 2dec73d91d..b881d58849 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -1300,6 +1300,79 @@ def condense(self) -> None: f"{self._conversation_action_base_path}/{self._id}/condense", ) + def fork( + self, + *, + conversation_id: "ConversationID | None" = None, + agent: "AgentBase | None" = None, + title: str | None = None, + tags: dict[str, str] | None = None, + reset_metrics: bool = True, + ) -> "RemoteConversation": + """Fork this conversation on the remote agent server. + + Sends a fork request to the server which deep-copies events and + state. Returns a new ``RemoteConversation`` pointing at the fork. + + Args: + conversation_id: ID for the forked conversation (auto-generated + on the server if ``None``). + agent: **Not supported for remote conversations.** Passing a + non-``None`` value raises ``NotImplementedError``. Use + ``LocalConversation.fork(agent=...)`` for agent replacement. + title: Optional title for the forked conversation. + tags: Optional tags for the forked conversation. + reset_metrics: If ``True`` (default), cost/token stats start + fresh on the fork. + + Returns: + A new ``RemoteConversation`` backed by the forked server-side + conversation. + + Raises: + NotImplementedError: If ``agent`` is provided. + """ + if agent is not None: + raise NotImplementedError( + "Agent replacement is not supported for remote conversation " + "forks. Use LocalConversation.fork(agent=...) instead." + ) + + body: dict[str, object] = {"reset_metrics": reset_metrics} + if conversation_id is not None: + body["id"] = str(conversation_id) + if title is not None: + body["title"] = title + if tags is not None: + body["tags"] = tags + + resp = _send_request( + self._client, + "POST", + f"{self._conversation_action_base_path}/{self._id}/fork", + json=body, + ) + fork_info = resp.json() + fork_uuid = uuid.UUID(fork_info["id"]) + + agent_cls = type(self.agent) + fork_agent = agent_cls.model_validate( + self.agent.model_dump(context={"expose_secrets": True}), + ) + + # Use server-returned tags (which include merged title) rather than + # the input tags, so the client-side object stays consistent. + server_tags: dict[str, str] | None = fork_info.get("tags") or None + + return RemoteConversation( + agent=fork_agent, + workspace=self.workspace, + conversation_id=fork_uuid, + max_iteration_per_run=self.max_iteration_per_run, + delete_on_close=self.delete_on_close, + tags=server_tags, + ) + def execute_tool(self, tool_name: str, action: "Action") -> "Observation": """Execute a tool directly without going through the agent loop. diff --git a/tests/agent_server/test_conversation_router.py b/tests/agent_server/test_conversation_router.py index fe86ee8704..50e2c57c90 100644 --- a/tests/agent_server/test_conversation_router.py +++ b/tests/agent_server/test_conversation_router.py @@ -1729,3 +1729,71 @@ def test_switch_conversation_profile_corrupted_profile( mock_conversation.switch_profile.assert_called_once_with("corrupted") finally: client.app.dependency_overrides.clear() + + +def test_fork_conversation_success( + client, mock_conversation_service, sample_conversation_info, sample_conversation_id +): + """Test fork endpoint returns 201 with forked conversation info.""" + mock_conversation_service.fork_conversation.return_value = sample_conversation_info + + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/fork", + json={"title": "Forked", "reset_metrics": True}, + ) + + assert response.status_code == 201 + data = response.json() + assert data["id"] == str(sample_conversation_info.id) + mock_conversation_service.fork_conversation.assert_called_once() + finally: + client.app.dependency_overrides.clear() + + +def test_fork_conversation_not_found( + client, mock_conversation_service, sample_conversation_id +): + """Test fork returns 404 when source conversation doesn't exist.""" + mock_conversation_service.fork_conversation.return_value = None + + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/fork", + json={}, + ) + + assert response.status_code == 404 + finally: + client.app.dependency_overrides.clear() + + +def test_fork_conversation_duplicate_id_returns_409( + client, mock_conversation_service, sample_conversation_id +): + """Test fork returns 409 when the requested fork ID already exists.""" + mock_conversation_service.fork_conversation.side_effect = ValueError( + f"Conversation with id {sample_conversation_id} already exists" + ) + + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/fork", + json={"id": str(sample_conversation_id)}, + ) + + assert response.status_code == 409 + finally: + client.app.dependency_overrides.clear() diff --git a/tests/sdk/conversation/local/test_fork.py b/tests/sdk/conversation/local/test_fork.py new file mode 100644 index 0000000000..34d379ae5a --- /dev/null +++ b/tests/sdk/conversation/local/test_fork.py @@ -0,0 +1,252 @@ +"""Tests for Conversation.fork() primitive.""" + +import tempfile +import uuid +from pathlib import Path + +import pytest +from pydantic import SecretStr + +from openhands.sdk.agent import Agent +from openhands.sdk.conversation import Conversation +from openhands.sdk.conversation.state import ConversationExecutionStatus +from openhands.sdk.event.llm_convertible import MessageEvent +from openhands.sdk.llm import LLM, Message, TextContent + + +def _agent() -> Agent: + return Agent( + llm=LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test"), + tools=[], + ) + + +def _msg(event_id: str, text: str = "hi") -> MessageEvent: + return MessageEvent( + id=event_id, + llm_message=Message(role="user", content=[TextContent(text=text)]), + source="user", + ) + + +def test_fork_creates_new_id(): + """Forked conversation must have a distinct ID.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork() + + assert fork.id != src.id + assert isinstance(fork.id, uuid.UUID) + + +def test_fork_with_explicit_id(): + """Explicit conversation_id is honoured.""" + custom_id = uuid.uuid4() + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork(conversation_id=custom_id) + + assert fork.id == custom_id + + +def test_fork_copies_events(): + """Events from the source must appear in the fork.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + src.state.events.append(_msg("evt-1", "hello")) + src.state.events.append(_msg("evt-2", "world")) + + fork = src.fork() + + # The fork should have at least the events we added + fork_ids = [e.id for e in fork.state.events] + assert "evt-1" in fork_ids + assert "evt-2" in fork_ids + + +def test_fork_source_unmodified(): + """Appending to the fork must not affect the source.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + src.state.events.append(_msg("src-evt")) + src_event_count = len(src.state.events) + + fork = src.fork() + fork.state.events.append(_msg("fork-only")) + + # Source should not grow + assert len(src.state.events) == src_event_count + + +def test_fork_execution_status_is_idle(): + """Forked conversation starts in idle status.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork() + + assert fork.state.execution_status == ConversationExecutionStatus.IDLE + + +def test_fork_resets_metrics_by_default(): + """By default, metrics on the fork should be fresh (empty).""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork() + + combined = fork.state.stats.get_combined_metrics() + assert combined.accumulated_cost == 0 + + +def test_fork_preserves_metrics_when_requested(): + """When reset_metrics=False the fork should carry over stats.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + # Inject a non-zero metric + from openhands.sdk.llm.utils.metrics import Metrics + + m = Metrics() + m.accumulated_cost = 1.5 + src._state.stats.usage_to_metrics["test"] = m + + fork = src.fork(reset_metrics=False) + + combined = fork.state.stats.get_combined_metrics() + assert combined.accumulated_cost == pytest.approx(1.5) + + +def test_fork_copies_agent_state(): + """agent_state dict should be carried over to the fork.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + src._state.agent_state = {"key": "value"} + + fork = src.fork() + + assert fork.state.agent_state == {"key": "value"} + # Mutation on fork should not affect source + fork._state.agent_state = {**fork._state.agent_state, "new": True} + assert "new" not in src._state.agent_state + + +def test_fork_accepts_replacement_agent(): + """Providing an agent kwarg replaces the source agent in the fork.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + alt_agent = Agent( + llm=LLM( + model="gpt-4o", + api_key=SecretStr("other-key"), + usage_id="alt", + ), + tools=[], + ) + + fork = src.fork(agent=alt_agent) + + assert fork.agent.llm.model == "gpt-4o" + # Source should keep its original agent + assert src.agent.llm.model == "gpt-4o-mini" + + +def test_fork_with_tags(): + """Tags should be passed through to the fork.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork(tags={"env": "test"}) + + assert fork.state.tags.get("env") == "test" + + +def test_fork_with_title_sets_tag(): + """Title is stored as a 'title' tag.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork(title="My Fork") + + assert fork.state.tags.get("title") == "My Fork" + + +def test_fork_shares_workspace(): + """Fork should reuse the same workspace as the source.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork() + + assert fork.workspace.working_dir == src.workspace.working_dir + + +def test_fork_event_deep_copy_isolation(): + """Mutating an event object in the fork must not affect the source.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + src.state.events.append(_msg("deep-evt", "original")) + + fork = src.fork() + + # The fork event is a different object + src_evt = src.state.events[0] + fork_evt = fork.state.events[0] + assert src_evt is not fork_evt + + # Mutating the fork event should not change the source + assert fork_evt.llm_message.content[0].text == "original" # type: ignore[union-attr] + fork_evt.llm_message.content[0].text = "mutated" # type: ignore[union-attr] + assert src_evt.llm_message.content[0].text == "original" # type: ignore[union-attr] + + +def test_fork_persistence_path_no_doubling(): + """Fork persistence dir must be a sibling of source, not nested inside it. + + Regression test: fork() previously computed the persistence path with + the conversation hex appended, but __init__ also appends it via + get_persistence_dir(), leading to /base/FORK_HEX/FORK_HEX. + """ + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + fork = src.fork() + + assert src._state.persistence_dir is not None + assert fork._state.persistence_dir is not None + src_path = Path(src._state.persistence_dir) + fork_path = Path(fork._state.persistence_dir) + + # Both should live directly under the same base directory + assert src_path.parent == fork_path.parent + # The fork dir should be /, not doubled + assert fork_path.name == fork.id.hex + + +def test_fork_persisted_events_survive_reload(): + """Events persisted by fork() should be loadable from the fork dir. + + This validates the path-doubling fix end-to-end: if the fork wrote + events to the wrong directory, resuming from the correct path would + see zero events. + """ + # Event IDs must be hex+dash, ≥8 chars to match EVENT_NAME_RE. + evt_id_1 = uuid.uuid4().hex + evt_id_2 = uuid.uuid4().hex + + with tempfile.TemporaryDirectory() as tmpdir: + src = Conversation(agent=_agent(), persistence_dir=tmpdir, workspace=tmpdir) + src.state.events.append(_msg(evt_id_1, "hello")) + src.state.events.append(_msg(evt_id_2, "world")) + + fork = src.fork() + fork_id = fork.id + + # The fork should have the events in-memory + assert len(fork.state.events) == 2 + + # Close the fork to flush persistence, then reopen from disk + fork.close() + + resumed = Conversation( + agent=_agent(), + persistence_dir=tmpdir, + workspace=tmpdir, + conversation_id=fork_id, + ) + resumed_ids = [e.id for e in resumed.state.events] + assert evt_id_1 in resumed_ids + assert evt_id_2 in resumed_ids diff --git a/tests/sdk/conversation/remote/test_remote_fork.py b/tests/sdk/conversation/remote/test_remote_fork.py new file mode 100644 index 0000000000..3d91252fe2 --- /dev/null +++ b/tests/sdk/conversation/remote/test_remote_fork.py @@ -0,0 +1,166 @@ +"""Tests for RemoteConversation.fork().""" + +import uuid +from unittest.mock import Mock, patch + +import pytest +from pydantic import SecretStr + +from openhands.sdk.agent import Agent +from openhands.sdk.conversation.impl.remote_conversation import RemoteConversation +from openhands.sdk.llm import LLM +from openhands.sdk.workspace import RemoteWorkspace + + +def _agent() -> Agent: + return Agent( + llm=LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test"), + tools=[], + ) + + +def _setup_workspace_with_mock_client( + host: str = "http://localhost:8000", + conversation_id: str | None = None, + fork_id: str | None = None, + fork_tags: dict[str, str] | None = None, +) -> tuple[RemoteWorkspace, Mock]: + """Set up workspace with a mock client that handles create + fork.""" + workspace = RemoteWorkspace(host=host, working_dir="/tmp") + mock_client = Mock() + workspace._client = mock_client + + if conversation_id is None: + conversation_id = str(uuid.uuid4()) + if fork_id is None: + fork_id = str(uuid.uuid4()) + + def request_side_effect(method: str, url: str, **kwargs: object) -> Mock: + response = Mock() + response.status_code = 200 + response.raise_for_status.return_value = None + + if method == "POST" and url == "/api/conversations": + response.json.return_value = { + "id": conversation_id, + "conversation_id": conversation_id, + } + elif method == "POST" and url.endswith("/fork"): + response.status_code = 201 + fork_response: dict[str, object] = { + "id": fork_id, + "conversation_id": fork_id, + "tags": fork_tags or {}, + } + response.json.return_value = fork_response + elif method == "GET" and "/events" in url: + response.json.return_value = {"items": [], "next_page_id": None} + else: + response.json.return_value = {} + + return response + + mock_client.request.side_effect = request_side_effect + return workspace, mock_client + + +@patch("openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient") +def test_remote_fork_sends_post_request(mock_ws_cls: Mock) -> None: + """fork() must POST to /{id}/fork.""" + mock_ws_cls.return_value = Mock() + fork_uuid = str(uuid.uuid4()) + workspace, mock_client = _setup_workspace_with_mock_client( + fork_id=fork_uuid, + ) + + conv = RemoteConversation(agent=_agent(), workspace=workspace) + fork = conv.fork() + + assert fork.id == uuid.UUID(fork_uuid) + + # Verify a POST …/fork call was made + fork_calls = [ + c + for c in mock_client.request.call_args_list + if c[0][0] == "POST" and str(c[0][1]).endswith("/fork") + ] + assert len(fork_calls) == 1 + + +@patch("openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient") +def test_remote_fork_uses_server_returned_tags(mock_ws_cls: Mock) -> None: + """The forked RemoteConversation constructor must receive tags from the + server response (which merges title), not the raw input kwargs. + + We verify by monkeypatching RemoteConversation to capture the tags kwarg + that the fork method passes to the constructor. + """ + mock_ws_cls.return_value = Mock() + server_tags = {"env": "test", "title": "My Fork"} + workspace, _ = _setup_workspace_with_mock_client(fork_tags=server_tags) + + conv = RemoteConversation(agent=_agent(), workspace=workspace) + + # Capture the kwargs passed to the fork's RemoteConversation() + captured_kwargs: dict[str, object] = {} + _orig_cls = RemoteConversation + + class _Capture(_orig_cls): + def __init__(self, **kwargs: object) -> None: # type: ignore[override] + captured_kwargs.update(kwargs) + super().__init__(**kwargs) # type: ignore[arg-type] + + # Temporarily replace the class reference used by the fork method. + import openhands.sdk.conversation.impl.remote_conversation as _mod + + _mod.RemoteConversation = _Capture # type: ignore[misc] + try: + conv.fork(title="My Fork", tags={"env": "test"}) + finally: + _mod.RemoteConversation = _orig_cls # type: ignore[misc] + + assert captured_kwargs.get("tags") == server_tags + + +@patch("openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient") +def test_remote_fork_raises_on_agent_param(mock_ws_cls: Mock) -> None: + """Passing agent= must raise NotImplementedError for remote forks.""" + mock_ws_cls.return_value = Mock() + workspace, _ = _setup_workspace_with_mock_client() + + conv = RemoteConversation(agent=_agent(), workspace=workspace) + + with pytest.raises(NotImplementedError, match="not supported"): + conv.fork(agent=_agent()) + + +@patch("openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient") +def test_remote_fork_passes_body_fields(mock_ws_cls: Mock) -> None: + """Verify conversation_id, title, tags, reset_metrics are sent in body.""" + mock_ws_cls.return_value = Mock() + custom_id = uuid.uuid4() + workspace, mock_client = _setup_workspace_with_mock_client( + fork_id=str(custom_id), + fork_tags={"env": "prod"}, + ) + + conv = RemoteConversation(agent=_agent(), workspace=workspace) + conv.fork( + conversation_id=custom_id, + title="Test Fork", + tags={"env": "prod"}, + reset_metrics=False, + ) + + fork_calls = [ + c + for c in mock_client.request.call_args_list + if c[0][0] == "POST" and str(c[0][1]).endswith("/fork") + ] + assert len(fork_calls) == 1 + + body = fork_calls[0][1].get("json", {}) + assert body["id"] == str(custom_id) + assert body["title"] == "Test Fork" + assert body["tags"] == {"env": "prod"} + assert body["reset_metrics"] is False diff --git a/tests/sdk/conversation/test_base_span_management.py b/tests/sdk/conversation/test_base_span_management.py index 01c1ddb536..f54ee101b3 100644 --- a/tests/sdk/conversation/test_base_span_management.py +++ b/tests/sdk/conversation/test_base_span_management.py @@ -68,6 +68,10 @@ def execute_tool(self, tool_name: str, action: Action) -> Observation: """Mock implementation of execute_tool method.""" raise NotImplementedError("Mock execute_tool not implemented") + def fork(self, **kwargs: Any) -> "MockConversation": + """Mock implementation of fork method.""" + raise NotImplementedError("Mock fork not implemented") + def test_base_conversation_span_management(): """Test that BaseConversation properly manages span state to prevent double-ending.""" # noqa: E501