Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions capiscio_mcp/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,24 @@ async def execute_query(sql: str) -> list[dict]:
# calls within a burst: first call ~3 ms, subsequent calls ~0.01 ms.
_DECISION_CACHE_TTL = 5.0 # seconds
_DECISION_CACHE_MAX_SIZE = 256 # max entries before eviction
_decision_cache: Dict[Tuple[str, str], Tuple["GuardResult", float]] = {}
_decision_cache: Dict[Tuple, Tuple["GuardResult", float]] = {} # key=(badge_jws, tool_name, ...)
_decision_cache_lock = threading.Lock()


def _cache_get(badge_jws: str, tool_name: str) -> Optional["GuardResult"]:
def _cache_get(badge_jws: str, tool_name: str, params_hash: str = "", server_origin: str = "", policy_version: str = "", capability_class: Optional[str] = None, deny_on_unknown_class: Optional[bool] = None) -> Optional["GuardResult"]:
"""Return cached decision if still valid, else None."""
with _decision_cache_lock:
entry = _decision_cache.get((badge_jws, tool_name))
entry = _decision_cache.get((badge_jws, tool_name, params_hash, server_origin, policy_version, capability_class or "", str(deny_on_unknown_class or "")))
if entry is None:
return None
result, expiry = entry
if time.monotonic() > expiry:
del _decision_cache[(badge_jws, tool_name)]
del _decision_cache[(badge_jws, tool_name, params_hash, server_origin, policy_version, capability_class or "", str(deny_on_unknown_class or ""))]
return None
return result


def _cache_put(badge_jws: str, tool_name: str, result: "GuardResult") -> None:
def _cache_put(badge_jws: str, tool_name: str, result: "GuardResult", params_hash: str = "", server_origin: str = "", policy_version: str = "", capability_class: Optional[str] = None, deny_on_unknown_class: Optional[bool] = None) -> None:
"""Store a decision in the cache."""
with _decision_cache_lock:
# Evict expired entries if cache is at capacity
Expand All @@ -129,7 +129,7 @@ def _cache_put(badge_jws: str, tool_name: str, result: "GuardResult") -> None:
oldest = sorted(_decision_cache, key=lambda k: _decision_cache[k][1])
for k in oldest[:len(_decision_cache) - _DECISION_CACHE_MAX_SIZE + 1]:
del _decision_cache[k]
_decision_cache[(badge_jws, tool_name)] = (
_decision_cache[(badge_jws, tool_name, params_hash, server_origin, policy_version, capability_class or "", str(deny_on_unknown_class or ""))] = (
result,
time.monotonic() + _DECISION_CACHE_TTL,
)
Expand Down Expand Up @@ -326,7 +326,7 @@ async def evaluate_tool_access(

# Check decision cache — same badge + same tool = same decision
cache_key_jws = effective_credential.badge_jws or ""
cached = _cache_get(cache_key_jws, tool_name)
cached = _cache_get(cache_key_jws, tool_name, params_hash, server_origin, effective_config.policy_version or "", capability_class, deny_on_unknown_class)
if cached is not None:
logger.debug("Decision cache hit: tool=%s decision=%s", tool_name, cached.decision.value)
return cached
Expand Down Expand Up @@ -415,7 +415,7 @@ async def evaluate_tool_access(
)

# Cache for subsequent calls with the same badge + tool
_cache_put(cache_key_jws, tool_name, result)
_cache_put(cache_key_jws, tool_name, result, params_hash, server_origin, effective_config.policy_version or "", capability_class, deny_on_unknown_class)

return result

Expand Down