Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions padiff/abstracts/hooks/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ def serialize(x):

return original_forward(*corrected_args, **kwargs)

model.forward = tracked_forward
# model.forward = tracked_forward
model.__dict__["forward"] = tracked_forward
model._padiff_input_captured = True

try:
yield
finally:
model.forward = original_forward
# model.forward = original_forward
model.__dict__["forward"] = original_forward


@contextlib.contextmanager
Expand Down
10 changes: 3 additions & 7 deletions padiff/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,8 @@ def load_yaml_config(config_path):
for k, v in config["CLI"].items():
cli_cfg[k] = v

guard_cfg = {}
for k, v in config.get("PaDiffGuard", {}).items():
guard_cfg[k] = v

compare_cfg = {}
for k, v in config.get("COMPARE", {}).items():
compare_cfg[k] = v
guard_cfg = dict(config.get("PaDiffGuard") or {})
compare_cfg = dict(config.get("COMPARE") or {})

return cli_cfg, guard_cfg, compare_cfg

Expand Down Expand Up @@ -186,6 +181,7 @@ def main():
rtol: 1.0e-06
compare_mode: "mean"
action_name: "equal"
check_mode: "fast"
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
Expand Down
16 changes: 8 additions & 8 deletions padiff/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def traverse(structure, on_leaf, on_container=None):
if isinstance(structure, (paddle.Tensor, torch.Tensor)):
return on_leaf(structure)

# namedtuple
if hasattr(structure, "_fields"):
result = type(structure)(
*[traverse(getattr(structure, field), on_leaf, on_container) for field in structure._fields]
)
return on_container(result) if on_container else result

# dict
if isinstance(structure, dict):
new_dict = type(structure)()
Expand All @@ -154,7 +161,7 @@ def traverse(structure, on_leaf, on_container=None):
# list or tuple
if isinstance(structure, (list, tuple)):
result = [traverse(item, on_leaf, on_container) for item in structure]
return on_container(result) if on_container else result
return on_container(result) if on_container else type(structure)(result)

# Sequence-like objects (e.g., DynamicCache, ModelOutput)
if hasattr(structure, "__getitem__") and hasattr(structure, "__len__"):
Expand All @@ -169,13 +176,6 @@ def traverse(structure, on_leaf, on_container=None):
except Exception:
pass

# namedtuple
if hasattr(structure, "_fields"):
result = type(structure)(
*[traverse(getattr(structure, field), on_leaf, on_container) for field in structure._fields]
)
return on_container(result) if on_container else result

# object with __dict__
if hasattr(structure, "__dict__"):
try:
Expand Down