diff --git a/operator_use/config/__init__.py b/operator_use/config/__init__.py index f9ce15e..583459d 100644 --- a/operator_use/config/__init__.py +++ b/operator_use/config/__init__.py @@ -21,6 +21,7 @@ ACPAgentEntry, ACPServerSettings, HeartbeatConfig, + SessionConfig, ToolsConfig, RetryConfig, SubagentConfig, @@ -50,6 +51,7 @@ "ACPAgentEntry", "ACPServerSettings", "HeartbeatConfig", + "SessionConfig", "ToolsConfig", "RetryConfig", "SubagentConfig", diff --git a/operator_use/config/service.py b/operator_use/config/service.py index 0602aae..a35bdce 100644 --- a/operator_use/config/service.py +++ b/operator_use/config/service.py @@ -287,6 +287,12 @@ class HeartbeatConfig(Base): llm_config: Optional[LLMConfig] = None # Dedicated LLM for heartbeat tasks +class SessionConfig(Base): + """Session lifecycle configuration.""" + + ttl_hours: float = 24.0 # Session idle timeout in hours (default: 24h) + + class Config(BaseSettings): """Root configuration for Operator.""" @@ -298,6 +304,7 @@ class Config(BaseSettings): search: SearchConfig = Field(default_factory=SearchConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig) heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig) + session: SessionConfig = Field(default_factory=SessionConfig) # Named registry of pre-approved remote ACP agents. # The LLM can only call agents listed here — it never supplies raw URLs. acp_agents: Dict[str, ACPAgentEntry] = Field(default_factory=dict) diff --git a/operator_use/session/service.py b/operator_use/session/service.py index 3a5f8ff..f3842f4 100644 --- a/operator_use/session/service.py +++ b/operator_use/session/service.py @@ -4,20 +4,30 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, Optional from operator_use.messages.service import BaseMessage from operator_use.utils.helper import ensure_directory -from operator_use.session.views import Session +from operator_use.session.views import Session, DEFAULT_SESSION_TTL class SessionStore: - """Store for sessions, keyed by session id. Persists to JSONL files.""" + """Store for sessions, keyed by session id. Persists to JSONL files. - def __init__(self, workspace: Path): + When *encryption_key* is provided (a URL-safe base-64 Fernet key), session + files are written as a single encrypted blob instead of plain JSONL lines. + The key can be generated with ``cryptography.fernet.Fernet.generate_key()``. + """ + + def __init__(self, workspace: Path, encryption_key: Optional[str] = None): self.workspace = Path(workspace) self.sessions_dir = ensure_directory(self.workspace / "sessions") self._sessions: dict[str, Session] = {} + self._fernet = None + if encryption_key: + from cryptography.fernet import Fernet + key_bytes = encryption_key.encode() if isinstance(encryption_key, str) else encryption_key + self._fernet = Fernet(key_bytes) def _session_id_to_filename(self, session_id: str) -> str: """Make session_id filesystem-safe (e.g. `:` invalid on Windows).""" @@ -26,10 +36,20 @@ def _session_id_to_filename(self, session_id: str) -> str: def _sessions_path(self, session_id: str) -> Path: return self.sessions_dir / f"{self._session_id_to_filename(session_id)}.jsonl" - def load(self, session_id: str) -> Session | None: + def load(self, session_id: str, ttl: float = DEFAULT_SESSION_TTL) -> Session | None: path = self._sessions_path(session_id) if not path.exists(): return None + + if self._fernet: + return self._load_encrypted(session_id, path, ttl) + + raw = path.read_bytes() + if raw.startswith(b"gAAAAA") and self._fernet is None: + raise ValueError( + f"Session file for '{session_id}' is Fernet-encrypted but no encryption_key was provided." + ) + messages: list[BaseMessage] = [] created_at = datetime.now() updated_at = datetime.now() @@ -49,16 +69,53 @@ def load(self, session_id: str) -> Session | None: continue if "role" in obj: messages.append(BaseMessage.from_dict(obj)) - return Session( + + return Session._from_persisted( id=session_id, messages=messages, created_at=created_at, updated_at=updated_at, metadata=metadata, + ttl=ttl, + ) + + def _load_encrypted(self, session_id: str, path: Path, ttl: float) -> Session | None: + """Load and decrypt a session file written by _save_encrypted().""" + from cryptography.fernet import InvalidToken + + if self._fernet is None: + raise ValueError( + f"Session {session_id!r} appears to be encrypted but no encryption_key was provided." + ) + raw = path.read_bytes() + try: + decrypted = self._fernet.decrypt(raw) + except InvalidToken as exc: + raise ValueError( + f"Failed to decrypt session '{session_id}': wrong key or corrupted data." + ) from exc + + payload = json.loads(decrypted.decode()) + created_at = datetime.fromisoformat(payload.get("created_at", datetime.now().isoformat())) + updated_at = datetime.fromisoformat(payload.get("updated_at", datetime.now().isoformat())) + metadata = payload.get("metadata", {}) + messages = [BaseMessage.from_dict(m) for m in payload.get("messages", [])] + return Session._from_persisted( + id=session_id, + messages=messages, + created_at=created_at, + updated_at=updated_at, + metadata=metadata, + ttl=ttl, ) def save(self, session: Session) -> None: path = self._sessions_path(session.id) + + if self._fernet: + self._save_encrypted(session, path) + return + with open(path, "w", encoding="utf-8") as f: meta = { "type": "metadata", @@ -71,15 +128,45 @@ def save(self, session: Session) -> None: for msg in session.messages: f.write(json.dumps(msg.to_dict()) + "\n") - def get_or_create(self, session_id: str | None = None) -> Session: - """Get a session by id, or create and store a new one. Loads from JSONL if exists.""" + def _save_encrypted(self, session: Session, path: Path) -> None: + """Serialize the session to JSON and write as a Fernet-encrypted blob.""" + payload = { + "id": session.id, + "created_at": session.created_at.isoformat(), + "updated_at": session.updated_at.isoformat(), + "metadata": session.metadata, + "messages": [msg.to_dict() for msg in session.messages], + } + token = self._fernet.encrypt(json.dumps(payload).encode()) + path.write_bytes(token) + + def get_or_create( + self, + session_id: Optional[str] = None, + ttl: float = DEFAULT_SESSION_TTL, + ) -> Session: + """Get a session by id, or create and store a new one. + + Loads from JSONL if exists. If the loaded session is expired (based on + real idle time derived from *updated_at*), it is deleted and a fresh + session is returned instead. + """ id = session_id or str(uuid.uuid4()) - if session := self._sessions.get(id): - return session - if session := self.load(id): - self._sessions[id] = session - return session - session = Session(id=id) + + if cached := self._sessions.get(id): + if not cached.is_expired(): + return cached + # In-memory session has expired — evict and fall through to create + del self._sessions[id] + + if session := self.load(id, ttl=ttl): + if session.is_expired(): + self.delete(id) + else: + self._sessions[id] = session + return session + + session = Session(id=id, ttl=ttl) self._sessions[id] = session return session @@ -108,6 +195,38 @@ def archive(self, session_id: str) -> bool: return True return False + def cleanup(self, ttl: float = DEFAULT_SESSION_TTL) -> list[str]: + """Delete all sessions whose idle time (since *updated_at*) exceeds *ttl*. + + Returns the list of session IDs that were removed. + Archived session files are skipped. + """ + # Build a reverse map: filesystem stem -> original session_id (in-memory key). + # Sessions with `:` in their IDs are stored under the original ID in + # self._sessions but their filename stem uses `_` as a replacement. + stem_to_original: dict[str, str] = { + self._session_id_to_filename(sid): sid for sid in self._sessions + } + + removed: list[str] = [] + for path in self.sessions_dir.glob("*.jsonl"): + # Skip archived sessions + if "_archived_" in path.stem: + continue + session_id_fs = path.stem + session = self.load(session_id_fs, ttl=ttl) + if session is None: + continue + if session.is_expired(): + path.unlink() + # Evict from in-memory cache using the original session ID if known, + # otherwise fall back to the filesystem-safe stem. + original_id = stem_to_original.get(session_id_fs, session_id_fs) + if original_id in self._sessions: + del self._sessions[original_id] + removed.append(original_id) + return removed + def list_sessions(self) -> list[dict[str, Any]]: """Load sessions from the sessions directory. Returns list of dicts with id, created_at, updated_at, path.""" result: list[dict[str, Any]] = [] diff --git a/operator_use/session/views.py b/operator_use/session/views.py index d52097b..d2f39e4 100644 --- a/operator_use/session/views.py +++ b/operator_use/session/views.py @@ -1,11 +1,17 @@ """Session views.""" +import time from dataclasses import dataclass, field from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any from operator_use.messages.service import BaseMessage +if TYPE_CHECKING: + from operator_use.config.service import Config + +DEFAULT_SESSION_TTL = 86400.0 # 24 hours (config-driven default) + @dataclass class Session: @@ -16,17 +22,65 @@ class Session: created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) + ttl: float = DEFAULT_SESSION_TTL + # _last_activity is set in __post_init__ so that tests can monkeypatch + # time.monotonic before instantiation and get a consistent starting value. + _last_activity: float = field(init=False, default=0.0) + + def __post_init__(self) -> None: + self._last_activity = time.monotonic() def add_message(self, message: BaseMessage) -> None: """Add a message and update updated_at.""" self.messages.append(message) self.updated_at = datetime.now() + self.touch() def get_history(self) -> list[BaseMessage]: """Return the message history.""" return list(self.messages) def clear(self) -> None: - """Clear all messages.""" + """Clear all messages and refresh the TTL window.""" self.messages.clear() self.updated_at = datetime.now() + self.touch() + + def touch(self) -> None: + """Refresh last_activity timestamp, extending the session TTL window.""" + self._last_activity = time.monotonic() + + def is_expired(self) -> bool: + """Return True if idle time since last activity exceeds the TTL.""" + return (time.monotonic() - self._last_activity) > self.ttl + + @classmethod + def from_config(cls, id: str, config: "Config") -> "Session": + """Construct a Session using TTL from config.session.ttl_hours.""" + ttl = config.session.ttl_hours * 3600 + return cls(id=id, ttl=ttl) + + @classmethod + def _from_persisted( + cls, + id: str, + messages: list[BaseMessage], + created_at: datetime, + updated_at: datetime, + metadata: dict[str, Any], + ttl: float = DEFAULT_SESSION_TTL, + ) -> "Session": + """Reconstruct a Session from disk, anchoring _last_activity to the + real idle time derived from updated_at so that loaded sessions expire + correctly rather than resetting to 'now'.""" + session = cls( + id=id, + messages=messages, + created_at=created_at, + updated_at=updated_at, + metadata=metadata, + ttl=ttl, + ) + idle_seconds = max(0.0, (datetime.now() - updated_at).total_seconds()) + session._last_activity = time.monotonic() - idle_seconds + return session diff --git a/operator_use/tools/control_center.py b/operator_use/tools/control_center.py index 537ed69..f918f51 100644 --- a/operator_use/tools/control_center.py +++ b/operator_use/tools/control_center.py @@ -12,6 +12,7 @@ import json import logging import os +import subprocess import sys from typing import Optional @@ -125,7 +126,7 @@ async def _do_restart(graceful_fn=None) -> None: ``os._exit(75)`` which skips cleanup but guarantees the process terminates. """ global _requested_exit_code - os.system("cls" if os.name == "nt" else "clear") + subprocess.run(["cls"] if os.name == "nt" else ["clear"], check=False) frames = ["↑", "↗", "→", "↘", "↓", "↙", "←", "↖"] for i in range(20): sys.stdout.write(f"\r {frames[i % len(frames)]} Restarting Operator...") diff --git a/pyproject.toml b/pyproject.toml index 111b2a5..2c1c103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "platformdirs>=4.0.0", "psutil>=7.0.0", "pynacl>=1.6.2", + "cryptography>=41.0", "comtypes>=1.4.15; sys_platform == 'win32'", "pywin32>=311; sys_platform == 'win32'", "pyobjc-framework-Cocoa>=10.0; sys_platform == 'darwin'", diff --git a/tests/test_agent.py b/tests/test_agent.py index 4fb6c3f..13db174 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -186,7 +186,7 @@ async def test_agent_run_with_tool_call_then_text(tmp_path): # Register a simple echo tool from pydantic import BaseModel - from operator_use.tools.service import Tool + from operator_use.agent.tools.service import Tool class EchoParams(BaseModel): message: str diff --git a/tests/test_browser_plugin.py b/tests/test_browser_plugin.py index 9089d95..9f062ed 100644 --- a/tests/test_browser_plugin.py +++ b/tests/test_browser_plugin.py @@ -24,9 +24,9 @@ def test_enabled_plugin_returns_system_prompt(): prompt = plugin.get_system_prompt() assert prompt is SYSTEM_PROMPT assert "browser" in prompt.lower() - assert "" in prompt - assert "" in prompt - assert "" in prompt + # Prompt is plain Markdown — assert on actual content, not old XML tags + assert "browser_task" in prompt + assert "Chrome" in prompt # --------------------------------------------------------------------------- @@ -38,7 +38,7 @@ def test_disabled_plugin_registers_no_tools(): plugin = BrowserPlugin(enabled=False) registry = ToolRegistry() plugin.register_tools(registry) - assert registry.get("browser") is None + assert registry.get("browser_task") is None def test_enabled_plugin_registers_browser_tool(): @@ -47,7 +47,7 @@ def test_enabled_plugin_registers_browser_tool(): plugin.browser = MagicMock() registry = ToolRegistry() plugin.register_tools(registry) - assert registry.get("browser") is not None + assert registry.get("browser_task") is not None def test_unregister_tools_removes_browser_tool(): @@ -57,11 +57,11 @@ def test_unregister_tools_removes_browser_tool(): registry = ToolRegistry() plugin.register_tools(registry) plugin.unregister_tools(registry) - assert registry.get("browser") is None + assert registry.get("browser_task") is None # --------------------------------------------------------------------------- -# register_hooks — BEFORE_LLM_CALL gated on _enabled +# register_hooks — hooks NOT registered to main agent (subagent arch) # --------------------------------------------------------------------------- @@ -72,20 +72,20 @@ def test_disabled_plugin_registers_no_hooks(): assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] -def test_enabled_plugin_registers_state_hook(): +def test_enabled_plugin_does_not_register_state_hook_to_main_agent(): + """Hooks are intentionally not wired to main agent — subagent manages its own state.""" plugin = BrowserPlugin(enabled=False) plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] -def test_unregister_hooks_removes_state_hook(): +def test_unregister_hooks_is_safe_noop(): plugin = BrowserPlugin(enabled=False) - plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - plugin.unregister_hooks(hooks) + plugin.unregister_hooks(hooks) # must not raise assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] @@ -99,7 +99,7 @@ def test_disabled_plugin_does_not_inject_prompt(): context = MagicMock() plugin.attach_prompt(context) context.register_plugin_prompt.assert_not_called() - assert plugin._context is context # reference still stored + assert plugin._context is context def test_enabled_plugin_injects_prompt(): @@ -125,7 +125,8 @@ def test_detach_prompt_removes_injected_prompt(): @pytest.mark.asyncio -async def test_enable_registers_hooks_and_injects_prompt(): +async def test_enable_injects_prompt_no_hooks(): + """enable() registers tools and injects prompt — hooks NOT wired to main agent.""" plugin = BrowserPlugin(enabled=False) hooks = Hooks() plugin.register_hooks(hooks) @@ -135,12 +136,12 @@ async def test_enable_registers_hooks_and_injects_prompt(): await plugin.enable() assert plugin._enabled is True - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] context.register_plugin_prompt.assert_called_once_with(SYSTEM_PROMPT) @pytest.mark.asyncio -async def test_disable_unregisters_hooks_and_removes_prompt(): +async def test_disable_removes_prompt(): plugin = BrowserPlugin(enabled=False) plugin._enabled = True hooks = Hooks() @@ -178,7 +179,7 @@ async def test_enable_then_disable_leaves_no_hooks(): async def test_state_hook_skips_when_no_browser_client(): plugin = BrowserPlugin(enabled=False) plugin.browser = MagicMock() - plugin.browser._client = None # no active session + plugin.browser._client = None ctx = MagicMock() ctx.messages = [] diff --git a/tests/test_computer_plugin.py b/tests/test_computer_plugin.py index 47e4480..4fea153 100644 --- a/tests/test_computer_plugin.py +++ b/tests/test_computer_plugin.py @@ -28,13 +28,12 @@ def test_enabled_plugin_returns_system_prompt(): prompt = plugin.get_system_prompt() assert prompt is SYSTEM_PROMPT assert "desktop" in prompt.lower() - assert "" in prompt - assert "" in prompt - assert "" in prompt + # Prompt is plain Markdown — assert on actual content, not old XML tags + assert "computer_task" in prompt # --------------------------------------------------------------------------- -# register_hooks — BEFORE_LLM_CALL + AFTER_TOOL_CALL, gated on _enabled +# register_hooks — hooks NOT registered to main agent (subagent arch) # --------------------------------------------------------------------------- @@ -46,21 +45,21 @@ def test_disabled_plugin_registers_no_hooks(): assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] -def test_enabled_plugin_registers_both_hooks(): +def test_enabled_plugin_does_not_register_hooks_to_main_agent(): + """Hooks are intentionally not wired to main agent — subagent manages its own state.""" plugin = ComputerPlugin(enabled=False) plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] - assert plugin._wait_for_ui_hook in hooks._handlers[HookEvent.AFTER_TOOL_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] -def test_unregister_hooks_removes_both(): +def test_unregister_hooks_is_safe_noop(): plugin = ComputerPlugin(enabled=False) - plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - plugin.unregister_hooks(hooks) + plugin.unregister_hooks(hooks) # must not raise assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] @@ -101,7 +100,8 @@ def test_detach_prompt_removes_injected_prompt(): @pytest.mark.asyncio -async def test_enable_registers_both_hooks_and_prompt(): +async def test_enable_injects_prompt_no_hooks(): + """enable() registers tools and injects prompt — hooks NOT wired to main agent.""" plugin = ComputerPlugin(enabled=False) hooks = Hooks() plugin.register_hooks(hooks) @@ -111,13 +111,13 @@ async def test_enable_registers_both_hooks_and_prompt(): await plugin.enable() assert plugin._enabled is True - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] - assert plugin._wait_for_ui_hook in hooks._handlers[HookEvent.AFTER_TOOL_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] context.register_plugin_prompt.assert_called_once_with(SYSTEM_PROMPT) @pytest.mark.asyncio -async def test_disable_unregisters_both_hooks_and_removes_prompt(): +async def test_disable_removes_prompt(): plugin = ComputerPlugin(enabled=False) plugin._enabled = True hooks = Hooks() @@ -160,27 +160,12 @@ async def test_state_hook_appends_desktop_state(): mock_state.to_string.return_value = "Active: Notepad | Elements: [button 'Save']" plugin.desktop = MagicMock() - import asyncio - - loop = asyncio.get_event_loop() - - async def _fake_executor(exc, fn): - return fn() - - plugin.desktop.get_state = MagicMock(return_value=mock_state) - - ctx = MagicMock() - ctx.messages = [] - - with pytest.MonkeyPatch().context() as mp: - mp.setattr(loop, "run_in_executor", lambda exc, fn: asyncio.coroutine(lambda: fn())()) - # Simpler: just patch run_in_executor at the asyncio level - - # Direct call with mocked executor from unittest.mock import patch with patch("asyncio.get_event_loop") as mock_loop: mock_loop.return_value.run_in_executor = AsyncMock(return_value=mock_state) + ctx = MagicMock() + ctx.messages = [] await plugin._state_hook(ctx) assert len(ctx.messages) == 1 @@ -192,19 +177,16 @@ async def test_state_hook_handles_exception_gracefully(): plugin = ComputerPlugin(enabled=False) plugin.desktop = MagicMock() - ctx = MagicMock() - ctx.messages = [] - from unittest.mock import patch with patch("asyncio.get_event_loop") as mock_loop: - mock_loop.return_value.run_in_executor = AsyncMock( - side_effect=RuntimeError("accessibility error") - ) + mock_loop.return_value.run_in_executor = AsyncMock(side_effect=RuntimeError("accessibility error")) + ctx = MagicMock() + ctx.messages = [] result = await plugin._state_hook(ctx) assert result is ctx - assert ctx.messages == [] # no message appended on error + assert ctx.messages == [] # --------------------------------------------------------------------------- diff --git a/tests/test_control_center.py b/tests/test_control_center.py index f3a2e5b..0efe749 100644 --- a/tests/test_control_center.py +++ b/tests/test_control_center.py @@ -4,7 +4,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch -from operator_use.agent.tools.builtin.control_center import ( +from operator_use.tools.control_center import ( control_center, _set_plugin_enabled, _get_plugin_enabled, diff --git a/tests/test_local_agents.py b/tests/test_local_agents.py index 8fd831b..a1b5168 100644 --- a/tests/test_local_agents.py +++ b/tests/test_local_agents.py @@ -2,7 +2,7 @@ import pytest -from operator_use.agent.tools.builtin.local_agents import LOCAL_AGENT_DELEGATION_CHAIN, localagents +from operator_use.tools.local_agents import LOCAL_AGENT_DELEGATION_CHAIN, localagents from operator_use.messages.service import AIMessage diff --git a/tests/test_plugins.py b/tests/test_plugins.py index f6ba6d4..5d9f8b9 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -7,7 +7,7 @@ from operator_use.agent.tools.registry import ToolRegistry from operator_use.agent.hooks.service import Hooks from operator_use.agent.hooks.events import HookEvent -from operator_use.tools.service import Tool +from operator_use.agent.tools.service import Tool from pydantic import BaseModel diff --git a/tests/test_session_ttl.py b/tests/test_session_ttl.py new file mode 100644 index 0000000..d78f3a8 --- /dev/null +++ b/tests/test_session_ttl.py @@ -0,0 +1,385 @@ +"""Tests for session TTL and auto-expiry. + +Validates that Session tracks last_activity, expires after its +configurable TTL, and that touch() extends the session lifetime. + +Covers all qodo findings for PR #32: +- Req Gap 1: TTL must be config-driven (24h default from SessionConfig) +- Req Gap 2: Loaded sessions must expire based on real age (updated_at) +- Req Gap 3: Cleanup and encryption round-trip coverage +- Bug 4: clear() must call touch() to refresh _last_activity +- Bug 5: Timing-sensitive tests replaced with monkeypatch +""" + +from __future__ import annotations + +import json +from datetime import datetime, timedelta +from pathlib import Path + +import pytest + +import operator_use.session.views as views_module +from operator_use.session.views import Session, DEFAULT_SESSION_TTL +from operator_use.session.service import SessionStore +from operator_use.messages.service import HumanMessage + + +# --------------------------------------------------------------------------- +# Existing TTL expiry / touch behaviour (Bug 5 fix: use monkeypatch clock) +# --------------------------------------------------------------------------- + +class TestSessionTTL: + def test_new_session_not_expired(self) -> None: + session = Session(id="test-1") + assert not session.is_expired() + + def test_default_ttl_is_24_hours(self) -> None: + """After Req Gap 1 fix: default TTL must be 24 hours (86400s), not 1 hour.""" + session = Session(id="test-2") + assert session.ttl == DEFAULT_SESSION_TTL + assert session.ttl == 86400.0 + + def test_custom_ttl(self) -> None: + session = Session(id="test-3", ttl=120.0) + assert session.ttl == 120.0 + + def test_session_expires_after_ttl(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Use monkeypatch clock — no real sleep, no CI flakiness.""" + fake_time = [0.0] + + def fake_monotonic() -> float: + return fake_time[0] + + monkeypatch.setattr(views_module.time, "monotonic", fake_monotonic) + + session = Session(id="test-4", ttl=100.0) + assert not session.is_expired() + + fake_time[0] = 101.0 # advance past TTL + assert session.is_expired() + + def test_touch_resets_expiry_clock(self, monkeypatch: pytest.MonkeyPatch) -> None: + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="test-5", ttl=100.0) + fake_time[0] = 60.0 # 60s elapsed — not expired + session.touch() + fake_time[0] = 120.0 # 60s after touch — within TTL + assert not session.is_expired() + + def test_session_expires_after_touch_if_ttl_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="test-6", ttl=100.0) + fake_time[0] = 50.0 + session.touch() + fake_time[0] = 160.0 # 110s since touch — past TTL + assert session.is_expired() + + def test_zero_ttl_immediately_expired(self, monkeypatch: pytest.MonkeyPatch) -> None: + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="test-7", ttl=0.0) + fake_time[0] = 0.001 # any elapsed time exceeds 0s TTL + assert session.is_expired() + + def test_negative_ttl_immediately_expired(self, monkeypatch: pytest.MonkeyPatch) -> None: + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="test-8", ttl=-1.0) + assert session.is_expired() # negative TTL: 0 > -1 is always true + + def test_very_large_ttl_does_not_expire(self, monkeypatch: pytest.MonkeyPatch) -> None: + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="test-9", ttl=1e9) + fake_time[0] = 1_000_000.0 + assert not session.is_expired() + + def test_multiple_touches_keep_session_alive(self, monkeypatch: pytest.MonkeyPatch) -> None: + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="test-10", ttl=100.0) + for i in range(1, 6): + fake_time[0] = i * 50.0 # 50s increments — each would expire without touch + session.touch() + assert not session.is_expired() + + +# --------------------------------------------------------------------------- +# Bug 4: clear() must call touch() +# --------------------------------------------------------------------------- + +class TestClearCallsTouch: + def test_clear_refreshes_last_activity(self, monkeypatch: pytest.MonkeyPatch) -> None: + """clear() is a mutating operation; it must extend the TTL window.""" + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="clear-1", ttl=100.0) + # Advance to just before expiry + fake_time[0] = 95.0 + # Clearing the session is activity — it must reset _last_activity + session.clear() + # Now advance another 95s (total 190s, but only 95s since clear) + fake_time[0] = 190.0 + assert not session.is_expired() + + def test_clear_without_touch_would_expire(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Verify the test logic: without calling touch() the session expires.""" + fake_time = [0.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + session = Session(id="clear-2", ttl=100.0) + fake_time[0] = 101.0 # past TTL — would expire if clear doesn't touch + # clear() must touch, so session should NOT be expired after clear + session.clear() + assert not session.is_expired() + + +# --------------------------------------------------------------------------- +# Req Gap 1: TTL from config (SessionConfig in Config) +# --------------------------------------------------------------------------- + +class TestConfigDrivenTTL: + def test_session_config_has_ttl_hours_field(self) -> None: + """SessionConfig must exist with a ttl_hours field defaulting to 24.0.""" + from operator_use.config.service import SessionConfig + sc = SessionConfig() + assert sc.ttl_hours == 24.0 + + def test_config_has_session_block(self) -> None: + """Root Config must have a session: SessionConfig field.""" + from operator_use.config.service import Config + c = Config() + assert hasattr(c, "session") + assert c.session.ttl_hours == 24.0 + + def test_from_config_uses_config_ttl(self) -> None: + """Session.from_config() must derive ttl from config.session.ttl_hours.""" + from operator_use.config.service import Config, SessionConfig + config = Config() + # Patch ttl_hours to a known value + config.session = SessionConfig(ttl_hours=2.0) + session = Session.from_config(id="cfg-1", config=config) + assert session.ttl == 2.0 * 3600 # 7200s + + def test_from_config_default_24h(self) -> None: + """from_config() with default config must produce 86400s TTL.""" + from operator_use.config.service import Config + config = Config() + session = Session.from_config(id="cfg-2", config=config) + assert session.ttl == 86400.0 + + +# --------------------------------------------------------------------------- +# Req Gap 2: Loaded sessions expire based on real age (updated_at) +# --------------------------------------------------------------------------- + +class TestLoadedSessionExpiry: + def test_loaded_session_last_activity_reflects_updated_at( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A session loaded from disk must base _last_activity on updated_at, + not on the current monotonic time at load time.""" + fake_time = [1000.0] # monotonic at "load time" + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + store = SessionStore(tmp_path) + session_id = "loaded-expiry-1" + + # Write a session that was last updated 2 hours ago + two_hours_ago = datetime.now() - timedelta(hours=2) + path = store._sessions_path(session_id) + meta = { + "type": "metadata", + "id": session_id, + "created_at": (datetime.now() - timedelta(hours=4)).isoformat(), + "updated_at": two_hours_ago.isoformat(), + "metadata": {}, + } + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + f.write(json.dumps(meta) + "\n") + + # Load the session with a 1-hour TTL — it should appear expired + # because updated_at is 2 hours ago + loaded = store.load(session_id) + assert loaded is not None + loaded.ttl = 3600.0 # 1 hour TTL + assert loaded.is_expired(), ( + "Session updated 2 hours ago with 1h TTL must appear expired on load" + ) + + def test_loaded_recent_session_not_expired( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A session updated 5 minutes ago with 1h TTL must NOT be expired.""" + fake_time = [1000.0] + monkeypatch.setattr(views_module.time, "monotonic", lambda: fake_time[0]) + + store = SessionStore(tmp_path) + session_id = "loaded-fresh-1" + + five_min_ago = datetime.now() - timedelta(minutes=5) + path = store._sessions_path(session_id) + meta = { + "type": "metadata", + "id": session_id, + "created_at": (datetime.now() - timedelta(hours=1)).isoformat(), + "updated_at": five_min_ago.isoformat(), + "metadata": {}, + } + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + f.write(json.dumps(meta) + "\n") + + loaded = store.load(session_id) + assert loaded is not None + loaded.ttl = 3600.0 # 1 hour TTL + assert not loaded.is_expired(), ( + "Session updated 5 minutes ago with 1h TTL must NOT be expired" + ) + + def test_get_or_create_deletes_expired_session(self, tmp_path: Path) -> None: + """get_or_create() must invalidate/delete expired sessions on access.""" + store = SessionStore(tmp_path) + session_id = "expired-cleanup-1" + + # Write a session that was last updated 48 hours ago + old_time = datetime.now() - timedelta(hours=48) + path = store._sessions_path(session_id) + meta = { + "type": "metadata", + "id": session_id, + "created_at": old_time.isoformat(), + "updated_at": old_time.isoformat(), + "metadata": {}, + } + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + f.write(json.dumps(meta) + "\n") + + # get_or_create with 1h TTL must return a fresh session, not the expired one + session = store.get_or_create(session_id=session_id, ttl=3600.0) + assert session.messages == [], "Expired session must be replaced with fresh one" + assert not session.is_expired(), "Newly created replacement session must not be expired" + + +# --------------------------------------------------------------------------- +# Req Gap 3: Cleanup method +# --------------------------------------------------------------------------- + +class TestSessionCleanup: + def test_cleanup_removes_expired_sessions(self, tmp_path: Path) -> None: + """SessionStore.cleanup() must delete expired sessions from disk.""" + store = SessionStore(tmp_path) + + # Create an expired session file (48h old) + old_time = datetime.now() - timedelta(hours=48) + expired_id = "cleanup-expired-1" + path = store._sessions_path(expired_id) + meta = { + "type": "metadata", + "id": expired_id, + "created_at": old_time.isoformat(), + "updated_at": old_time.isoformat(), + "metadata": {}, + } + with open(path, "w") as f: + f.write(json.dumps(meta) + "\n") + + # Create a fresh session file (5 min old) + fresh_id = "cleanup-fresh-1" + fresh_path = store._sessions_path(fresh_id) + fresh_time = datetime.now() - timedelta(minutes=5) + fresh_meta = { + "type": "metadata", + "id": fresh_id, + "created_at": fresh_time.isoformat(), + "updated_at": fresh_time.isoformat(), + "metadata": {}, + } + with open(fresh_path, "w") as f: + f.write(json.dumps(fresh_meta) + "\n") + + removed = store.cleanup(ttl=3600.0) # 1h TTL + assert expired_id in removed, "Expired session must be in removed list" + assert fresh_id not in removed, "Fresh session must NOT be removed" + assert not path.exists(), "Expired session file must be deleted from disk" + assert fresh_path.exists(), "Fresh session file must survive cleanup" + + def test_cleanup_returns_empty_list_when_nothing_expired(self, tmp_path: Path) -> None: + """cleanup() returns an empty list when no sessions are expired.""" + store = SessionStore(tmp_path) + removed = store.cleanup(ttl=86400.0) + assert removed == [] + + +# --------------------------------------------------------------------------- +# Req Gap 3: Encryption round-trip +# --------------------------------------------------------------------------- + +class TestSessionEncryption: + def test_save_and_load_with_encryption(self, tmp_path: Path) -> None: + """Encrypted-at-rest sessions must survive a save→load round-trip.""" + from cryptography.fernet import Fernet + key = Fernet.generate_key().decode() + + store = SessionStore(tmp_path, encryption_key=key) + session = Session(id="enc-1", ttl=86400.0) + session.add_message(HumanMessage(content="secret message")) + store.save(session) + + loaded = store.load("enc-1") + assert loaded is not None + assert len(loaded.messages) == 1 + assert loaded.messages[0].content == "secret message" + + def test_encrypted_file_is_not_plaintext(self, tmp_path: Path) -> None: + """When encryption is enabled, the raw .jsonl file must not contain + plaintext message content.""" + from cryptography.fernet import Fernet + key = Fernet.generate_key().decode() + + store = SessionStore(tmp_path, encryption_key=key) + session = Session(id="enc-2", ttl=86400.0) + session.add_message(HumanMessage(content="top secret")) + store.save(session) + + raw = store._sessions_path("enc-2").read_bytes() + assert b"top secret" not in raw, ( + "Plaintext message content must not appear in the encrypted file" + ) + + def test_load_without_key_when_saved_with_key_raises(self, tmp_path: Path) -> None: + """Loading an encrypted session without a key must raise an error.""" + from cryptography.fernet import Fernet + key = Fernet.generate_key().decode() + + store_with_key = SessionStore(tmp_path, encryption_key=key) + session = Session(id="enc-3", ttl=86400.0) + session.add_message(HumanMessage(content="confidential")) + store_with_key.save(session) + + store_no_key = SessionStore(tmp_path) + with pytest.raises(ValueError): + store_no_key.load("enc-3") + + def test_unencrypted_save_load_round_trip(self, tmp_path: Path) -> None: + """Without encryption, save→load must still work correctly (regression guard).""" + store = SessionStore(tmp_path) + session = Session(id="plain-1", ttl=86400.0) + session.add_message(HumanMessage(content="hello")) + store.save(session) + + loaded = store.load("plain-1") + assert loaded is not None + assert loaded.messages[0].content == "hello" diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py index ca6ed75..77c70b9 100644 --- a/tests/test_tool_registry.py +++ b/tests/test_tool_registry.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from operator_use.agent.tools.registry import ToolRegistry -from operator_use.tools.service import Tool +from operator_use.agent.tools.service import Tool # --- Helpers --- diff --git a/tests/test_tools.py b/tests/test_tools.py index 8cbf913..de572ab 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from typing import Literal -from operator_use.tools.service import Tool, ToolResult +from operator_use.agent.tools.service import Tool, ToolResult # --- ToolResult --- diff --git a/uv.lock b/uv.lock index 6449804..79bd8f6 100644 --- a/uv.lock +++ b/uv.lock @@ -1802,6 +1802,7 @@ dependencies = [ { name = "cerebras-cloud-sdk" }, { name = "comtypes", marker = "sys_platform == 'win32'" }, { name = "croniter" }, + { name = "cryptography" }, { name = "ddgs" }, { name = "discord-py" }, { name = "fastmcp" }, @@ -1861,6 +1862,7 @@ requires-dist = [ { name = "cerebras-cloud-sdk", specifier = ">=1.50.1" }, { name = "comtypes", marker = "sys_platform == 'win32'", specifier = ">=1.4.15" }, { name = "croniter", specifier = ">=2.0.0" }, + { name = "cryptography", specifier = ">=41.0" }, { name = "ddgs", specifier = ">=9.11.1" }, { name = "discord-py", specifier = ">=2.0.0" }, { name = "exa-py", marker = "extra == 'exa'", specifier = ">=1.0.0" },