From bea852402bbba108e703e4bb313a4999c4d5565d Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Sun, 28 Sep 2025 09:51:30 +0000 Subject: [PATCH] feat: optimize check mode --- README.md | 3 + docs/CLIConfig.md | 21 ++-- padiff/comparison/checker/reports.py | 139 ++++++++++++++++----------- padiff/configs/__init__.py | 1 + padiff/utils/log.py | 65 +++++++++++-- 5 files changed, 153 insertions(+), 76 deletions(-) diff --git a/README.md b/README.md index 7673720..5f52d41 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ python -m pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple ## 快速开始 +**注意:当前暂不支持分布式** + ### 使用单行命令对齐 将命令写入配置文件后,通过如下命令运行 @@ -61,6 +63,7 @@ if __name__ == "__main__": "rtol": 1e-4, "compare_mode": "abs_mean", "action_name": "loose_equal", + "check_mode": "fast", } pt_dump_path = "torch_proj/padiff_dump/model_torch" diff --git a/docs/CLIConfig.md b/docs/CLIConfig.md index e74a988..4e7e6ef 100644 --- a/docs/CLIConfig.md +++ b/docs/CLIConfig.md @@ -38,12 +38,13 @@ 定义结果对比的精度和逻辑。 -| 参数 | 类型 | 必需 | 默认值 | 说明 | -| -------------- | ------ | ---- | ------- | --------------------------------------- | -| `atol` | float | 否 | 1e-6 | 绝对误差容忍度 | -| `rtol` | float | 否 | 1e-6 | 相对误差容忍度 | -| `compare_mode` | string | 否 | "mean" | 对比模式 ("mean", "strict", "abs_mean") | -| `action_name` | string | 否 | "equal" | 对比动作 ("equal", "loose_equal") | +| 参数 | 类型 | 必需 | 默认值 | 说明 | +| -------------- | ------ | ---- | ------- | ---------------------------------------- | +| `atol` | float | 否 | 1e-6 | 绝对误差容忍度 | +| `rtol` | float | 否 | 1e-6 | 相对误差容忍度 | +| `compare_mode` | string | 否 | "mean" | 数值对比模式 ("mean", "strict", "abs_mean") | +| `action_name` | string | 否 | "equal" | 层对比策略 ("equal", "loose_equal") | +| `check_mode` | string | 否 | "fast" | 模型对比策略 ("fast", "deep") | ## 示例和详细说明 @@ -207,13 +208,16 @@ compare_dumps(pt_dump_path, pd_dump_path, cfg) - 控制模型输出结果的对比精度和模式 - atol: 绝对误差容忍度 (default: 1e-6) - rtol: 相对误差容忍度 (default: 1e-6) -- compare_mode: 对比模式,可选值: mean, strict, abs_mean, 默认值: "mean" +- compare_mode: 数值对比模式,可选值: mean, strict, abs_mean, 默认值: "mean" - mean: 比较传入数据的均值 - strict: 直接比较传入数据 - abs_mean: 比较传入数据的绝对值的均值 -- action_name: 对比行为,可选值: equal, loose_equal, 默认值: "equal" +- action_name: 层对比策略,可选值: equal, loose_equal, 默认值: "equal" - equal: 进行严格的对比,比如会对输出的个数、shape 都进行检查 - loose_equal: 较为宽松的对比,会尝试尽可能多的匹配数据,比如两个模型某层的输出数量分别为 1 和 3 ,会对第一个输出进行比较,同时在两个数据的形状不匹配时会尝试 transpose 或 reshape。仅在该层所有数据都无法匹配时报错。 +- check_mode: 模型对比策略,可选值: fast, deep, 默认值: "fast" + - fast: 进行自顶向下的对比,当上层模型精度检查成功时,不会再进入子层检查,一旦检查到错误就会报错退出 + - deep: 强制检查所有层,即使上层模型精度检查成功,也进入子层进行检查,所有检查结束后才会报错退出 ``` COMPARE: @@ -221,4 +225,5 @@ COMPARE: rtol: 1e-5 compare_mode: "mean" action_name: "loose_equal" + check_mode: "fast" ``` diff --git a/padiff/comparison/checker/reports.py b/padiff/comparison/checker/reports.py index 38cccbe..d594396 100644 --- a/padiff/comparison/checker/reports.py +++ b/padiff/comparison/checker/reports.py @@ -18,7 +18,7 @@ clone_dict_tree, get_all_valid_path, load_json, - print_report_info, + print_multi_report_info, reorder_and_match_sublayers, logger, struct_info_log, @@ -75,78 +75,101 @@ def _check_report_impl(report_path_0, report_path_1, cfg=None, diff_phase="both" def check_forward(nodes, reports, cfg): - logger.debug(f"Checking forward of {nodes[0]['name']}") - action_name = cfg.get("action_name", None) - act = get_action(reports[0], nodes[0], reports[1], nodes[1], name=action_name) - try: - act(nodes[0]["fwd_outputs"], nodes[1]["fwd_outputs"], cfg) - return True - except Exception as e: - compare_info = e - if len(nodes[0]["children"]) == 0 or len(nodes[1]["children"]) == 0: - print_report_info(nodes, reports, e, "Forward") - return False - - # reorder current level - try: - if not nodes[1]["reordered"]: - reorder_and_match_sublayers(nodes, reports) - except Exception as e: - msg = f"While checking forward, diff found at {nodes[0]['name']}(base) vs {nodes[1]['name']}(raw)\n" - msg += "Call `reorder_and_match_sublayers` for more detailed infos, but error occurs again:\n" - msg += f"{type(e).__name__}: {str(e)}" - logger.error(msg) - # print_report_info(nodes, reports, compare_info, "Forward", msg) - # return False - - for child_0, child_1 in zip(nodes[0]["children"], nodes[1]["children"]): - res = check_forward((child_0, child_1), reports, cfg) - if res == False: - return False - - # sublayers is compared ok, but diff found at father layer - msg = ( - f"\n ⚠️ Sublayers of {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but diff found at their output! " - "\n 💡 This might be reasonable since errors accumulate if single_step mode is enabled." + failures = _check_node( + nodes=nodes, + reports=reports, + cfg=cfg, + data_key="fwd_outputs", + direction="forward", ) - print_report_info(nodes, reports, compare_info, "Forward", msg) - return False + success = print_multi_report_info(failures, stage="Forward") + return success def check_backward(nodes, reports, cfg): - logger.debug(f"Checking backward of {nodes[0]['name']}") + failures = _check_node(nodes=nodes, reports=reports, cfg=cfg, data_key="bwd_grads", direction="backward") + success = print_multi_report_info(failures, stage="Backward") + return success + + +def _check_node(nodes, reports, cfg, data_key, direction): + """ + Generic function to check forward or backward outputs. + + Args: + nodes: (node_base, node_raw) + reports: (report_base, report_raw) + cfg: config dict + data_key: "fwd_outputs" or "bwd_grads" + direction: "forward" or "backward" + """ + logger.debug(f"Checking {direction.lower()} of {nodes[0]['name']}") + + check_mode = cfg.get("check_mode", "fast").lower() action_name = cfg.get("action_name", None) act = get_action(reports[0], nodes[0], reports[1], nodes[1], name=action_name) + + parent_error = None + failures = [] + + # Step 1: Compare current layer try: - act(nodes[0]["bwd_grads"], nodes[1]["bwd_grads"], cfg) - return True + act(nodes[0][data_key], nodes[1][data_key], cfg) except Exception as e: - compare_info = e + parent_error = e if len(nodes[0]["children"]) == 0 or len(nodes[1]["children"]) == 0: - print_report_info(nodes, reports, e, "Backward") - return False + failures.append({"nodes": nodes, "reports": reports, "exc": e, "msg": None, "direction": direction}) + return failures + logger.debug(f"{direction} mismatch at {nodes[0]['name']} -> will check children (check_mode={check_mode})") - # reorder current level + # Step 2: Early return if should not check sublayer + if check_mode == "fast" and parent_error is None: + logger.debug(f"{direction} PASSED -> skip children (check_mode=fast)") + return [] + + # Step 3: Try to reorder current level try: if not nodes[1]["reordered"]: reorder_and_match_sublayers(nodes, reports) except Exception as e: - msg = f"While checking backward, diff found at {nodes[0]['name']}(base) vs {nodes[1]['name']}(raw)\n" - msg += "Call `reorder_and_match_sublayers` for more detailed infos, but error occurs again:\n" - msg += f"{type(e).__name__}: {str(e)}" - logger.error(msg) - # print_report_info(nodes, reports, compare_info, "Backward", msg) - # return False - - for child_0, child_1 in zip(reversed(nodes[0]["children"]), reversed(nodes[1]["children"])): - res = check_backward((child_0, child_1), reports, cfg) - if res == False: - return False + # some mismatches are allowed when action_name='loose_equal', so no error is returned here, only print it + logger.error( + f"While checking {direction.lower()}, diff found at {nodes[0]['name']}(base) vs {nodes[1]['name']}(raw)\n" + "Call `reorder_and_match_sublayers` for more detailed infos, but error occurs again:\n" + f"{type(e).__name__}: {str(e)}" + ) + + # Step 4: Recursively check all sublayers + # Note: Backward often checks in reverse order + children_zip = ( + zip(reversed(nodes[0]["children"]), reversed(nodes[1]["children"])) + if direction.lower() == "backward" + else zip(nodes[0]["children"], nodes[1]["children"]) + ) - # sublayers is compared ok, but diff found at father layer - msg = f"Grad of sublayer {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but current grad found diff!" - print_report_info(nodes, reports, compare_info, "Backward", msg) - return False + for child_0, child_1 in children_zip: + child_failures = _check_node( + (child_0, child_1), + reports, + cfg, + data_key=data_key, + direction=direction, + ) + if child_failures: + failures.extend(child_failures) + if check_mode == "fast": + return failures + + # Step 5: All children passed, but parent failed + if parent_error is not None and not failures: + msg = ( + f"\n ⚠️ Grad of sublayer '{nodes[0]['name']}' and '{nodes[1]['name']}' are corresponded, " + f"but current {direction.lower()} output found diff!" + "\n 💡 This might be reasonable since errors accumulate if single_step mode is enabled." + ) + failures.append({"nodes": nodes, "reports": reports, "exc": parent_error, "msg": msg, "direction": direction}) + + return failures def check_layer_map(reports): diff --git a/padiff/configs/__init__.py b/padiff/configs/__init__.py index 2073fb5..a1f4603 100644 --- a/padiff/configs/__init__.py +++ b/padiff/configs/__init__.py @@ -21,6 +21,7 @@ "rtol": 1e-7, "compare_mode": "mean", "action_name": "equal", + "check_mode": "fast", } diff --git a/padiff/utils/log.py b/padiff/utils/log.py index d47226c..e469559 100644 --- a/padiff/utils/log.py +++ b/padiff/utils/log.py @@ -137,7 +137,12 @@ def print_report_info(nodes, reports, exc, stage, msg=None): logger.info(retstr) -def tree_print(node, mark=None, prefix=[]): +def tree_print(node, marks=None, prefix=[]): + if marks is None: + marks = [] + elif not isinstance(marks, (list, tuple)): + marks = [marks] + cur_str = "" for i, s in enumerate(prefix): if i == len(prefix) - 1: @@ -153,9 +158,9 @@ def tree_print(node, mark=None, prefix=[]): cur_str += node["name"] if "available" in node and node["available"] == False: cur_str += " (skip)" - if os.getenv("PADIFF_PATH_LOG") == "ON": + if os.getenv("PADIFF_PATH_LOG") == "ON" or os.getenv("PADIFF_LOG_LEVEL") == "DEBUG": cur_str += " (" + node["route"] + ")" - if mark is node: + if node in marks: cur_str += " <--- *** HERE ***" ret_strs = [cur_str] @@ -164,7 +169,7 @@ def tree_print(node, mark=None, prefix=[]): if i == len(node["children"]) - 1: pre = " +--- " prefix.append(pre) - retval = tree_print(child, mark, prefix) + retval = tree_print(child, marks, prefix) ret_strs.extend(retval) prefix.pop() @@ -179,23 +184,34 @@ def build_file_name(report, file_name): return file_name + ".log" -def struct_info(report, node, file_prefix): +def struct_info(report, mark, file_prefix): file_name = build_file_name(report, file_prefix + "_" + report["model_name"]) title = f"{report['model_name']}(without layers in blacklist)\n" + "=" * 40 + "\n" + + if not isinstance(mark, (list, tuple)): + marks = [mark] + else: + marks = mark + retval = [] for tree in report["tree"]: - retval.extend(tree_print(tree, mark=node, prefix=[" " * 4])) + retval.extend(tree_print(tree, marks=marks, prefix=[" " * 4])) + info = title + "\n".join(retval) logger.log_file(file_name, "w", info) return file_name def struct_info_log(reports, nodes, file_prefix): + if isinstance(nodes, (list, tuple)): + if len(nodes) == 0: + return "" + else: + nodes = [nodes] + file_names = [] - for idx in range(2): - node = nodes[idx] - report = reports[idx] - file_name = struct_info(report, node, file_prefix) + for idx, report in enumerate(reports): + file_name = struct_info(report, nodes[idx], file_prefix) file_names.append(file_name) retval = ( f"Model struct files saved in: '{logger.log_path}/{file_names[0]}' vs '{logger.log_path}/{file_names[1]}'\n" @@ -206,3 +222,32 @@ def struct_info_log(reports, nodes, file_prefix): def save_model_struct(report, file_prefix="arch"): file_name = struct_info(report, None, file_prefix) logger.info(f"Model struct saved in: '{logger.log_path}/{file_name}' without layers in blacklist\n") + + +def print_multi_report_info(failure_list, stage="Forward"): + if not failure_list: + return True + + logger.error(f"FAILED !!! '{stage}' Stage Mismatch! Found {len(failure_list)} mismatched layers.") + + marked_nodes = [[], []] + for failure in failure_list: + exc = failure["exc"] + nodes = failure["nodes"] + logger.error( + f"\n Layer: {nodes[0]['name']} vs {nodes[1]['name']}" + f"\n Route: {nodes[0]['route']} vs {nodes[1]['route']}" + f"\nError({type(exc).__name__}): {str(exc)} \n" + ) + if failure["msg"] is not None: + logger.warning("ADDITIONAL MESSAGE:") + logger.warning(failure["msg"] + " \n") + + marked_nodes[0].append(nodes[0]["origin_node"]) + marked_nodes[1].append(nodes[1]["origin_node"]) + + reports = failure_list[0]["reports"] + retstr = struct_info_log(reports, marked_nodes, "report") + logger.info(retstr) + + return False