Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ python -m pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple

## 快速开始

**注意:当前暂不支持分布式**

### 使用单行命令对齐

将命令写入配置文件后,通过如下命令运行
Expand Down Expand Up @@ -61,6 +63,7 @@ if __name__ == "__main__":
"rtol": 1e-4,
"compare_mode": "abs_mean",
"action_name": "loose_equal",
"check_mode": "fast",
}

pt_dump_path = "torch_proj/padiff_dump/model_torch"
Expand Down
21 changes: 13 additions & 8 deletions docs/CLIConfig.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@

定义结果对比的精度和逻辑。

| 参数 | 类型 | 必需 | 默认值 | 说明 |
| -------------- | ------ | ---- | ------- | --------------------------------------- |
| `atol` | float | 否 | 1e-6 | 绝对误差容忍度 |
| `rtol` | float | 否 | 1e-6 | 相对误差容忍度 |
| `compare_mode` | string | 否 | "mean" | 对比模式 ("mean", "strict", "abs_mean") |
| `action_name` | string | 否 | "equal" | 对比动作 ("equal", "loose_equal") |
| 参数 | 类型 | 必需 | 默认值 | 说明 |
| -------------- | ------ | ---- | ------- | ---------------------------------------- |
| `atol` | float | 否 | 1e-6 | 绝对误差容忍度 |
| `rtol` | float | 否 | 1e-6 | 相对误差容忍度 |
| `compare_mode` | string | 否 | "mean" | 数值对比模式 ("mean", "strict", "abs_mean") |
| `action_name` | string | 否 | "equal" | 层对比策略 ("equal", "loose_equal") |
| `check_mode` | string | 否 | "fast" | 模型对比策略 ("fast", "deep") |

## 示例和详细说明

Expand Down Expand Up @@ -207,18 +208,22 @@ compare_dumps(pt_dump_path, pd_dump_path, cfg)
- 控制模型输出结果的对比精度和模式
- atol: 绝对误差容忍度 (default: 1e-6)
- rtol: 相对误差容忍度 (default: 1e-6)
- compare_mode: 对比模式,可选值: mean, strict, abs_mean, 默认值: "mean"
- compare_mode: 数值对比模式,可选值: mean, strict, abs_mean, 默认值: "mean"
- mean: 比较传入数据的均值
- strict: 直接比较传入数据
- abs_mean: 比较传入数据的绝对值的均值
- action_name: 对比行为,可选值: equal, loose_equal, 默认值: "equal"
- action_name: 层对比策略,可选值: equal, loose_equal, 默认值: "equal"
- equal: 进行严格的对比,比如会对输出的个数、shape 都进行检查
- loose_equal: 较为宽松的对比,会尝试尽可能多的匹配数据,比如两个模型某层的输出数量分别为 1 和 3 ,会对第一个输出进行比较,同时在两个数据的形状不匹配时会尝试 transpose 或 reshape。仅在该层所有数据都无法匹配时报错。
- check_mode: 模型对比策略,可选值: fast, deep, 默认值: "fast"
- fast: 进行自顶向下的对比,当上层模型精度检查成功时,不会再进入子层检查,一旦检查到错误就会报错退出
- deep: 强制检查所有层,即使上层模型精度检查成功,也进入子层进行检查,所有检查结束后才会报错退出

```
COMPARE:
atol: 1e-4
rtol: 1e-5
compare_mode: "mean"
action_name: "loose_equal"
check_mode: "fast"
```
139 changes: 81 additions & 58 deletions padiff/comparison/checker/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
clone_dict_tree,
get_all_valid_path,
load_json,
print_report_info,
print_multi_report_info,
reorder_and_match_sublayers,
logger,
struct_info_log,
Expand Down Expand Up @@ -75,78 +75,101 @@ def _check_report_impl(report_path_0, report_path_1, cfg=None, diff_phase="both"


def check_forward(nodes, reports, cfg):
logger.debug(f"Checking forward of {nodes[0]['name']}")
action_name = cfg.get("action_name", None)
act = get_action(reports[0], nodes[0], reports[1], nodes[1], name=action_name)
try:
act(nodes[0]["fwd_outputs"], nodes[1]["fwd_outputs"], cfg)
return True
except Exception as e:
compare_info = e
if len(nodes[0]["children"]) == 0 or len(nodes[1]["children"]) == 0:
print_report_info(nodes, reports, e, "Forward")
return False

# reorder current level
try:
if not nodes[1]["reordered"]:
reorder_and_match_sublayers(nodes, reports)
except Exception as e:
msg = f"While checking forward, diff found at {nodes[0]['name']}(base) vs {nodes[1]['name']}(raw)\n"
msg += "Call `reorder_and_match_sublayers` for more detailed infos, but error occurs again:\n"
msg += f"{type(e).__name__}: {str(e)}"
logger.error(msg)
# print_report_info(nodes, reports, compare_info, "Forward", msg)
# return False

for child_0, child_1 in zip(nodes[0]["children"], nodes[1]["children"]):
res = check_forward((child_0, child_1), reports, cfg)
if res == False:
return False

# sublayers is compared ok, but diff found at father layer
msg = (
f"\n ⚠️ Sublayers of {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but diff found at their output! "
"\n 💡 This might be reasonable since errors accumulate if single_step mode is enabled."
failures = _check_node(
nodes=nodes,
reports=reports,
cfg=cfg,
data_key="fwd_outputs",
direction="forward",
)
print_report_info(nodes, reports, compare_info, "Forward", msg)
return False
success = print_multi_report_info(failures, stage="Forward")
return success


def check_backward(nodes, reports, cfg):
logger.debug(f"Checking backward of {nodes[0]['name']}")
failures = _check_node(nodes=nodes, reports=reports, cfg=cfg, data_key="bwd_grads", direction="backward")
success = print_multi_report_info(failures, stage="Backward")
return success


def _check_node(nodes, reports, cfg, data_key, direction):
"""
Generic function to check forward or backward outputs.

Args:
nodes: (node_base, node_raw)
reports: (report_base, report_raw)
cfg: config dict
data_key: "fwd_outputs" or "bwd_grads"
direction: "forward" or "backward"
"""
logger.debug(f"Checking {direction.lower()} of {nodes[0]['name']}")

check_mode = cfg.get("check_mode", "fast").lower()
action_name = cfg.get("action_name", None)
act = get_action(reports[0], nodes[0], reports[1], nodes[1], name=action_name)

parent_error = None
failures = []

# Step 1: Compare current layer
try:
act(nodes[0]["bwd_grads"], nodes[1]["bwd_grads"], cfg)
return True
act(nodes[0][data_key], nodes[1][data_key], cfg)
except Exception as e:
compare_info = e
parent_error = e
if len(nodes[0]["children"]) == 0 or len(nodes[1]["children"]) == 0:
print_report_info(nodes, reports, e, "Backward")
return False
failures.append({"nodes": nodes, "reports": reports, "exc": e, "msg": None, "direction": direction})
return failures
logger.debug(f"{direction} mismatch at {nodes[0]['name']} -> will check children (check_mode={check_mode})")

# reorder current level
# Step 2: Early return if should not check sublayer
if check_mode == "fast" and parent_error is None:
logger.debug(f"{direction} PASSED -> skip children (check_mode=fast)")
return []

# Step 3: Try to reorder current level
try:
if not nodes[1]["reordered"]:
reorder_and_match_sublayers(nodes, reports)
except Exception as e:
msg = f"While checking backward, diff found at {nodes[0]['name']}(base) vs {nodes[1]['name']}(raw)\n"
msg += "Call `reorder_and_match_sublayers` for more detailed infos, but error occurs again:\n"
msg += f"{type(e).__name__}: {str(e)}"
logger.error(msg)
# print_report_info(nodes, reports, compare_info, "Backward", msg)
# return False

for child_0, child_1 in zip(reversed(nodes[0]["children"]), reversed(nodes[1]["children"])):
res = check_backward((child_0, child_1), reports, cfg)
if res == False:
return False
# some mismatches are allowed when action_name='loose_equal', so no error is returned here, only print it
logger.error(
f"While checking {direction.lower()}, diff found at {nodes[0]['name']}(base) vs {nodes[1]['name']}(raw)\n"
"Call `reorder_and_match_sublayers` for more detailed infos, but error occurs again:\n"
f"{type(e).__name__}: {str(e)}"
)

# Step 4: Recursively check all sublayers
# Note: Backward often checks in reverse order
children_zip = (
zip(reversed(nodes[0]["children"]), reversed(nodes[1]["children"]))
if direction.lower() == "backward"
else zip(nodes[0]["children"], nodes[1]["children"])
)

# sublayers is compared ok, but diff found at father layer
msg = f"Grad of sublayer {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but current grad found diff!"
print_report_info(nodes, reports, compare_info, "Backward", msg)
return False
for child_0, child_1 in children_zip:
child_failures = _check_node(
(child_0, child_1),
reports,
cfg,
data_key=data_key,
direction=direction,
)
if child_failures:
failures.extend(child_failures)
if check_mode == "fast":
return failures

# Step 5: All children passed, but parent failed
if parent_error is not None and not failures:
msg = (
f"\n ⚠️ Grad of sublayer '{nodes[0]['name']}' and '{nodes[1]['name']}' are corresponded, "
f"but current {direction.lower()} output found diff!"
"\n 💡 This might be reasonable since errors accumulate if single_step mode is enabled."
)
failures.append({"nodes": nodes, "reports": reports, "exc": parent_error, "msg": msg, "direction": direction})

return failures


def check_layer_map(reports):
Expand Down
1 change: 1 addition & 0 deletions padiff/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"rtol": 1e-7,
"compare_mode": "mean",
"action_name": "equal",
"check_mode": "fast",
}


Expand Down
65 changes: 55 additions & 10 deletions padiff/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def print_report_info(nodes, reports, exc, stage, msg=None):
logger.info(retstr)


def tree_print(node, mark=None, prefix=[]):
def tree_print(node, marks=None, prefix=[]):
if marks is None:
marks = []
elif not isinstance(marks, (list, tuple)):
marks = [marks]

cur_str = ""
for i, s in enumerate(prefix):
if i == len(prefix) - 1:
Expand All @@ -153,9 +158,9 @@ def tree_print(node, mark=None, prefix=[]):
cur_str += node["name"]
if "available" in node and node["available"] == False:
cur_str += " (skip)"
if os.getenv("PADIFF_PATH_LOG") == "ON":
if os.getenv("PADIFF_PATH_LOG") == "ON" or os.getenv("PADIFF_LOG_LEVEL") == "DEBUG":
cur_str += " (" + node["route"] + ")"
if mark is node:
if node in marks:
cur_str += " <--- *** HERE ***"

ret_strs = [cur_str]
Expand All @@ -164,7 +169,7 @@ def tree_print(node, mark=None, prefix=[]):
if i == len(node["children"]) - 1:
pre = " +--- "
prefix.append(pre)
retval = tree_print(child, mark, prefix)
retval = tree_print(child, marks, prefix)
ret_strs.extend(retval)
prefix.pop()

Expand All @@ -179,23 +184,34 @@ def build_file_name(report, file_name):
return file_name + ".log"


def struct_info(report, node, file_prefix):
def struct_info(report, mark, file_prefix):
file_name = build_file_name(report, file_prefix + "_" + report["model_name"])
title = f"{report['model_name']}(without layers in blacklist)\n" + "=" * 40 + "\n"

if not isinstance(mark, (list, tuple)):
marks = [mark]
else:
marks = mark

retval = []
for tree in report["tree"]:
retval.extend(tree_print(tree, mark=node, prefix=[" " * 4]))
retval.extend(tree_print(tree, marks=marks, prefix=[" " * 4]))

info = title + "\n".join(retval)
logger.log_file(file_name, "w", info)
return file_name


def struct_info_log(reports, nodes, file_prefix):
if isinstance(nodes, (list, tuple)):
if len(nodes) == 0:
return ""
else:
nodes = [nodes]

file_names = []
for idx in range(2):
node = nodes[idx]
report = reports[idx]
file_name = struct_info(report, node, file_prefix)
for idx, report in enumerate(reports):
file_name = struct_info(report, nodes[idx], file_prefix)
file_names.append(file_name)
retval = (
f"Model struct files saved in: '{logger.log_path}/{file_names[0]}' vs '{logger.log_path}/{file_names[1]}'\n"
Expand All @@ -206,3 +222,32 @@ def struct_info_log(reports, nodes, file_prefix):
def save_model_struct(report, file_prefix="arch"):
file_name = struct_info(report, None, file_prefix)
logger.info(f"Model struct saved in: '{logger.log_path}/{file_name}' without layers in blacklist\n")


def print_multi_report_info(failure_list, stage="Forward"):
if not failure_list:
return True

logger.error(f"FAILED !!! '{stage}' Stage Mismatch! Found {len(failure_list)} mismatched layers.")

marked_nodes = [[], []]
for failure in failure_list:
exc = failure["exc"]
nodes = failure["nodes"]
logger.error(
f"\n Layer: {nodes[0]['name']} vs {nodes[1]['name']}"
f"\n Route: {nodes[0]['route']} vs {nodes[1]['route']}"
f"\nError({type(exc).__name__}): {str(exc)} \n"
)
if failure["msg"] is not None:
logger.warning("ADDITIONAL MESSAGE:")
logger.warning(failure["msg"] + " \n")

marked_nodes[0].append(nodes[0]["origin_node"])
marked_nodes[1].append(nodes[1]["origin_node"])

reports = failure_list[0]["reports"]
retstr = struct_info_log(reports, marked_nodes, "report")
logger.info(retstr)

return False