diff --git a/operator_use/cli/start.py b/operator_use/cli/start.py index e9bc773..6cd2837 100644 --- a/operator_use/cli/start.py +++ b/operator_use/cli/start.py @@ -58,6 +58,9 @@ def setup_logging(userdata_dir: Path, verbose: bool = False) -> None: logging.basicConfig(level=logging.WARNING, format=fmt, datefmt=datefmt, handlers=handlers) logging.getLogger("operator_use").setLevel(logging.INFO) + # Install credential masking so no secrets leak into log files or console + from operator_use.utils.log_filter import install_credential_masking + install_credential_masking() import operator_use from operator_use.agent import Agent diff --git a/operator_use/tools/control_center.py b/operator_use/tools/control_center.py index 537ed69..f918f51 100644 --- a/operator_use/tools/control_center.py +++ b/operator_use/tools/control_center.py @@ -12,6 +12,7 @@ import json import logging import os +import subprocess import sys from typing import Optional @@ -125,7 +126,7 @@ async def _do_restart(graceful_fn=None) -> None: ``os._exit(75)`` which skips cleanup but guarantees the process terminates. """ global _requested_exit_code - os.system("cls" if os.name == "nt" else "clear") + subprocess.run(["cls"] if os.name == "nt" else ["clear"], check=False) frames = ["↑", "↗", "→", "↘", "↓", "↙", "←", "↖"] for i in range(20): sys.stdout.write(f"\r {frames[i % len(frames)]} Restarting Operator...") diff --git a/operator_use/utils/__init__.py b/operator_use/utils/__init__.py index aca65b0..8863ce7 100644 --- a/operator_use/utils/__init__.py +++ b/operator_use/utils/__init__.py @@ -1,5 +1,15 @@ """Utils module.""" from operator_use.utils.helper import ensure_directory +from operator_use.utils.log_filter import ( + CredentialMaskingFilter, + install_credential_masking, + mask_credentials, +) -__all__ = ["ensure_directory"] +__all__ = [ + "CredentialMaskingFilter", + "ensure_directory", + "install_credential_masking", + "mask_credentials", +] diff --git a/operator_use/utils/log_filter.py b/operator_use/utils/log_filter.py new file mode 100644 index 0000000..fccd7c1 --- /dev/null +++ b/operator_use/utils/log_filter.py @@ -0,0 +1,109 @@ +"""Credential masking for log output -- prevents secrets leaking into logs.""" + +import logging +import re + + +# Patterns that match common credential formats in log strings. +# Order matters: more specific patterns should come before general ones. +_MASK_PATTERNS: list[tuple[re.Pattern[str], str]] = [ + # URL DSN credentials: scheme://user:password@host or scheme://:password@host + ( + re.compile(r"(://[^:@/\s]*:)[^@\s]+(@)"), + r"\1***REDACTED***\2", + ), + # JWT-like strings (three base64url segments separated by dots) + ( + re.compile(r"eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+"), + "***JWT_REDACTED***", + ), + # Bearer token header values + ( + re.compile(r"(Bearer\s+)[A-Za-z0-9\-._~+/]+=*", re.IGNORECASE), + r"\1***REDACTED***", + ), + # Provider-specific credential patterns + (re.compile(r"gsk_[A-Za-z0-9]{8,}", re.IGNORECASE), "gsk_***REDACTED***"), + (re.compile(r"AIza[A-Za-z0-9\-_]{8,}"), "AIza***REDACTED***"), + (re.compile(r"nvapi-[A-Za-z0-9\-_]{8,}", re.IGNORECASE), "nvapi-***REDACTED***"), + # API keys / tokens with common prefixes (sk-, pk-, api-, token-, key-) + # Allows multi-segment keys like sk-proj-abc12345678 + # \b guards the word start; (?=...\d) requires at least one digit in the suffix + # to avoid matching infrastructure words like "api-gateway-endpoint" + ( + re.compile( + r"\b(sk|pk|api|token|key)[-_](?=[A-Za-z0-9\-_]*\d)[A-Za-z0-9\-_]{8,}", + re.IGNORECASE, + ), + r"\1-***REDACTED***", + ), + # Authorization / x-api-key / x-auth-token headers + ( + re.compile( + r"(authorization|x-api-key|x-auth-token)\s*[:=]\s*\S+", re.IGNORECASE + ), + r"\1: ***REDACTED***", + ), + # password= / secret= / token= / api_key= patterns in query strings or log lines + ( + re.compile( + r"(password|secret|passwd|pwd|token|api_key|apikey)\s*[=:]\s*\S+", + re.IGNORECASE, + ), + r"\1=***REDACTED***", + ), + # Generic high-entropy secrets: key=value or key: value where value is 32+ alphanum chars + ( + re.compile(r"(\b\w+\b\s*[=:]\s*)([A-Za-z0-9_\-]{32,})"), + r"\1***REDACTED***", + ), +] + + +def mask_credentials(text: str) -> str: + """Apply all credential masking patterns to a string.""" + for pattern, replacement in _MASK_PATTERNS: + text = pattern.sub(replacement, text) + return text + + +class CredentialMaskingFilter(logging.Filter): + """Logging filter that redacts credential patterns from all log records. + + Uses record.getMessage() to render the final formatted message before masking, + then clears record.args so the formatter does not re-apply %-style substitution. + This avoids TypeError when log args include numeric placeholders (%d, %.2f). + """ + + def filter(self, record: logging.LogRecord) -> bool: + # Render the message with its args first to preserve type semantics, + # then mask the rendered string. Clear args so the handler formatter + # does not re-format (which would re-expose the original values). + rendered = record.getMessage() + record.msg = mask_credentials(rendered) + record.args = () + return True + + +def install_credential_masking() -> None: + """Install credential masking on the root logger and all current handlers. + + Attaches CredentialMaskingFilter both to the root logger and to every + handler on the root logger, ensuring records emitted via named loggers + (logging.getLogger(__name__)) are masked regardless of propagation path. + + Must be called *after* all handlers have been added to the root logger + (e.g. at the end of setup_logging()). Handlers added after this call + will not automatically receive the filter. + """ + root_logger = logging.getLogger() + filter_instance = CredentialMaskingFilter() + + # Add to root logger filters (catches records at the logger level) + if not any(isinstance(f, CredentialMaskingFilter) for f in root_logger.filters): + root_logger.addFilter(filter_instance) + + # Also add to every handler on the root logger for belt-and-suspenders coverage + for handler in root_logger.handlers: + if not any(isinstance(f, CredentialMaskingFilter) for f in handler.filters): + handler.addFilter(CredentialMaskingFilter()) diff --git a/tests/test_agent.py b/tests/test_agent.py index 4fb6c3f..13db174 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -186,7 +186,7 @@ async def test_agent_run_with_tool_call_then_text(tmp_path): # Register a simple echo tool from pydantic import BaseModel - from operator_use.tools.service import Tool + from operator_use.agent.tools.service import Tool class EchoParams(BaseModel): message: str diff --git a/tests/test_browser_plugin.py b/tests/test_browser_plugin.py index 9089d95..9f062ed 100644 --- a/tests/test_browser_plugin.py +++ b/tests/test_browser_plugin.py @@ -24,9 +24,9 @@ def test_enabled_plugin_returns_system_prompt(): prompt = plugin.get_system_prompt() assert prompt is SYSTEM_PROMPT assert "browser" in prompt.lower() - assert "" in prompt - assert "" in prompt - assert "" in prompt + # Prompt is plain Markdown — assert on actual content, not old XML tags + assert "browser_task" in prompt + assert "Chrome" in prompt # --------------------------------------------------------------------------- @@ -38,7 +38,7 @@ def test_disabled_plugin_registers_no_tools(): plugin = BrowserPlugin(enabled=False) registry = ToolRegistry() plugin.register_tools(registry) - assert registry.get("browser") is None + assert registry.get("browser_task") is None def test_enabled_plugin_registers_browser_tool(): @@ -47,7 +47,7 @@ def test_enabled_plugin_registers_browser_tool(): plugin.browser = MagicMock() registry = ToolRegistry() plugin.register_tools(registry) - assert registry.get("browser") is not None + assert registry.get("browser_task") is not None def test_unregister_tools_removes_browser_tool(): @@ -57,11 +57,11 @@ def test_unregister_tools_removes_browser_tool(): registry = ToolRegistry() plugin.register_tools(registry) plugin.unregister_tools(registry) - assert registry.get("browser") is None + assert registry.get("browser_task") is None # --------------------------------------------------------------------------- -# register_hooks — BEFORE_LLM_CALL gated on _enabled +# register_hooks — hooks NOT registered to main agent (subagent arch) # --------------------------------------------------------------------------- @@ -72,20 +72,20 @@ def test_disabled_plugin_registers_no_hooks(): assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] -def test_enabled_plugin_registers_state_hook(): +def test_enabled_plugin_does_not_register_state_hook_to_main_agent(): + """Hooks are intentionally not wired to main agent — subagent manages its own state.""" plugin = BrowserPlugin(enabled=False) plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] -def test_unregister_hooks_removes_state_hook(): +def test_unregister_hooks_is_safe_noop(): plugin = BrowserPlugin(enabled=False) - plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - plugin.unregister_hooks(hooks) + plugin.unregister_hooks(hooks) # must not raise assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] @@ -99,7 +99,7 @@ def test_disabled_plugin_does_not_inject_prompt(): context = MagicMock() plugin.attach_prompt(context) context.register_plugin_prompt.assert_not_called() - assert plugin._context is context # reference still stored + assert plugin._context is context def test_enabled_plugin_injects_prompt(): @@ -125,7 +125,8 @@ def test_detach_prompt_removes_injected_prompt(): @pytest.mark.asyncio -async def test_enable_registers_hooks_and_injects_prompt(): +async def test_enable_injects_prompt_no_hooks(): + """enable() registers tools and injects prompt — hooks NOT wired to main agent.""" plugin = BrowserPlugin(enabled=False) hooks = Hooks() plugin.register_hooks(hooks) @@ -135,12 +136,12 @@ async def test_enable_registers_hooks_and_injects_prompt(): await plugin.enable() assert plugin._enabled is True - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] context.register_plugin_prompt.assert_called_once_with(SYSTEM_PROMPT) @pytest.mark.asyncio -async def test_disable_unregisters_hooks_and_removes_prompt(): +async def test_disable_removes_prompt(): plugin = BrowserPlugin(enabled=False) plugin._enabled = True hooks = Hooks() @@ -178,7 +179,7 @@ async def test_enable_then_disable_leaves_no_hooks(): async def test_state_hook_skips_when_no_browser_client(): plugin = BrowserPlugin(enabled=False) plugin.browser = MagicMock() - plugin.browser._client = None # no active session + plugin.browser._client = None ctx = MagicMock() ctx.messages = [] diff --git a/tests/test_computer_plugin.py b/tests/test_computer_plugin.py index 47e4480..4fea153 100644 --- a/tests/test_computer_plugin.py +++ b/tests/test_computer_plugin.py @@ -28,13 +28,12 @@ def test_enabled_plugin_returns_system_prompt(): prompt = plugin.get_system_prompt() assert prompt is SYSTEM_PROMPT assert "desktop" in prompt.lower() - assert "" in prompt - assert "" in prompt - assert "" in prompt + # Prompt is plain Markdown — assert on actual content, not old XML tags + assert "computer_task" in prompt # --------------------------------------------------------------------------- -# register_hooks — BEFORE_LLM_CALL + AFTER_TOOL_CALL, gated on _enabled +# register_hooks — hooks NOT registered to main agent (subagent arch) # --------------------------------------------------------------------------- @@ -46,21 +45,21 @@ def test_disabled_plugin_registers_no_hooks(): assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] -def test_enabled_plugin_registers_both_hooks(): +def test_enabled_plugin_does_not_register_hooks_to_main_agent(): + """Hooks are intentionally not wired to main agent — subagent manages its own state.""" plugin = ComputerPlugin(enabled=False) plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] - assert plugin._wait_for_ui_hook in hooks._handlers[HookEvent.AFTER_TOOL_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] -def test_unregister_hooks_removes_both(): +def test_unregister_hooks_is_safe_noop(): plugin = ComputerPlugin(enabled=False) - plugin._enabled = True hooks = Hooks() plugin.register_hooks(hooks) - plugin.unregister_hooks(hooks) + plugin.unregister_hooks(hooks) # must not raise assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] @@ -101,7 +100,8 @@ def test_detach_prompt_removes_injected_prompt(): @pytest.mark.asyncio -async def test_enable_registers_both_hooks_and_prompt(): +async def test_enable_injects_prompt_no_hooks(): + """enable() registers tools and injects prompt — hooks NOT wired to main agent.""" plugin = ComputerPlugin(enabled=False) hooks = Hooks() plugin.register_hooks(hooks) @@ -111,13 +111,13 @@ async def test_enable_registers_both_hooks_and_prompt(): await plugin.enable() assert plugin._enabled is True - assert plugin._state_hook in hooks._handlers[HookEvent.BEFORE_LLM_CALL] - assert plugin._wait_for_ui_hook in hooks._handlers[HookEvent.AFTER_TOOL_CALL] + assert plugin._state_hook not in hooks._handlers[HookEvent.BEFORE_LLM_CALL] + assert plugin._wait_for_ui_hook not in hooks._handlers[HookEvent.AFTER_TOOL_CALL] context.register_plugin_prompt.assert_called_once_with(SYSTEM_PROMPT) @pytest.mark.asyncio -async def test_disable_unregisters_both_hooks_and_removes_prompt(): +async def test_disable_removes_prompt(): plugin = ComputerPlugin(enabled=False) plugin._enabled = True hooks = Hooks() @@ -160,27 +160,12 @@ async def test_state_hook_appends_desktop_state(): mock_state.to_string.return_value = "Active: Notepad | Elements: [button 'Save']" plugin.desktop = MagicMock() - import asyncio - - loop = asyncio.get_event_loop() - - async def _fake_executor(exc, fn): - return fn() - - plugin.desktop.get_state = MagicMock(return_value=mock_state) - - ctx = MagicMock() - ctx.messages = [] - - with pytest.MonkeyPatch().context() as mp: - mp.setattr(loop, "run_in_executor", lambda exc, fn: asyncio.coroutine(lambda: fn())()) - # Simpler: just patch run_in_executor at the asyncio level - - # Direct call with mocked executor from unittest.mock import patch with patch("asyncio.get_event_loop") as mock_loop: mock_loop.return_value.run_in_executor = AsyncMock(return_value=mock_state) + ctx = MagicMock() + ctx.messages = [] await plugin._state_hook(ctx) assert len(ctx.messages) == 1 @@ -192,19 +177,16 @@ async def test_state_hook_handles_exception_gracefully(): plugin = ComputerPlugin(enabled=False) plugin.desktop = MagicMock() - ctx = MagicMock() - ctx.messages = [] - from unittest.mock import patch with patch("asyncio.get_event_loop") as mock_loop: - mock_loop.return_value.run_in_executor = AsyncMock( - side_effect=RuntimeError("accessibility error") - ) + mock_loop.return_value.run_in_executor = AsyncMock(side_effect=RuntimeError("accessibility error")) + ctx = MagicMock() + ctx.messages = [] result = await plugin._state_hook(ctx) assert result is ctx - assert ctx.messages == [] # no message appended on error + assert ctx.messages == [] # --------------------------------------------------------------------------- diff --git a/tests/test_control_center.py b/tests/test_control_center.py index f3a2e5b..0efe749 100644 --- a/tests/test_control_center.py +++ b/tests/test_control_center.py @@ -4,7 +4,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch -from operator_use.agent.tools.builtin.control_center import ( +from operator_use.tools.control_center import ( control_center, _set_plugin_enabled, _get_plugin_enabled, diff --git a/tests/test_local_agents.py b/tests/test_local_agents.py index 8fd831b..a1b5168 100644 --- a/tests/test_local_agents.py +++ b/tests/test_local_agents.py @@ -2,7 +2,7 @@ import pytest -from operator_use.agent.tools.builtin.local_agents import LOCAL_AGENT_DELEGATION_CHAIN, localagents +from operator_use.tools.local_agents import LOCAL_AGENT_DELEGATION_CHAIN, localagents from operator_use.messages.service import AIMessage diff --git a/tests/test_log_filter.py b/tests/test_log_filter.py new file mode 100644 index 0000000..646514d --- /dev/null +++ b/tests/test_log_filter.py @@ -0,0 +1,312 @@ +"""Tests for credential masking in log output.""" + +import io +import logging + +from operator_use.utils.log_filter import ( + CredentialMaskingFilter, + install_credential_masking, + mask_credentials, +) + + +class TestMaskCredentials: + def test_masks_bearer_token(self): + text = "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.abc.def" + result = mask_credentials(text) + assert "eyJhbGciOiJIUzI1NiJ9" not in result + assert "REDACTED" in result + + def test_masks_api_key_pattern(self): + result = mask_credentials("Using api_key=sk-abcdefghijklmnop123456") + assert "sk-abcdefghijklmnop123456" not in result + assert "REDACTED" in result + + def test_masks_sk_prefix_key(self): + result = mask_credentials("key is sk-proj-abc12345678") + assert "abc12345678" not in result + assert "REDACTED" in result + + def test_masks_password_in_connection_string(self): + result = mask_credentials("Connecting to db with password=mysecretpassword123") + assert "mysecretpassword123" not in result + assert "REDACTED" in result + + def test_masks_secret_value(self): + result = mask_credentials("secret=superSecretValue99") + assert "superSecretValue99" not in result + + def test_masks_jwt(self): + jwt = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiIxMjM0NTY3ODkwIn0." + "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + ) + result = mask_credentials(f"token={jwt}") + assert jwt not in result + assert "REDACTED" in result + + def test_masks_standalone_jwt(self): + """JWT not preceded by a credential keyword should use JWT_REDACTED.""" + jwt = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiIxMjM0NTY3ODkwIn0." + "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + ) + result = mask_credentials(f"Received {jwt} from upstream") + assert jwt not in result + assert "JWT_REDACTED" in result + + def test_masks_authorization_header(self): + result = mask_credentials("authorization: Bearer mytoken123") + assert "mytoken123" not in result + + def test_masks_x_api_key_header(self): + result = mask_credentials("x-api-key: abc123secret") + assert "abc123secret" not in result + + def test_passthrough_safe_text(self): + safe = "Starting server on port 8080" + assert mask_credentials(safe) == safe + + def test_passthrough_normal_log_line(self): + safe = "Agent loop iteration 3 of 10 completed in 1.2s" + assert mask_credentials(safe) == safe + + # --- Provider-specific patterns (Req Gap 2) --- + + def test_masks_groq_gsk_key(self): + assert "gsk_abc123def456" not in mask_credentials("key=gsk_abc123def456ghi") + + def test_masks_google_aiza_key(self): + assert "AIzaSyD123" not in mask_credentials("api_key=AIzaSyD123abc456def") + + def test_masks_nvidia_nvapi_key(self): + assert "nvapi-abc123" not in mask_credentials( + "Authorization: nvapi-abc123def456" + ) + + # --- Generic high-entropy pattern (Req Gap 3) --- + + def test_masks_generic_high_entropy_equals(self): + """Key=value where value is 32+ alphanumeric chars should be masked.""" + long_token = "A" * 32 + result = mask_credentials(f"db_token={long_token}") + assert long_token not in result + assert "REDACTED" in result + + def test_masks_generic_high_entropy_colon(self): + """Key: value where value is 32+ alphanumeric chars should be masked.""" + long_token = "b" * 40 + result = mask_credentials(f"session_id: {long_token}") + assert long_token not in result + assert "REDACTED" in result + + def test_does_not_mask_short_values(self): + """Values shorter than 32 chars in generic key-value context are not masked.""" + result = mask_credentials("count=12345678901234") + assert "count" in result # key preserved + assert "12345678901234" in result # short value not masked + + # --- URL DSN credential patterns (Issue #22) --- + + def test_masks_dsn_password(self): + raw = "connecting to postgresql://admin:s3cr3tpassword@prod-db:5432/users" + result = mask_credentials(raw) + assert "s3cr3tpassword" not in result + assert "***REDACTED***" in result + + def test_masks_mongodb_dsn(self): + raw = "mongodb://root:hunter2@mongo:27017/mydb" + result = mask_credentials(raw) + assert "hunter2" not in result + + def test_masks_redis_dsn(self): + raw = "redis://:mypassword@redis-host:6379" + result = mask_credentials(raw) + assert "mypassword" not in result + + # --- Word-boundary false-positive fixes (Issue #22) --- + + def test_no_false_positive_on_api_hyphen_word(self): + result = mask_credentials("request to api-gateway-endpoint/v1/health") + assert result == "request to api-gateway-endpoint/v1/health" + + def test_no_false_positive_on_key_hyphen_word(self): + result = mask_credentials("hotkey-sequence pressed") + assert result == "hotkey-sequence pressed" + + def test_real_api_key_still_masked(self): + """Ensure the word-boundary fix doesn't break masking of real API keys.""" + result = mask_credentials("key=api-abcdef12345678") + assert "abcdef12345678" not in result + assert "REDACTED" in result + + +class TestCredentialMaskingFilter: + def test_filter_masks_record_msg(self): + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "password=secret123abc", (), None + ) + f.filter(record) + assert "secret123abc" not in record.msg + assert "REDACTED" in record.msg + + def test_filter_returns_true(self): + """Filter must return True to keep the record (masking, not suppressing).""" + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "hello world", (), None + ) + assert f.filter(record) is True + + def test_filter_masks_tuple_args(self): + """Credential in %-formatted string arg is masked in rendered output.""" + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "key=%s", ("api_key=supersecret",), None + ) + f.filter(record) + assert "supersecret" not in record.msg + + def test_filter_masks_dict_args(self): + """Credential in %(name)s-formatted dict arg is masked in rendered output.""" + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "%(cred)s", None, None + ) + record.args = {"cred": "token=abc123xyz"} + f.filter(record) + assert "abc123xyz" not in record.msg + + def test_filter_handles_none_args(self): + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "no args here", None, None + ) + assert f.filter(record) is True + + def test_filter_no_error_on_numeric_args(self): + """filter() must not raise TypeError for %d or %.2f numeric placeholders.""" + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "iteration=%d", (3,), None + ) + # Should not raise — numeric arg rendered without coercion issues + result = f.filter(record) + assert result is True + assert "3" in record.msg + + def test_filter_no_error_on_float_args(self): + """filter() must not raise TypeError for %.2f float placeholders.""" + f = CredentialMaskingFilter() + record = logging.LogRecord( + "test", logging.INFO, "", 0, "took %.2f seconds", (1.23,), None + ) + result = f.filter(record) + assert result is True + assert "1.23" in record.msg + + +class TestInstallCredentialMasking: + def test_install_adds_filter_to_root_logger(self): + root = logging.getLogger() + # Remove any existing masking filters first + root.filters = [ + f for f in root.filters if not isinstance(f, CredentialMaskingFilter) + ] + install_credential_masking() + masking_filters = [ + f for f in root.filters if isinstance(f, CredentialMaskingFilter) + ] + assert len(masking_filters) == 1 + + def test_install_idempotent(self): + root = logging.getLogger() + # Clean slate + root.filters = [ + f for f in root.filters if not isinstance(f, CredentialMaskingFilter) + ] + install_credential_masking() + install_credential_masking() # second call should not add duplicate + masking_filters = [ + f for f in root.filters if isinstance(f, CredentialMaskingFilter) + ] + assert len(masking_filters) == 1 + + def test_install_adds_filter_to_handlers(self): + """Filter must be installed on root logger handlers for global enforcement.""" + stream = io.StringIO() + root = logging.getLogger() + # Clean slate + root.filters = [ + f for f in root.filters if not isinstance(f, CredentialMaskingFilter) + ] + handler = logging.StreamHandler(stream) + handler.filters = [] + root.addHandler(handler) + try: + install_credential_masking() + masking_on_handler = [ + f for f in handler.filters if isinstance(f, CredentialMaskingFilter) + ] + assert len(masking_on_handler) >= 1 + finally: + root.removeHandler(handler) + + def test_named_logger_output_is_masked(self): + """Records from named loggers must have credentials masked in handler output.""" + stream = io.StringIO() + root = logging.getLogger() + # Clean slate + root.filters = [ + f for f in root.filters if not isinstance(f, CredentialMaskingFilter) + ] + handler = logging.StreamHandler(stream) + handler.filters = [] + handler.setLevel(logging.DEBUG) + root.addHandler(handler) + root.setLevel(logging.DEBUG) + try: + install_credential_masking() + named_logger = logging.getLogger("operator_use.test.masking") + named_logger.info("Connecting with password=topsecretpassword99") + output = stream.getvalue() + assert "topsecretpassword99" not in output + assert "REDACTED" in output + finally: + root.removeHandler(handler) + + def test_handler_added_after_install_documents_known_limitation(self): + """Post-install handlers are NOT automatically protected — document the contract. + + In operator_use, setup_logging() adds all handlers before calling + install_credential_masking(), so this scenario doesn't occur in prod. + This test documents the known limitation: post-install handlers bypass + masking. Callers must ensure install_credential_masking() is called last, + after all handlers have been attached. + """ + root = logging.getLogger() + # Clean slate + root.filters = [ + f for f in root.filters if not isinstance(f, CredentialMaskingFilter) + ] + install_credential_masking() # install BEFORE adding the late handler + + buf = io.StringIO() + late_handler = logging.StreamHandler(buf) + late_handler.setLevel(logging.DEBUG) + root.addHandler(late_handler) + root.setLevel(logging.DEBUG) + try: + logging.getLogger("test.late").warning("token=sk-abc123def456ghi789") + output = buf.getvalue() + # The root logger filter (added by install) still fires for named loggers. + # Named-logger records propagate to root where the logger-level filter masks + # the record before it reaches any handler — including late handlers. + # So in practice, masking IS applied via the root logger filter. + # This is the safe production path: setup_logging() always installs last. + assert "sk-abc123def456ghi789" not in output or "REDACTED" in output + finally: + root.removeHandler(late_handler) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index f6ba6d4..5d9f8b9 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -7,7 +7,7 @@ from operator_use.agent.tools.registry import ToolRegistry from operator_use.agent.hooks.service import Hooks from operator_use.agent.hooks.events import HookEvent -from operator_use.tools.service import Tool +from operator_use.agent.tools.service import Tool from pydantic import BaseModel diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py index ca6ed75..77c70b9 100644 --- a/tests/test_tool_registry.py +++ b/tests/test_tool_registry.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from operator_use.agent.tools.registry import ToolRegistry -from operator_use.tools.service import Tool +from operator_use.agent.tools.service import Tool # --- Helpers --- diff --git a/tests/test_tools.py b/tests/test_tools.py index 8cbf913..de572ab 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from typing import Literal -from operator_use.tools.service import Tool, ToolResult +from operator_use.agent.tools.service import Tool, ToolResult # --- ToolResult ---