diff --git a/padiff/abstracts/hooks/guard.py b/padiff/abstracts/hooks/guard.py index c6f9e5e..a584453 100644 --- a/padiff/abstracts/hooks/guard.py +++ b/padiff/abstracts/hooks/guard.py @@ -191,13 +191,15 @@ def serialize(x): return original_forward(*corrected_args, **kwargs) - model.forward = tracked_forward + # model.forward = tracked_forward + model.__dict__["forward"] = tracked_forward model._padiff_input_captured = True try: yield finally: - model.forward = original_forward + # model.forward = original_forward + model.__dict__["forward"] = original_forward @contextlib.contextmanager diff --git a/padiff/cli.py b/padiff/cli.py index 70d3123..22c538c 100644 --- a/padiff/cli.py +++ b/padiff/cli.py @@ -42,13 +42,8 @@ def load_yaml_config(config_path): for k, v in config["CLI"].items(): cli_cfg[k] = v - guard_cfg = {} - for k, v in config.get("PaDiffGuard", {}).items(): - guard_cfg[k] = v - - compare_cfg = {} - for k, v in config.get("COMPARE", {}).items(): - compare_cfg[k] = v + guard_cfg = dict(config.get("PaDiffGuard") or {}) + compare_cfg = dict(config.get("COMPARE") or {}) return cli_cfg, guard_cfg, compare_cfg @@ -186,6 +181,7 @@ def main(): rtol: 1.0e-06 compare_mode: "mean" action_name: "equal" + check_mode: "fast" """, formatter_class=argparse.RawDescriptionHelpFormatter, ) diff --git a/padiff/utils/utils.py b/padiff/utils/utils.py index 1c5460e..413ab13 100644 --- a/padiff/utils/utils.py +++ b/padiff/utils/utils.py @@ -144,6 +144,13 @@ def traverse(structure, on_leaf, on_container=None): if isinstance(structure, (paddle.Tensor, torch.Tensor)): return on_leaf(structure) + # namedtuple + if hasattr(structure, "_fields"): + result = type(structure)( + *[traverse(getattr(structure, field), on_leaf, on_container) for field in structure._fields] + ) + return on_container(result) if on_container else result + # dict if isinstance(structure, dict): new_dict = type(structure)() @@ -154,7 +161,7 @@ def traverse(structure, on_leaf, on_container=None): # list or tuple if isinstance(structure, (list, tuple)): result = [traverse(item, on_leaf, on_container) for item in structure] - return on_container(result) if on_container else result + return on_container(result) if on_container else type(structure)(result) # Sequence-like objects (e.g., DynamicCache, ModelOutput) if hasattr(structure, "__getitem__") and hasattr(structure, "__len__"): @@ -169,13 +176,6 @@ def traverse(structure, on_leaf, on_container=None): except Exception: pass - # namedtuple - if hasattr(structure, "_fields"): - result = type(structure)( - *[traverse(getattr(structure, field), on_leaf, on_container) for field in structure._fields] - ) - return on_container(result) if on_container else result - # object with __dict__ if hasattr(structure, "__dict__"): try: