diff --git a/padiff/abstracts/hooks/base.py b/padiff/abstracts/hooks/base.py index afe4371..57dd820 100644 --- a/padiff/abstracts/hooks/base.py +++ b/padiff/abstracts/hooks/base.py @@ -60,42 +60,52 @@ def _traversal(node, bucket): class _CallsContext: """ - A global context for managing forward call counts across multiple PaDiffGuard invocations. - This ensures that max_calls is respected even when PaDiffGuard is re-entered. + A context for managing forward call counts for PaDiffGuard invocations. + Each model instance will have its own independent call count state. + Different files always use independent counters. """ - _state = contextvars.ContextVar("_calls_context_state", default=None) - def __init__(self): - self._state.set({"count": 0, "limit": 0, "active": False}) - - @property - def state(self) -> Dict: - s = self._state.get() - if s is None: - s = {"count": 0, "limit": 0, "active": False} - self._state.set(s) - return s - - def set_limit(self, limit: int): - self.state["limit"] = limit - self.state["active"] = True - - def increment(self) -> int: - if not self.state["active"]: + self._model_states = {} # model_id -> state dict + + def _get_model_id(self, model): + """Get a unique identifier for the model in current context""" + # Simple object identity is sufficient since we don't need cross-file compatibility + return str(id(model)) + + def get_state(self, model) -> Dict: + """Get the state for a specific model""" + model_id = self._get_model_id(model) + if model_id not in self._model_states: + self._model_states[model_id] = {"count": 0, "limit": 0, "active": False} + return self._model_states[model_id] + + def set_limit(self, model, limit: int): + """Set the call limit for a specific model""" + state = self.get_state(model) + state["limit"] = limit + state["active"] = True + + def increment(self, model) -> int: + """Increment the call count for a specific model""" + state = self.get_state(model) + if not state["active"]: return 0 - self.state["count"] += 1 - return self.state["count"] + state["count"] += 1 + return state["count"] - def is_exceeded(self) -> bool: - if not self.state["active"]: + def is_exceeded(self, model) -> bool: + """Check if the call limit is exceeded for a specific model""" + state = self.get_state(model) + if not state["active"]: return False - return self.state["count"] >= self.state["limit"] + return state["count"] >= state["limit"] - def reset(self): - self.state["count"] = 0 - self.state["limit"] = 0 - self.state["active"] = False + def reset(self, model): + """Reset the state for a specific model""" + model_id = self._get_model_id(model) + if model_id in self._model_states: + del self._model_states[model_id] @classmethod def get_current(cls) -> "_CallsContext": diff --git a/padiff/abstracts/hooks/guard.py b/padiff/abstracts/hooks/guard.py index a584453..f8d23cc 100644 --- a/padiff/abstracts/hooks/guard.py +++ b/padiff/abstracts/hooks/guard.py @@ -211,10 +211,10 @@ def MaxCallsGuard(max_calls: int, model): calls_context = get_calls_context() def pre_hook(m, input): - if calls_context.is_exceeded(): + if calls_context.is_exceeded(model): logger.warning(f"PaDiffGuard: max_calls={max_calls} reached, raising _CallsComplete") raise _CallsComplete() - count = calls_context.increment() + count = calls_context.increment(model) logger.info(f"MaxCallsGuard: forward start calling #{count}") handle = model.register_forward_pre_hook(pre_hook) @@ -240,13 +240,14 @@ def PaDiffGuard( black_list=None, keys_mapping=None, ): - # moniter number of calls + # get the global calls context calls_context = get_calls_context() - reset_flag = calls_context.state["count"] == 0 + # check if this is the first call for this specific model + reset_flag = calls_context.get_state(model)["count"] == 0 if reset_flag: - # set max calls - calls_context.set_limit(max_calls) + # set max calls for this model + calls_context.set_limit(model, max_calls) proxy_model = create_model(model, name=name, reset_dir=reset_flag) model._padiff_proxy = proxy_model