diff --git a/capiscio_mcp/guard.py b/capiscio_mcp/guard.py index eb54054..eabfc6d 100644 --- a/capiscio_mcp/guard.py +++ b/capiscio_mcp/guard.py @@ -98,24 +98,24 @@ async def execute_query(sql: str) -> list[dict]: # calls within a burst: first call ~3 ms, subsequent calls ~0.01 ms. _DECISION_CACHE_TTL = 5.0 # seconds _DECISION_CACHE_MAX_SIZE = 256 # max entries before eviction -_decision_cache: Dict[Tuple[str, str], Tuple["GuardResult", float]] = {} +_decision_cache: Dict[Tuple, Tuple["GuardResult", float]] = {} # key=(badge_jws, tool_name, ...) _decision_cache_lock = threading.Lock() -def _cache_get(badge_jws: str, tool_name: str) -> Optional["GuardResult"]: +def _cache_get(badge_jws: str, tool_name: str, params_hash: str = "", server_origin: str = "", policy_version: str = "", capability_class: Optional[str] = None, deny_on_unknown_class: Optional[bool] = None) -> Optional["GuardResult"]: """Return cached decision if still valid, else None.""" with _decision_cache_lock: - entry = _decision_cache.get((badge_jws, tool_name)) + entry = _decision_cache.get((badge_jws, tool_name, params_hash, server_origin, policy_version, capability_class or "", str(deny_on_unknown_class or ""))) if entry is None: return None result, expiry = entry if time.monotonic() > expiry: - del _decision_cache[(badge_jws, tool_name)] + del _decision_cache[(badge_jws, tool_name, params_hash, server_origin, policy_version, capability_class or "", str(deny_on_unknown_class or ""))] return None return result -def _cache_put(badge_jws: str, tool_name: str, result: "GuardResult") -> None: +def _cache_put(badge_jws: str, tool_name: str, result: "GuardResult", params_hash: str = "", server_origin: str = "", policy_version: str = "", capability_class: Optional[str] = None, deny_on_unknown_class: Optional[bool] = None) -> None: """Store a decision in the cache.""" with _decision_cache_lock: # Evict expired entries if cache is at capacity @@ -129,7 +129,7 @@ def _cache_put(badge_jws: str, tool_name: str, result: "GuardResult") -> None: oldest = sorted(_decision_cache, key=lambda k: _decision_cache[k][1]) for k in oldest[:len(_decision_cache) - _DECISION_CACHE_MAX_SIZE + 1]: del _decision_cache[k] - _decision_cache[(badge_jws, tool_name)] = ( + _decision_cache[(badge_jws, tool_name, params_hash, server_origin, policy_version, capability_class or "", str(deny_on_unknown_class or ""))] = ( result, time.monotonic() + _DECISION_CACHE_TTL, ) @@ -326,7 +326,7 @@ async def evaluate_tool_access( # Check decision cache — same badge + same tool = same decision cache_key_jws = effective_credential.badge_jws or "" - cached = _cache_get(cache_key_jws, tool_name) + cached = _cache_get(cache_key_jws, tool_name, params_hash, server_origin, effective_config.policy_version or "", capability_class, deny_on_unknown_class) if cached is not None: logger.debug("Decision cache hit: tool=%s decision=%s", tool_name, cached.decision.value) return cached @@ -415,7 +415,7 @@ async def evaluate_tool_access( ) # Cache for subsequent calls with the same badge + tool - _cache_put(cache_key_jws, tool_name, result) + _cache_put(cache_key_jws, tool_name, result, params_hash, server_origin, effective_config.policy_version or "", capability_class, deny_on_unknown_class) return result