From d8ad294d81401ceec272c4bf9338cdefdf7baf38 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Fri, 26 Sep 2025 07:03:52 +0000 Subject: [PATCH] feat:support user-defined alignment order when using cli & optimizing the single-step strategy --- README.md | 3 +- docs/CLIConfig.md | 11 +++++ docs/config_example.yaml | 1 + padiff/__init__.py | 2 +- padiff/abstracts/hooks/hook.py | 84 +++++++++++++++++++++++++--------- padiff/cli.py | 52 +++++++++++++-------- setup.py | 2 +- 7 files changed, 112 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 637a0eb..7673720 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ PaDiff 是基于 PaddlePaddle 与 PyTorch 的模型精度对齐工具。传入 P - 手动修改 Paddle 和 PyTorch 脚本,并在**分别**运行(不需要同时)这两个脚本,根据脚本 PaDiff 会分别监控运行过程并 dump 数据,最后手动调用 API 接口对比数据,得到对齐结果。两次运行可以在**两个环境**中,但每个环境中都需要同时存在 `paddlepaddle` 和 `torch` 包 -- 参考[旧版本特性(v0.2版本)](#旧版本特性(v0.2版本)) 使用 auto_diff 接口进行对齐,这种方法需要将 Paddle 和 PyTorch 放在同一个文件中,同时将 `paddlepaddle` 和 `torch` 包安装在同一环境中,不仅两个模型的运行过程耦合,代码修改量也比较大,因此**不推荐**此用法 +- 参考[旧版本特性(v0.2版本)](#readme-v02) 使用 auto_diff 接口进行对齐,这种方法需要将 Paddle 和 PyTorch 放在同一个文件中,同时将 `paddlepaddle` 和 `torch` 包安装在同一环境中,不仅两个模型的运行过程耦合,代码修改量也比较大,因此**不推荐**此用法 ## 安装 @@ -79,6 +79,7 @@ if __name__ == "__main__": 为了保持控制台信息简洁,可以设置环境变量 `PADIFF_SILENT=1`,此模式下仅保存 log 文件,不在控制台输出 log 信息 + ## 旧版本特性(v0.2版本) ### 使用 auto_diff 接口和其它方法 diff --git a/docs/CLIConfig.md b/docs/CLIConfig.md index b96d4f0..e74a988 100644 --- a/docs/CLIConfig.md +++ b/docs/CLIConfig.md @@ -17,6 +17,7 @@ | `pd_model_name` | string | 否 | "model" | PaddlePaddle 模型实例的变量名 | | `pt_optim_name` | string | 否 | null | PyTorch 优化器实例的变量名 | | `pd_optim_name` | string | 否 | null | PaddlePaddle 优化器实例的变量名 | +| `base_framework`| string | 否 | "torch" | 设置作为 base 的框架 | | `log_dir` | string | 否 | "./padiff_log" | 日志和报告的输出目录 | ## PaDiffGuard 部分 @@ -116,6 +117,16 @@ trainer.train() # 由于 trainer.train() 中通常已经包含了完整的前反向过程,因此不需要传递此参数 ``` +#### base 框架 (base_framework) + +- 设置对齐方向,默认值 'troch',即 paddle 代码向 torch 代码对齐 +- 默认值: "torch",此时在逐层对齐时,会在运行 torch 代码时保存各层输出,并在运行 paddle 代码时读取并替换 +- 当 base 模型的输出个数少于 raw 模型时,会找不到输出而报错,此时可以修改此参数,交换二者的顺序 + +``` +base_framework: "torch" +``` + #### 日志目录 (log_dir) - 指定生成报告和日志的目录 diff --git a/docs/config_example.yaml b/docs/config_example.yaml index 5d35b7d..c5ef01a 100644 --- a/docs/config_example.yaml +++ b/docs/config_example.yaml @@ -8,6 +8,7 @@ CLI: pd_model_name: "pd_model" # PaddlePaddle 模型变量名 pt_optim_name: "pt_optimizer" # (可选) PyTorch 优化器变量名 pd_optim_name: "pd_optimizer" # (可选) PaddlePaddle 优化器变量名 + base_framework: "torch" # (可选) 设置作为 base 的框架 log_dir: "./padiff_log" # (可选) 日志目录 # --- COMPARE 部分 --- diff --git a/padiff/__init__.py b/padiff/__init__.py index e52c3bc..626c5f6 100644 --- a/padiff/__init__.py +++ b/padiff/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -__version__ = "0.3.0" +__version__ = "0.4.0" # for api -> Layer diff --git a/padiff/abstracts/hooks/hook.py b/padiff/abstracts/hooks/hook.py index 6f90fa7..5b5167c 100644 --- a/padiff/abstracts/hooks/hook.py +++ b/padiff/abstracts/hooks/hook.py @@ -32,6 +32,9 @@ from .base import current_report, find_base_report_node, single_step_state +_seen_warnings = set() + + @contextlib.contextmanager def register_hooker(model): marker = model.marker @@ -266,27 +269,63 @@ def replace_forward_output(node, current_name=None): def inner(input_): nonlocal cur_idx - if isinstance(input_, (paddle.Tensor, torch.Tensor)): - if cur_idx >= len(numpy_file_list): - raise RuntimeError( - f"\n ⚠️ Single-step alignment FAILED: the {cur_idx + 1}st output is requested, " - f"but only {len(numpy_file_list)} pre-saved numpy files are available." + if not isinstance(input_, (paddle.Tensor, torch.Tensor)): + return input_ + + if cur_idx >= len(numpy_file_list): + warning_key = ("single-step: output_count_mismatch", current_name) + if warning_key not in _seen_warnings: + logger.warning( + f"\n ⚠️ Single-step alignment SKIPPED: the {cur_idx + 1}st output is requested, " + f"but only {len(numpy_file_list)} pre-saved from base model, skip the current output." + "\n ⚠️ This warning will not repeat for this layer." f"\n 📌 Layer Name: {current_name}(raw)" "\n 💡 Possible Causes and Solutions:" - "\n - The number of outputs from the current layer in the raw model does not match " + "\n - The number of outputs from the current layer in the raw model is bigger than" "that of its corresponding layer in the base model." "\n - Verify that both models have identical architectures for this layer." "\n - If the corresponding relationship of the current layer is correct, " "please disable single step mode, or add the layer to blacklist to skip the check of this layer." + "\n - Or when you are sure that the extra output does not need to be compared, " + "you can swap the execution order of the base model and the raw model." ) - value = np.load(numpy_file_list[cur_idx]["path"]) - cur_idx += 1 - if isinstance(input_, paddle.Tensor): - return paddle.to_tensor(value, dtype=input_.dtype) - else: - return torch.as_tensor(value, dtype=input_.dtype, device=input_.device) - else: + _seen_warnings.add(warning_key) + return input_ + + value = np.load(numpy_file_list[cur_idx]["path"]) + cur_idx += 1 + base_shape = tuple(value.shape) + raw_shape = tuple(input_.shape) + + if base_shape == raw_shape: + pass + + elif np.prod(base_shape) != np.prod(raw_shape): + warning_key = ("single-step: shape_mismatch", current_name) + if warning_key not in _seen_warnings: + logger.warning( + f"\n ⚠️ Single-step alignment SKIPPED: shape mismatch." + "\n ⚠️ This warning will not repeat for this layer." + f"\n 📌 Layer Name: {current_name}(raw)" + f"\n 📌 Shape: {base_shape}(base) vs {raw_shape}(raw)" + ) + _seen_warnings.add(warning_key) return input_ + else: + value = value.reshape(input_.shape) + debug_key = ("single-step: reshape_used", current_name) + if debug_key not in _seen_warnings: + logger.debug( + f"\n ⚠️ Try to reshape loaded value to input's shape of layer {current_name}(raw). " + "This may lead to numerical errors even if reshape succeeds." + "\n ⚠️ This warning will not repeat for this layer." + ) + _seen_warnings.add(debug_key) + + if isinstance(input_, paddle.Tensor): + return paddle.to_tensor(value, dtype=input_.dtype) + else: + return torch.as_tensor(value, dtype=input_.dtype, device=input_.device) return inner @@ -296,14 +335,17 @@ def single_step_check(report, net_id, step_idx, current_name, node_type, bwd_ite try: base_report_node = find_base_report_node(net_id, step_idx) if base_report_node["name"] != current_name: - warning_msg = ( - f"\n ⚠️ Single-step alignment WARNING: {node_type} with net_id={net_id} mismatch!\n" - f" 📌 Mismatch {node_type.capitalize()}: {base_report_node['name']}(base) vs {current_name}(raw)\n" - " 💡 Suggestion: Models have different architectures or initialization order. " - "Please check the model implementation or decrease 'align_depth' to reduce the alignment " - "granularity, or add layers that do not require alignment to the blacklist." - ) - logger.warning(warning_msg) + warning_key = ("single-step: name_mismatch", current_name) + if warning_key not in _seen_warnings: + logger.warning( + f"\n ⚠️ Single-step alignment WARNING: {node_type} with net_id={net_id} mismatch!" + "\n ⚠️ This warning will not repeat for this layer." + f"\n 📌 Mismatch {node_type.capitalize()}: {base_report_node['name']}(base) vs {current_name}(raw)" + "\n 💡 Suggestion: Models have different architectures or class name or initialization order. " + "Please check the model implementation or decrease 'align_depth' to reduce the alignment " + "granularity, or add layers that do not require alignment to the blacklist." + ) + _seen_warnings.add(warning_key) else: logger.debug(f"Single Step: {current_name}(net_id={net_id})") diff --git a/padiff/cli.py b/padiff/cli.py index e71b74b..70d3123 100644 --- a/padiff/cli.py +++ b/padiff/cli.py @@ -143,7 +143,15 @@ def main(): --pt_cmd "python /path/to/your/torch_script.py" --pd_cmd "python /path/to/your/paddle_script.py" - 3. 日志目录参数 (--log_dir): + 3. base 框架设置参数 (--base_framework): + * 非必需 被包含在 config 文件中,或 通过命令行传入 + * 设置对齐方向,默认值 'troch',即 paddle 代码向 torch 代码对齐 + * 影响逐层对齐时,是读取哪个框架下保存的数据 + + 示例: + --base_framework "torch" + + 4. 日志目录参数 (--log_dir): * 可选参数 * 指定生成报告和日志的目录 * 默认值: ./padiff_log @@ -159,6 +167,7 @@ def main(): pd_model_name: "pd_model" pt_optim_name: "pt_optimizer" # not required pd_optim_name: "pd_optimizer" # not required + base_framework: "torch" # not required log_dir: "./padiff_log" # not required PaDiffGuard: @@ -192,6 +201,12 @@ def main(): type=str, help="Override 'pd_cmd' (paddle command) in config, e.g., 'python /paddle_dir/paddle_model.py'", ) + parser.add_argument( + "--base_framework", + type=str, + choices=["torch", "paddle"], + help="Which framework's output should be used as the base framework (base). The other will align to it. Options: 'torch' or 'paddle'.", + ) parser.add_argument( "--log_dir", type=str, @@ -213,6 +228,8 @@ def main(): cli_cfg["pd_cmd"] = args.pd_cmd if args.log_dir: cli_cfg["log_dir"] = args.log_dir + if args.base_framework: + cli_cfg["base_framework"] = args.base_framework log_dir = cli_cfg.pop("log_dir", "./padiff_log") logger.reset_dir(log_dir) @@ -229,26 +246,23 @@ def main(): pt_optim_name = cli_cfg.get("pt_optim_name") pd_optim_name = cli_cfg.get("pd_optim_name") + tasks = { + "torch": {"cmd": pt_cmd, "framework": "torch", "model_name": pt_model_name, "optim_name": pt_optim_name}, + "paddle": {"cmd": pd_cmd, "framework": "paddle", "model_name": pd_model_name, "optim_name": pd_optim_name}, + } + base_fw = cli_cfg.pop("base_framework", "torch") + align_fw = "paddle" if base_fw == "torch" else "torch" + logger.info("Code injection and script execution...") try: - pt_dump_path = run_with_padiff( - cmd=pt_cmd, - framework="torch", - model_name=pt_model_name, - optim_name=pt_optim_name, - mode="base", - alignment_dir=None, - guard_cfg=guard_cfg, - ) - pd_dump_path = run_with_padiff( - cmd=pd_cmd, - framework="paddle", - model_name=pd_model_name, - optim_name=pd_optim_name, - mode="align", - alignment_dir=pt_dump_path, - guard_cfg=guard_cfg, - ) + logger.info(f"Running {base_fw.upper()} as BASE") + base_path = run_with_padiff(mode="base", alignment_dir=None, guard_cfg=guard_cfg, **tasks[base_fw]) + + logger.info(f"Running {align_fw.upper()} in ALIGN mode") + align_path = run_with_padiff(mode="align", alignment_dir=base_path, guard_cfg=guard_cfg, **tasks[align_fw]) + + pt_dump_path = base_path if base_fw == "torch" else align_path + pd_dump_path = base_path if base_fw == "paddle" else align_path except Exception as e: logger.error(f"An error occurred during execution: {type(e).__name__}: {str(e)}") import traceback diff --git a/setup.py b/setup.py index 2e8cc29..24fdfd9 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ from setuptools import find_packages, setup -VERSION = "0.3.0" +VERSION = "0.4.0" def read_requirements_file(filepath):