diff --git a/operator_use/guardrails/__init__.py b/operator_use/guardrails/__init__.py new file mode 100644 index 0000000..61cd4a8 --- /dev/null +++ b/operator_use/guardrails/__init__.py @@ -0,0 +1,68 @@ +"""operator_use.guardrails — base classes and registry for the guardrails framework. + +Public API +---------- + +Base classes & result types: + Guardrail, GuardrailResult, GuardrailAction, RiskLevel, + ActionValidator, ContentFilter, PolicyEngine + +Concrete helpers: + AllowAllValidator, BlockListValidator, CompositeActionValidator, + PassthroughFilter, KeywordBlockFilter, RegexFilter, CompositeContentFilter, + RuleBasedPolicyEngine, CompositePolicyEngine + +Registry: + GuardrailRegistry +""" + +from operator_use.guardrails.base import ( + Guardrail, + GuardrailAction, + GuardrailResult, + RiskLevel, + ActionValidator, + ContentFilter, + PolicyEngine, +) +from operator_use.guardrails.action_validator import ( + AllowAllValidator, + BlockListValidator, + CompositeActionValidator, +) +from operator_use.guardrails.content_filter import ( + PassthroughFilter, + KeywordBlockFilter, + RegexFilter, + CompositeContentFilter, +) +from operator_use.guardrails.policy_engine import ( + RuleBasedPolicyEngine, + CompositePolicyEngine, +) +from operator_use.guardrails.registry import GuardrailRegistry + +__all__ = [ + # base + "Guardrail", + "GuardrailAction", + "GuardrailResult", + "RiskLevel", + "ActionValidator", + "ContentFilter", + "PolicyEngine", + # action validators + "AllowAllValidator", + "BlockListValidator", + "CompositeActionValidator", + # content filters + "PassthroughFilter", + "KeywordBlockFilter", + "RegexFilter", + "CompositeContentFilter", + # policy engines + "RuleBasedPolicyEngine", + "CompositePolicyEngine", + # registry + "GuardrailRegistry", +] diff --git a/operator_use/guardrails/action_validator.py b/operator_use/guardrails/action_validator.py new file mode 100644 index 0000000..2a2e40a --- /dev/null +++ b/operator_use/guardrails/action_validator.py @@ -0,0 +1,108 @@ +"""ActionValidator: pre-execution tool-call validation. + +Concrete validators extend :class:`~operator_use.guardrails.base.ActionValidator` +and register themselves with :class:`~operator_use.guardrails.registry.GuardrailRegistry`. + +This module also ships a :class:`CompositeActionValidator` that runs a sequence +of validators and returns the most restrictive result. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from operator_use.guardrails.base import ( + ActionValidator, + GuardrailAction, + GuardrailResult, + RiskLevel, +) + +logger = logging.getLogger(__name__) + + +class AllowAllValidator(ActionValidator): + """Passthrough validator — allows every tool call. + + Useful as a no-op default or for testing. + """ + + def __init__(self) -> None: + super().__init__(name="allow_all") + + def validate( + self, + tool_name: str, + args: dict[str, Any], + context: dict[str, Any], + ) -> GuardrailResult: + return GuardrailResult.allow(f"AllowAllValidator: {tool_name!r} permitted") + + +class BlockListValidator(ActionValidator): + """Blocks any tool whose name appears in a configurable deny-list. + + Example:: + + validator = BlockListValidator(blocked_tools={"shell", "delete_file"}) + result = validator.validate("shell", {}, {}) + assert result.is_blocked + """ + + def __init__(self, blocked_tools: set[str] | None = None) -> None: + super().__init__(name="block_list") + self.blocked_tools: set[str] = blocked_tools or set() + + def validate( + self, + tool_name: str, + args: dict[str, Any], + context: dict[str, Any], + ) -> GuardrailResult: + if tool_name in self.blocked_tools: + return GuardrailResult.block( + f"Tool {tool_name!r} is on the deny-list", + severity=RiskLevel.DANGEROUS, + ) + return GuardrailResult.allow(f"Tool {tool_name!r} is not blocked") + + +class CompositeActionValidator(ActionValidator): + """Runs multiple validators in order and returns the strictest result. + + Priority: BLOCK > CONFIRM > ALLOW. The first BLOCK short-circuits. + """ + + def __init__(self, validators: list[ActionValidator] | None = None) -> None: + super().__init__(name="composite_action_validator") + self.validators: list[ActionValidator] = validators or [] + + def add(self, validator: ActionValidator) -> None: + """Append a validator to the chain.""" + self.validators.append(validator) + + def validate( + self, + tool_name: str, + args: dict[str, Any], + context: dict[str, Any], + ) -> GuardrailResult: + result: GuardrailResult = GuardrailResult.allow("No validators configured") + + for validator in self.validators: + if not validator.enabled: + continue + current = validator.validate(tool_name, args, context) + logger.debug( + "ActionValidator %r: %s — %s", + validator.name, + current.action, + current.reason, + ) + if current.action == GuardrailAction.BLOCK: + return current + if current.action == GuardrailAction.CONFIRM: + result = current # keep going in case something blocks later + + return result diff --git a/operator_use/guardrails/base.py b/operator_use/guardrails/base.py new file mode 100644 index 0000000..c0e0172 --- /dev/null +++ b/operator_use/guardrails/base.py @@ -0,0 +1,188 @@ +"""Base classes for the guardrails module. + +Defines the abstract foundation for action validation, content filtering, +and policy enforcement used throughout the guardrails framework. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +logger = logging.getLogger(__name__) + + +class GuardrailAction(str, Enum): + """Possible outcomes from a guardrail check.""" + + ALLOW = "allow" + BLOCK = "block" + CONFIRM = "confirm" + + +class RiskLevel(str, Enum): + """Risk classification for an action.""" + + SAFE = "safe" + REVIEW = "review" + DANGEROUS = "dangerous" + + +@dataclass +class GuardrailResult: + """Result object returned by every guardrail check. + + Attributes: + action: Disposition — allow, block, or require human confirmation. + reason: Human-readable explanation for the decision. + severity: Risk level associated with the decision. + metadata: Optional extra data produced by the guardrail. + """ + + action: GuardrailAction + reason: str + severity: RiskLevel = RiskLevel.SAFE + metadata: dict[str, Any] = field(default_factory=dict) + + @classmethod + def allow(cls, reason: str = "OK", severity: RiskLevel = RiskLevel.SAFE) -> "GuardrailResult": + return cls(action=GuardrailAction.ALLOW, reason=reason, severity=severity) + + @classmethod + def block(cls, reason: str, severity: RiskLevel = RiskLevel.DANGEROUS) -> "GuardrailResult": + return cls(action=GuardrailAction.BLOCK, reason=reason, severity=severity) + + @classmethod + def confirm(cls, reason: str, severity: RiskLevel = RiskLevel.REVIEW) -> "GuardrailResult": + return cls(action=GuardrailAction.CONFIRM, reason=reason, severity=severity) + + @property + def is_allowed(self) -> bool: + return self.action == GuardrailAction.ALLOW + + @property + def is_blocked(self) -> bool: + return self.action == GuardrailAction.BLOCK + + @property + def needs_confirmation(self) -> bool: + return self.action == GuardrailAction.CONFIRM + + +class Guardrail(ABC): + """Abstract base for all guardrails. + + Every guardrail receives a context dict and returns a GuardrailResult. + Subclasses implement :meth:`check` with their specific logic. + """ + + def __init__(self, name: str, enabled: bool = True) -> None: + self.name = name + self.enabled = enabled + + @abstractmethod + def check(self, context: dict[str, Any]) -> GuardrailResult: + """Evaluate the guardrail against the given context. + + Args: + context: Arbitrary context dictionary provided by the caller. + Exact keys depend on the guardrail type. + + Returns: + A :class:`GuardrailResult` describing the decision. + """ + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name!r}, enabled={self.enabled})" + + +class ActionValidator(Guardrail, ABC): + """Validates tool calls *before* execution. + + Receives the tool name, arguments, and caller context, and decides + whether to allow, block, or require confirmation. + """ + + @abstractmethod + def validate( + self, + tool_name: str, + args: dict[str, Any], + context: dict[str, Any], + ) -> GuardrailResult: + """Validate a pending tool call. + + Args: + tool_name: Name of the tool about to be called. + args: Arguments that will be passed to the tool. + context: Caller context (agent id, session, etc.). + + Returns: + A :class:`GuardrailResult` describing the decision. + """ + + def check(self, context: dict[str, Any]) -> GuardrailResult: + """Delegate to :meth:`validate` using ``context`` keys.""" + return self.validate( + tool_name=context.get("tool_name", ""), + args=context.get("args", {}), + context=context, + ) + + +class ContentFilter(Guardrail, ABC): + """Filters LLM output *before* it is forwarded to the user. + + Receives the raw content string and caller context, and returns a + result indicating whether the content should pass through, be + blocked, or be sent for human review. + """ + + @abstractmethod + def filter(self, content: str, context: dict[str, Any]) -> GuardrailResult: + """Evaluate and potentially filter a content string. + + Args: + content: The LLM-generated text to evaluate. + context: Caller context (agent id, session, etc.). + + Returns: + A :class:`GuardrailResult` describing the decision. + """ + + def check(self, context: dict[str, Any]) -> GuardrailResult: + """Delegate to :meth:`filter` using ``context`` keys.""" + return self.filter( + content=context.get("content", ""), + context=context, + ) + + +class PolicyEngine(Guardrail, ABC): + """Evaluates the risk level of an action given context. + + Concrete implementations encode organisation-specific policies. + """ + + @abstractmethod + def classify_risk(self, action: dict[str, Any]) -> RiskLevel: + """Classify the risk of ``action``. + + Args: + action: A description of the action (tool name, args, agent, …). + + Returns: + ``safe``, ``review``, or ``dangerous``. + """ + + def check(self, context: dict[str, Any]) -> GuardrailResult: + """Run :meth:`classify_risk` and wrap the result.""" + risk = self.classify_risk(context) + if risk == RiskLevel.SAFE: + return GuardrailResult.allow("Policy: safe", severity=risk) + if risk == RiskLevel.REVIEW: + return GuardrailResult.confirm("Policy: requires review", severity=risk) + return GuardrailResult.block("Policy: dangerous action", severity=risk) diff --git a/operator_use/guardrails/content_filter.py b/operator_use/guardrails/content_filter.py new file mode 100644 index 0000000..4a92ede --- /dev/null +++ b/operator_use/guardrails/content_filter.py @@ -0,0 +1,128 @@ +"""ContentFilter: post-execution output filtering. + +Concrete filters extend :class:`~operator_use.guardrails.base.ContentFilter` +and plug into the response pipeline before content is forwarded to the user. + +This module ships a :class:`CompositeContentFilter` that aggregates multiple +filters with the same BLOCK > CONFIRM > ALLOW precedence used elsewhere. +""" + +from __future__ import annotations + +import logging +import re +from typing import Any + +from operator_use.guardrails.base import ( + ContentFilter, + GuardrailAction, + GuardrailResult, + RiskLevel, +) + +logger = logging.getLogger(__name__) + + +class PassthroughFilter(ContentFilter): + """No-op filter — passes all content through. + + Useful as a default or in test environments. + """ + + def __init__(self) -> None: + super().__init__(name="passthrough") + + def filter(self, content: str, context: dict[str, Any]) -> GuardrailResult: + return GuardrailResult.allow("PassthroughFilter: content accepted") + + +class KeywordBlockFilter(ContentFilter): + """Blocks content that contains any phrase from a configurable deny-list. + + Matching is case-insensitive by default. + + Example:: + + f = KeywordBlockFilter(blocked_phrases={"drop table", "rm -rf"}) + result = f.filter("please run rm -rf /", {}) + assert result.is_blocked + """ + + def __init__( + self, + blocked_phrases: set[str] | None = None, + case_sensitive: bool = False, + ) -> None: + super().__init__(name="keyword_block") + self.blocked_phrases: set[str] = blocked_phrases or set() + self.case_sensitive = case_sensitive + + def filter(self, content: str, context: dict[str, Any]) -> GuardrailResult: + haystack = content if self.case_sensitive else content.lower() + for phrase in self.blocked_phrases: + needle = phrase if self.case_sensitive else phrase.lower() + if needle in haystack: + return GuardrailResult.block( + f"Content contains blocked phrase: {phrase!r}", + severity=RiskLevel.DANGEROUS, + ) + return GuardrailResult.allow("No blocked phrases found") + + +class RegexFilter(ContentFilter): + """Blocks content matching any of a set of regular expressions. + + Example:: + + f = RegexFilter(patterns=[r"\\bpassword\\s*=\\s*\\S+"]) + result = f.filter("password = hunter2", {}) + assert result.is_blocked + """ + + def __init__(self, patterns: list[str] | None = None) -> None: + super().__init__(name="regex_filter") + self._compiled = [re.compile(p) for p in (patterns or [])] + + def filter(self, content: str, context: dict[str, Any]) -> GuardrailResult: + for pattern in self._compiled: + if pattern.search(content): + return GuardrailResult.block( + f"Content matched blocked pattern: {pattern.pattern!r}", + severity=RiskLevel.DANGEROUS, + ) + return GuardrailResult.allow("No patterns matched") + + +class CompositeContentFilter(ContentFilter): + """Runs multiple filters in order and returns the strictest result. + + Priority: BLOCK > CONFIRM > ALLOW. The first BLOCK short-circuits. + """ + + def __init__(self, filters: list[ContentFilter] | None = None) -> None: + super().__init__(name="composite_content_filter") + self.filters: list[ContentFilter] = filters or [] + + def add(self, content_filter: ContentFilter) -> None: + """Append a filter to the chain.""" + self.filters.append(content_filter) + + def filter(self, content: str, context: dict[str, Any]) -> GuardrailResult: + result: GuardrailResult = GuardrailResult.allow("No filters configured") + + for f in self.filters: + if not f.enabled: + continue + current = f.filter(content, context) + logger.debug( + "ContentFilter %r: %s — %s", + f.name, + current.action, + current.reason, + ) + if current.action == GuardrailAction.BLOCK: + return current + if current.action == GuardrailAction.CONFIRM: + result = current + + return result diff --git a/operator_use/guardrails/policy_engine.py b/operator_use/guardrails/policy_engine.py new file mode 100644 index 0000000..24e7ae3 --- /dev/null +++ b/operator_use/guardrails/policy_engine.py @@ -0,0 +1,98 @@ +"""PolicyEngine: risk classification and policy decisions. + +Concrete engines extend :class:`~operator_use.guardrails.base.PolicyEngine` +and encode organisation-specific risk policies. + +This module ships a :class:`RuleBasedPolicyEngine` that classifies actions +against configurable safe/dangerous tool lists, and a +:class:`CompositePolicyEngine` that aggregates multiple engines with +DANGEROUS > REVIEW > SAFE precedence. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from operator_use.guardrails.base import ( + GuardrailResult, + PolicyEngine, + RiskLevel, +) + +logger = logging.getLogger(__name__) + + +class RuleBasedPolicyEngine(PolicyEngine): + """Classifies risk by comparing ``action["tool_name"]`` against lists. + + Classification priority: + 1. If the tool name is in *dangerous_tools* → ``dangerous`` + 2. If the tool name is in *review_tools* → ``review`` + 3. Otherwise → ``safe`` + + Example:: + + engine = RuleBasedPolicyEngine( + dangerous_tools={"shell", "delete_file"}, + review_tools={"write_file"}, + ) + assert engine.classify_risk({"tool_name": "shell"}) == RiskLevel.DANGEROUS + assert engine.classify_risk({"tool_name": "write_file"}) == RiskLevel.REVIEW + assert engine.classify_risk({"tool_name": "read_file"}) == RiskLevel.SAFE + """ + + def __init__( + self, + dangerous_tools: set[str] | None = None, + review_tools: set[str] | None = None, + ) -> None: + super().__init__(name="rule_based_policy") + self.dangerous_tools: set[str] = dangerous_tools or set() + self.review_tools: set[str] = review_tools or set() + + def classify_risk(self, action: dict[str, Any]) -> RiskLevel: + tool_name = action.get("tool_name", "") + if tool_name in self.dangerous_tools: + return RiskLevel.DANGEROUS + if tool_name in self.review_tools: + return RiskLevel.REVIEW + return RiskLevel.SAFE + + +class CompositePolicyEngine(PolicyEngine): + """Runs multiple engines and returns the highest risk level found. + + Priority: DANGEROUS > REVIEW > SAFE. + """ + + def __init__(self, engines: list[PolicyEngine] | None = None) -> None: + super().__init__(name="composite_policy") + self.engines: list[PolicyEngine] = engines or [] + + def add(self, engine: PolicyEngine) -> None: + """Append an engine to the chain.""" + self.engines.append(engine) + + def classify_risk(self, action: dict[str, Any]) -> RiskLevel: + highest = RiskLevel.SAFE + + for engine in self.engines: + if not engine.enabled: + continue + risk = engine.classify_risk(action) + logger.debug("PolicyEngine %r classified as %s", engine.name, risk) + if risk == RiskLevel.DANGEROUS: + return RiskLevel.DANGEROUS + if risk == RiskLevel.REVIEW: + highest = RiskLevel.REVIEW + + return highest + + def check(self, context: dict[str, Any]) -> GuardrailResult: + risk = self.classify_risk(context) + if risk == RiskLevel.SAFE: + return GuardrailResult.allow("CompositePolicyEngine: safe", severity=risk) + if risk == RiskLevel.REVIEW: + return GuardrailResult.confirm("CompositePolicyEngine: requires review", severity=risk) + return GuardrailResult.block("CompositePolicyEngine: dangerous action", severity=risk) diff --git a/operator_use/guardrails/registry.py b/operator_use/guardrails/registry.py new file mode 100644 index 0000000..7f58e98 --- /dev/null +++ b/operator_use/guardrails/registry.py @@ -0,0 +1,111 @@ +"""GuardrailRegistry: registration and lookup of guardrail instances. + +The registry is the single source of truth for all active guardrails in a +running system. Guardrails self-register at construction time (opt-in via +:meth:`GuardrailRegistry.register`) or are registered explicitly. + +Typical usage:: + + registry = GuardrailRegistry() + registry.register(BlockListValidator(blocked_tools={"shell"})) + registry.register(KeywordBlockFilter(blocked_phrases={"drop table"})) + + validators = registry.get_all(ActionValidator) + filters = registry.get_all(ContentFilter) +""" + +from __future__ import annotations + +import logging +from typing import Any, TypeVar, Type + +from operator_use.guardrails.base import Guardrail, GuardrailResult + +logger = logging.getLogger(__name__) + +G = TypeVar("G", bound=Guardrail) + + +class GuardrailRegistry: + """Central registry for guardrail instances. + + Guardrails are stored by name; registering a second guardrail with the + same name replaces the first (last-write wins). + """ + + def __init__(self) -> None: + self._guardrails: dict[str, Guardrail] = {} + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register(self, guardrail: Guardrail) -> None: + """Register ``guardrail`` under its :attr:`~Guardrail.name`. + + If a guardrail with the same name already exists it is replaced. + """ + if guardrail.name in self._guardrails: + logger.debug("GuardrailRegistry: replacing %r", guardrail.name) + self._guardrails[guardrail.name] = guardrail + logger.debug("GuardrailRegistry: registered %r", guardrail.name) + + def unregister(self, name: str) -> None: + """Remove the guardrail identified by ``name``. + + Silently does nothing if the name is not registered. + """ + self._guardrails.pop(name, None) + + # ------------------------------------------------------------------ + # Lookup + # ------------------------------------------------------------------ + + def get(self, name: str) -> Guardrail | None: + """Return the guardrail with ``name``, or ``None`` if absent.""" + return self._guardrails.get(name) + + def get_all(self, guardrail_type: Type[G] | None = None) -> list[G]: + """Return all registered guardrails, optionally filtered by type. + + Args: + guardrail_type: If provided, only instances of this type are + returned. Pass ``None`` to get everything. + + Returns: + List of matching guardrail instances (enabled or not). + """ + if guardrail_type is None: + return list(self._guardrails.values()) # type: ignore[return-value] + return [g for g in self._guardrails.values() if isinstance(g, guardrail_type)] + + def get_enabled(self, guardrail_type: Type[G] | None = None) -> list[G]: + """Like :meth:`get_all` but only returns enabled guardrails.""" + return [g for g in self.get_all(guardrail_type) if g.enabled] + + # ------------------------------------------------------------------ + # Bulk operations + # ------------------------------------------------------------------ + + def run_all( + self, + guardrail_type: Type[G] | None = None, + context: dict[str, Any] | None = None, + ) -> list[GuardrailResult]: + """Run :meth:`~Guardrail.check` on every enabled guardrail of ``guardrail_type``. + + Returns all results (callers decide how to interpret them). + """ + ctx = context or {} + return [g.check(ctx) for g in self.get_enabled(guardrail_type)] + + def clear(self) -> None: + """Remove all registered guardrails.""" + self._guardrails.clear() + + def __len__(self) -> int: + return len(self._guardrails) + + def __repr__(self) -> str: + names = list(self._guardrails.keys()) + return f"GuardrailRegistry(count={len(names)}, names={names})" diff --git a/tests/security/__init__.py b/tests/security/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/security/test_guardrails_base.py b/tests/security/test_guardrails_base.py new file mode 100644 index 0000000..a8a26eb --- /dev/null +++ b/tests/security/test_guardrails_base.py @@ -0,0 +1,432 @@ +"""Unit tests for operator_use/guardrails/ base classes and concrete helpers.""" + +import pytest + +from operator_use.guardrails.base import ( + ActionValidator, + ContentFilter, + GuardrailAction, + GuardrailResult, + PolicyEngine, + RiskLevel, +) +from operator_use.guardrails.action_validator import ( + AllowAllValidator, + BlockListValidator, + CompositeActionValidator, +) +from operator_use.guardrails.content_filter import ( + CompositeContentFilter, + KeywordBlockFilter, + PassthroughFilter, + RegexFilter, +) +from operator_use.guardrails.policy_engine import ( + CompositePolicyEngine, + RuleBasedPolicyEngine, +) +from operator_use.guardrails.registry import GuardrailRegistry + + +# --------------------------------------------------------------------------- +# GuardrailResult +# --------------------------------------------------------------------------- + + +class TestGuardrailResult: + def test_allow_factory(self): + r = GuardrailResult.allow("OK") + assert r.action == GuardrailAction.ALLOW + assert r.is_allowed + assert not r.is_blocked + assert not r.needs_confirmation + assert r.severity == RiskLevel.SAFE + + def test_block_factory(self): + r = GuardrailResult.block("dangerous") + assert r.action == GuardrailAction.BLOCK + assert r.is_blocked + assert r.severity == RiskLevel.DANGEROUS + + def test_confirm_factory(self): + r = GuardrailResult.confirm("needs review") + assert r.action == GuardrailAction.CONFIRM + assert r.needs_confirmation + assert r.severity == RiskLevel.REVIEW + + def test_metadata_defaults_to_empty_dict(self): + r = GuardrailResult.allow() + assert r.metadata == {} + + def test_custom_severity_on_allow(self): + r = GuardrailResult.allow("ok", severity=RiskLevel.REVIEW) + assert r.severity == RiskLevel.REVIEW + + +# --------------------------------------------------------------------------- +# ActionValidator — abstract interface enforced +# --------------------------------------------------------------------------- + + +class TestActionValidatorAbstract: + def test_cannot_instantiate_abstract(self): + with pytest.raises(TypeError): + ActionValidator(name="test") # type: ignore[abstract] + + def test_concrete_subclass_requires_validate(self): + class Broken(ActionValidator): + pass # missing validate + + with pytest.raises(TypeError): + Broken(name="broken") # type: ignore[abstract] + + +# --------------------------------------------------------------------------- +# ContentFilter — abstract interface enforced +# --------------------------------------------------------------------------- + + +class TestContentFilterAbstract: + def test_cannot_instantiate_abstract(self): + with pytest.raises(TypeError): + ContentFilter(name="test") # type: ignore[abstract] + + +# --------------------------------------------------------------------------- +# PolicyEngine — abstract interface enforced +# --------------------------------------------------------------------------- + + +class TestPolicyEngineAbstract: + def test_cannot_instantiate_abstract(self): + with pytest.raises(TypeError): + PolicyEngine(name="test") # type: ignore[abstract] + + +# --------------------------------------------------------------------------- +# AllowAllValidator +# --------------------------------------------------------------------------- + + +class TestAllowAllValidator: + def setup_method(self): + self.v = AllowAllValidator() + + def test_allows_any_tool(self): + for tool in ("shell", "delete_file", "read_file", ""): + r = self.v.validate(tool, {}, {}) + assert r.is_allowed, f"Expected allowed for {tool!r}" + + def test_check_delegates_to_validate(self): + r = self.v.check({"tool_name": "shell", "args": {}}) + assert r.is_allowed + + def test_enabled_by_default(self): + assert self.v.enabled is True + + +# --------------------------------------------------------------------------- +# BlockListValidator +# --------------------------------------------------------------------------- + + +class TestBlockListValidator: + def setup_method(self): + self.v = BlockListValidator(blocked_tools={"shell", "delete_file"}) + + def test_blocks_listed_tool(self): + r = self.v.validate("shell", {}, {}) + assert r.is_blocked + assert r.severity == RiskLevel.DANGEROUS + + def test_allows_unlisted_tool(self): + r = self.v.validate("read_file", {}, {}) + assert r.is_allowed + + def test_empty_blocklist_allows_all(self): + v = BlockListValidator() + assert v.validate("shell", {}, {}).is_allowed + + def test_check_delegates(self): + r = self.v.check({"tool_name": "delete_file", "args": {}}) + assert r.is_blocked + + +# --------------------------------------------------------------------------- +# CompositeActionValidator +# --------------------------------------------------------------------------- + + +class TestCompositeActionValidator: + def test_empty_composite_allows(self): + c = CompositeActionValidator() + r = c.validate("anything", {}, {}) + assert r.is_allowed + + def test_block_short_circuits(self): + c = CompositeActionValidator( + validators=[ + BlockListValidator(blocked_tools={"shell"}), + AllowAllValidator(), + ] + ) + r = c.validate("shell", {}, {}) + assert r.is_blocked + + def test_confirm_preserved_when_no_block(self): + class ConfirmAll(ActionValidator): + def __init__(self): + super().__init__(name="confirm_all") + + def validate(self, tool_name, args, context): + return GuardrailResult.confirm("needs review") + + c = CompositeActionValidator(validators=[ConfirmAll(), AllowAllValidator()]) + r = c.validate("anything", {}, {}) + assert r.needs_confirmation + + def test_disabled_validator_skipped(self): + blocker = BlockListValidator(blocked_tools={"shell"}) + blocker.enabled = False + c = CompositeActionValidator(validators=[blocker]) + assert c.validate("shell", {}, {}).is_allowed + + def test_add_validator(self): + c = CompositeActionValidator() + c.add(BlockListValidator(blocked_tools={"shell"})) + assert c.validate("shell", {}, {}).is_blocked + + +# --------------------------------------------------------------------------- +# PassthroughFilter +# --------------------------------------------------------------------------- + + +class TestPassthroughFilter: + def test_passes_any_content(self): + f = PassthroughFilter() + for content in ("hello", "", "DROP TABLE users;", "rm -rf /"): + assert f.filter(content, {}).is_allowed + + +# --------------------------------------------------------------------------- +# KeywordBlockFilter +# --------------------------------------------------------------------------- + + +class TestKeywordBlockFilter: + def setup_method(self): + self.f = KeywordBlockFilter(blocked_phrases={"drop table", "rm -rf"}) + + def test_blocks_matching_phrase(self): + r = self.f.filter("please run rm -rf /", {}) + assert r.is_blocked + assert r.severity == RiskLevel.DANGEROUS + + def test_case_insensitive_by_default(self): + r = self.f.filter("DROP TABLE users;", {}) + assert r.is_blocked + + def test_allows_clean_content(self): + r = self.f.filter("list all files", {}) + assert r.is_allowed + + def test_case_sensitive_mode(self): + f = KeywordBlockFilter(blocked_phrases={"DROP TABLE"}, case_sensitive=True) + assert f.filter("DROP TABLE users;", {}).is_blocked + assert f.filter("drop table users;", {}).is_allowed + + +# --------------------------------------------------------------------------- +# RegexFilter +# --------------------------------------------------------------------------- + + +class TestRegexFilter: + def test_blocks_matching_pattern(self): + f = RegexFilter(patterns=[r"\bpassword\s*=\s*\S+"]) + r = f.filter("password = hunter2", {}) + assert r.is_blocked + + def test_allows_non_matching(self): + f = RegexFilter(patterns=[r"\bpassword\s*=\s*\S+"]) + r = f.filter("no secrets here", {}) + assert r.is_allowed + + def test_empty_patterns_allows_all(self): + f = RegexFilter() + assert f.filter("anything", {}).is_allowed + + +# --------------------------------------------------------------------------- +# CompositeContentFilter +# --------------------------------------------------------------------------- + + +class TestCompositeContentFilter: + def test_empty_composite_allows(self): + c = CompositeContentFilter() + assert c.filter("anything", {}).is_allowed + + def test_block_short_circuits(self): + c = CompositeContentFilter( + filters=[ + KeywordBlockFilter(blocked_phrases={"bad"}), + PassthroughFilter(), + ] + ) + r = c.filter("bad content", {}) + assert r.is_blocked + + def test_disabled_filter_skipped(self): + blocker = KeywordBlockFilter(blocked_phrases={"bad"}) + blocker.enabled = False + c = CompositeContentFilter(filters=[blocker]) + assert c.filter("bad content", {}).is_allowed + + def test_add_filter(self): + c = CompositeContentFilter() + c.add(KeywordBlockFilter(blocked_phrases={"bad"})) + assert c.filter("bad content", {}).is_blocked + + +# --------------------------------------------------------------------------- +# RuleBasedPolicyEngine +# --------------------------------------------------------------------------- + + +class TestRuleBasedPolicyEngine: + def setup_method(self): + self.engine = RuleBasedPolicyEngine( + dangerous_tools={"shell", "delete_file"}, + review_tools={"write_file"}, + ) + + def test_classify_dangerous(self): + assert self.engine.classify_risk({"tool_name": "shell"}) == RiskLevel.DANGEROUS + + def test_classify_review(self): + assert self.engine.classify_risk({"tool_name": "write_file"}) == RiskLevel.REVIEW + + def test_classify_safe(self): + assert self.engine.classify_risk({"tool_name": "read_file"}) == RiskLevel.SAFE + + def test_check_blocks_dangerous(self): + r = self.engine.check({"tool_name": "shell"}) + assert r.is_blocked + + def test_check_confirms_review(self): + r = self.engine.check({"tool_name": "write_file"}) + assert r.needs_confirmation + + def test_check_allows_safe(self): + r = self.engine.check({"tool_name": "read_file"}) + assert r.is_allowed + + def test_unknown_tool_is_safe(self): + assert self.engine.classify_risk({"tool_name": ""}) == RiskLevel.SAFE + + +# --------------------------------------------------------------------------- +# CompositePolicyEngine +# --------------------------------------------------------------------------- + + +class TestCompositePolicyEngine: + def test_empty_composite_is_safe(self): + c = CompositePolicyEngine() + assert c.classify_risk({}) == RiskLevel.SAFE + + def test_dangerous_wins_over_review(self): + e1 = RuleBasedPolicyEngine(review_tools={"write_file"}) + e2 = RuleBasedPolicyEngine(dangerous_tools={"write_file"}) + c = CompositePolicyEngine(engines=[e1, e2]) + assert c.classify_risk({"tool_name": "write_file"}) == RiskLevel.DANGEROUS + + def test_review_preserved_when_no_dangerous(self): + e1 = RuleBasedPolicyEngine(review_tools={"write_file"}) + e2 = RuleBasedPolicyEngine() + c = CompositePolicyEngine(engines=[e1, e2]) + assert c.classify_risk({"tool_name": "write_file"}) == RiskLevel.REVIEW + + def test_disabled_engine_skipped(self): + e = RuleBasedPolicyEngine(dangerous_tools={"shell"}) + e.enabled = False + c = CompositePolicyEngine(engines=[e]) + assert c.classify_risk({"tool_name": "shell"}) == RiskLevel.SAFE + + def test_add_engine(self): + c = CompositePolicyEngine() + c.add(RuleBasedPolicyEngine(dangerous_tools={"shell"})) + assert c.classify_risk({"tool_name": "shell"}) == RiskLevel.DANGEROUS + + +# --------------------------------------------------------------------------- +# GuardrailRegistry +# --------------------------------------------------------------------------- + + +class TestGuardrailRegistry: + def setup_method(self): + self.registry = GuardrailRegistry() + + def test_register_and_get(self): + v = AllowAllValidator() + self.registry.register(v) + assert self.registry.get("allow_all") is v + + def test_get_unknown_returns_none(self): + assert self.registry.get("nonexistent") is None + + def test_register_replaces_same_name(self): + v1 = AllowAllValidator() + v2 = AllowAllValidator() + self.registry.register(v1) + self.registry.register(v2) + assert self.registry.get("allow_all") is v2 + + def test_unregister(self): + self.registry.register(AllowAllValidator()) + self.registry.unregister("allow_all") + assert self.registry.get("allow_all") is None + + def test_unregister_missing_is_noop(self): + self.registry.unregister("does_not_exist") # should not raise + + def test_get_all_returns_all(self): + self.registry.register(AllowAllValidator()) + self.registry.register(PassthroughFilter()) + assert len(self.registry.get_all()) == 2 + + def test_get_all_filtered_by_type(self): + self.registry.register(AllowAllValidator()) + self.registry.register(PassthroughFilter()) + validators = self.registry.get_all(ActionValidator) + assert len(validators) == 1 + assert isinstance(validators[0], ActionValidator) + + def test_get_enabled_excludes_disabled(self): + v = AllowAllValidator() + v.enabled = False + self.registry.register(v) + assert self.registry.get_enabled(ActionValidator) == [] + + def test_run_all_returns_results(self): + self.registry.register(AllowAllValidator()) + results = self.registry.run_all(ActionValidator, context={"tool_name": "x", "args": {}}) + assert len(results) == 1 + assert results[0].is_allowed + + def test_clear(self): + self.registry.register(AllowAllValidator()) + self.registry.clear() + assert len(self.registry) == 0 + + def test_len(self): + self.registry.register(AllowAllValidator()) + self.registry.register(PassthroughFilter()) + assert len(self.registry) == 2 + + def test_repr(self): + self.registry.register(AllowAllValidator()) + assert "allow_all" in repr(self.registry) diff --git a/tests/test_agent.py b/tests/test_agent.py index 4fb6c3f..13db174 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -186,7 +186,7 @@ async def test_agent_run_with_tool_call_then_text(tmp_path): # Register a simple echo tool from pydantic import BaseModel - from operator_use.tools.service import Tool + from operator_use.agent.tools.service import Tool class EchoParams(BaseModel): message: str diff --git a/tests/test_control_center.py b/tests/test_control_center.py index f3a2e5b..0efe749 100644 --- a/tests/test_control_center.py +++ b/tests/test_control_center.py @@ -4,7 +4,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch -from operator_use.agent.tools.builtin.control_center import ( +from operator_use.tools.control_center import ( control_center, _set_plugin_enabled, _get_plugin_enabled, diff --git a/tests/test_local_agents.py b/tests/test_local_agents.py index 8fd831b..a1b5168 100644 --- a/tests/test_local_agents.py +++ b/tests/test_local_agents.py @@ -2,7 +2,7 @@ import pytest -from operator_use.agent.tools.builtin.local_agents import LOCAL_AGENT_DELEGATION_CHAIN, localagents +from operator_use.tools.local_agents import LOCAL_AGENT_DELEGATION_CHAIN, localagents from operator_use.messages.service import AIMessage diff --git a/tests/test_plugins.py b/tests/test_plugins.py index f6ba6d4..5d9f8b9 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -7,7 +7,7 @@ from operator_use.agent.tools.registry import ToolRegistry from operator_use.agent.hooks.service import Hooks from operator_use.agent.hooks.events import HookEvent -from operator_use.tools.service import Tool +from operator_use.agent.tools.service import Tool from pydantic import BaseModel diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py index ca6ed75..77c70b9 100644 --- a/tests/test_tool_registry.py +++ b/tests/test_tool_registry.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from operator_use.agent.tools.registry import ToolRegistry -from operator_use.tools.service import Tool +from operator_use.agent.tools.service import Tool # --- Helpers --- diff --git a/tests/test_tools.py b/tests/test_tools.py index 8cbf913..de572ab 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from typing import Literal -from operator_use.tools.service import Tool, ToolResult +from operator_use.agent.tools.service import Tool, ToolResult # --- ToolResult ---