Skip to content

Commit d8ad294

Browse files
committed
feat:support user-defined alignment order when using cli & optimizing the single-step strategy
1 parent afde4e5 commit d8ad294

7 files changed

Lines changed: 112 additions & 43 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ PaDiff 是基于 PaddlePaddle 与 PyTorch 的模型精度对齐工具。传入 P
1313

1414
- 手动修改 Paddle 和 PyTorch 脚本,并在**分别**运行(不需要同时)这两个脚本,根据脚本 PaDiff 会分别监控运行过程并 dump 数据,最后手动调用 API 接口对比数据,得到对齐结果。两次运行可以在**两个环境**中,但每个环境中都需要同时存在 `paddlepaddle``torch`
1515

16-
- 参考[旧版本特性(v0.2版本)](#旧版本特性(v0.2版本)) 使用 auto_diff 接口进行对齐,这种方法需要将 Paddle 和 PyTorch 放在同一个文件中,同时将 `paddlepaddle``torch` 包安装在同一环境中,不仅两个模型的运行过程耦合,代码修改量也比较大,因此**不推荐**此用法
16+
- 参考[旧版本特性(v0.2版本)](#readme-v02) 使用 auto_diff 接口进行对齐,这种方法需要将 Paddle 和 PyTorch 放在同一个文件中,同时将 `paddlepaddle``torch` 包安装在同一环境中,不仅两个模型的运行过程耦合,代码修改量也比较大,因此**不推荐**此用法
1717

1818

1919
## 安装
@@ -79,6 +79,7 @@ if __name__ == "__main__":
7979

8080
为了保持控制台信息简洁,可以设置环境变量 `PADIFF_SILENT=1`,此模式下仅保存 log 文件,不在控制台输出 log 信息
8181

82+
<a id="readme-v02"></a>
8283
## 旧版本特性(v0.2版本)
8384

8485
### 使用 auto_diff 接口和其它方法

docs/CLIConfig.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
| `pd_model_name` | string || "model" | PaddlePaddle 模型实例的变量名 |
1818
| `pt_optim_name` | string || null | PyTorch 优化器实例的变量名 |
1919
| `pd_optim_name` | string || null | PaddlePaddle 优化器实例的变量名 |
20+
| `base_framework`| string || "torch" | 设置作为 base 的框架 |
2021
| `log_dir` | string || "./padiff_log" | 日志和报告的输出目录 |
2122

2223
## PaDiffGuard 部分
@@ -116,6 +117,16 @@ trainer.train()
116117
# 由于 trainer.train() 中通常已经包含了完整的前反向过程,因此不需要传递此参数
117118
```
118119

120+
#### base 框架 (base_framework)
121+
122+
- 设置对齐方向,默认值 'troch',即 paddle 代码向 torch 代码对齐
123+
- 默认值: "torch",此时在逐层对齐时,会在运行 torch 代码时保存各层输出,并在运行 paddle 代码时读取并替换
124+
- 当 base 模型的输出个数少于 raw 模型时,会找不到输出而报错,此时可以修改此参数,交换二者的顺序
125+
126+
```
127+
base_framework: "torch"
128+
```
129+
119130
#### 日志目录 (log_dir)
120131

121132
- 指定生成报告和日志的目录

docs/config_example.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CLI:
88
pd_model_name: "pd_model" # PaddlePaddle 模型变量名
99
pt_optim_name: "pt_optimizer" # (可选) PyTorch 优化器变量名
1010
pd_optim_name: "pd_optimizer" # (可选) PaddlePaddle 优化器变量名
11+
base_framework: "torch" # (可选) 设置作为 base 的框架
1112
log_dir: "./padiff_log" # (可选) 日志目录
1213

1314
# --- COMPARE 部分 ---

padiff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
__version__ = "0.3.0"
16+
__version__ = "0.4.0"
1717

1818

1919
# for api -> Layer

padiff/abstracts/hooks/hook.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from .base import current_report, find_base_report_node, single_step_state
3333

3434

35+
_seen_warnings = set()
36+
37+
3538
@contextlib.contextmanager
3639
def register_hooker(model):
3740
marker = model.marker
@@ -266,27 +269,63 @@ def replace_forward_output(node, current_name=None):
266269

267270
def inner(input_):
268271
nonlocal cur_idx
269-
if isinstance(input_, (paddle.Tensor, torch.Tensor)):
270-
if cur_idx >= len(numpy_file_list):
271-
raise RuntimeError(
272-
f"\n ⚠️ Single-step alignment FAILED: the {cur_idx + 1}st output is requested, "
273-
f"but only {len(numpy_file_list)} pre-saved numpy files are available."
272+
if not isinstance(input_, (paddle.Tensor, torch.Tensor)):
273+
return input_
274+
275+
if cur_idx >= len(numpy_file_list):
276+
warning_key = ("single-step: output_count_mismatch", current_name)
277+
if warning_key not in _seen_warnings:
278+
logger.warning(
279+
f"\n ⚠️ Single-step alignment SKIPPED: the {cur_idx + 1}st output is requested, "
280+
f"but only {len(numpy_file_list)} pre-saved from base model, skip the current output."
281+
"\n ⚠️ This warning will not repeat for this layer."
274282
f"\n 📌 Layer Name: {current_name}(raw)"
275283
"\n 💡 Possible Causes and Solutions:"
276-
"\n - The number of outputs from the current layer in the raw model does not match "
284+
"\n - The number of outputs from the current layer in the raw model is bigger than"
277285
"that of its corresponding layer in the base model."
278286
"\n - Verify that both models have identical architectures for this layer."
279287
"\n - If the corresponding relationship of the current layer is correct, "
280288
"please disable single step mode, or add the layer to blacklist to skip the check of this layer."
289+
"\n - Or when you are sure that the extra output does not need to be compared, "
290+
"you can swap the execution order of the base model and the raw model."
281291
)
282-
value = np.load(numpy_file_list[cur_idx]["path"])
283-
cur_idx += 1
284-
if isinstance(input_, paddle.Tensor):
285-
return paddle.to_tensor(value, dtype=input_.dtype)
286-
else:
287-
return torch.as_tensor(value, dtype=input_.dtype, device=input_.device)
288-
else:
292+
_seen_warnings.add(warning_key)
293+
return input_
294+
295+
value = np.load(numpy_file_list[cur_idx]["path"])
296+
cur_idx += 1
297+
base_shape = tuple(value.shape)
298+
raw_shape = tuple(input_.shape)
299+
300+
if base_shape == raw_shape:
301+
pass
302+
303+
elif np.prod(base_shape) != np.prod(raw_shape):
304+
warning_key = ("single-step: shape_mismatch", current_name)
305+
if warning_key not in _seen_warnings:
306+
logger.warning(
307+
f"\n ⚠️ Single-step alignment SKIPPED: shape mismatch."
308+
"\n ⚠️ This warning will not repeat for this layer."
309+
f"\n 📌 Layer Name: {current_name}(raw)"
310+
f"\n 📌 Shape: {base_shape}(base) vs {raw_shape}(raw)"
311+
)
312+
_seen_warnings.add(warning_key)
289313
return input_
314+
else:
315+
value = value.reshape(input_.shape)
316+
debug_key = ("single-step: reshape_used", current_name)
317+
if debug_key not in _seen_warnings:
318+
logger.debug(
319+
f"\n ⚠️ Try to reshape loaded value to input's shape of layer {current_name}(raw). "
320+
"This may lead to numerical errors even if reshape succeeds."
321+
"\n ⚠️ This warning will not repeat for this layer."
322+
)
323+
_seen_warnings.add(debug_key)
324+
325+
if isinstance(input_, paddle.Tensor):
326+
return paddle.to_tensor(value, dtype=input_.dtype)
327+
else:
328+
return torch.as_tensor(value, dtype=input_.dtype, device=input_.device)
290329

291330
return inner
292331

@@ -296,14 +335,17 @@ def single_step_check(report, net_id, step_idx, current_name, node_type, bwd_ite
296335
try:
297336
base_report_node = find_base_report_node(net_id, step_idx)
298337
if base_report_node["name"] != current_name:
299-
warning_msg = (
300-
f"\n ⚠️ Single-step alignment WARNING: {node_type} with net_id={net_id} mismatch!\n"
301-
f" 📌 Mismatch {node_type.capitalize()}: {base_report_node['name']}(base) vs {current_name}(raw)\n"
302-
" 💡 Suggestion: Models have different architectures or initialization order. "
303-
"Please check the model implementation or decrease 'align_depth' to reduce the alignment "
304-
"granularity, or add layers that do not require alignment to the blacklist."
305-
)
306-
logger.warning(warning_msg)
338+
warning_key = ("single-step: name_mismatch", current_name)
339+
if warning_key not in _seen_warnings:
340+
logger.warning(
341+
f"\n ⚠️ Single-step alignment WARNING: {node_type} with net_id={net_id} mismatch!"
342+
"\n ⚠️ This warning will not repeat for this layer."
343+
f"\n 📌 Mismatch {node_type.capitalize()}: {base_report_node['name']}(base) vs {current_name}(raw)"
344+
"\n 💡 Suggestion: Models have different architectures or class name or initialization order. "
345+
"Please check the model implementation or decrease 'align_depth' to reduce the alignment "
346+
"granularity, or add layers that do not require alignment to the blacklist."
347+
)
348+
_seen_warnings.add(warning_key)
307349
else:
308350
logger.debug(f"Single Step: {current_name}(net_id={net_id})")
309351

padiff/cli.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,15 @@ def main():
143143
--pt_cmd "python /path/to/your/torch_script.py"
144144
--pd_cmd "python /path/to/your/paddle_script.py"
145145
146-
3. 日志目录参数 (--log_dir):
146+
3. base 框架设置参数 (--base_framework):
147+
* 非必需 被包含在 config 文件中,或 通过命令行传入
148+
* 设置对齐方向,默认值 'troch',即 paddle 代码向 torch 代码对齐
149+
* 影响逐层对齐时,是读取哪个框架下保存的数据
150+
151+
示例:
152+
--base_framework "torch"
153+
154+
4. 日志目录参数 (--log_dir):
147155
* 可选参数
148156
* 指定生成报告和日志的目录
149157
* 默认值: ./padiff_log
@@ -159,6 +167,7 @@ def main():
159167
pd_model_name: "pd_model"
160168
pt_optim_name: "pt_optimizer" # not required
161169
pd_optim_name: "pd_optimizer" # not required
170+
base_framework: "torch" # not required
162171
log_dir: "./padiff_log" # not required
163172
164173
PaDiffGuard:
@@ -192,6 +201,12 @@ def main():
192201
type=str,
193202
help="Override 'pd_cmd' (paddle command) in config, e.g., 'python /paddle_dir/paddle_model.py'",
194203
)
204+
parser.add_argument(
205+
"--base_framework",
206+
type=str,
207+
choices=["torch", "paddle"],
208+
help="Which framework's output should be used as the base framework (base). The other will align to it. Options: 'torch' or 'paddle'.",
209+
)
195210
parser.add_argument(
196211
"--log_dir",
197212
type=str,
@@ -213,6 +228,8 @@ def main():
213228
cli_cfg["pd_cmd"] = args.pd_cmd
214229
if args.log_dir:
215230
cli_cfg["log_dir"] = args.log_dir
231+
if args.base_framework:
232+
cli_cfg["base_framework"] = args.base_framework
216233

217234
log_dir = cli_cfg.pop("log_dir", "./padiff_log")
218235
logger.reset_dir(log_dir)
@@ -229,26 +246,23 @@ def main():
229246
pt_optim_name = cli_cfg.get("pt_optim_name")
230247
pd_optim_name = cli_cfg.get("pd_optim_name")
231248

249+
tasks = {
250+
"torch": {"cmd": pt_cmd, "framework": "torch", "model_name": pt_model_name, "optim_name": pt_optim_name},
251+
"paddle": {"cmd": pd_cmd, "framework": "paddle", "model_name": pd_model_name, "optim_name": pd_optim_name},
252+
}
253+
base_fw = cli_cfg.pop("base_framework", "torch")
254+
align_fw = "paddle" if base_fw == "torch" else "torch"
255+
232256
logger.info("Code injection and script execution...")
233257
try:
234-
pt_dump_path = run_with_padiff(
235-
cmd=pt_cmd,
236-
framework="torch",
237-
model_name=pt_model_name,
238-
optim_name=pt_optim_name,
239-
mode="base",
240-
alignment_dir=None,
241-
guard_cfg=guard_cfg,
242-
)
243-
pd_dump_path = run_with_padiff(
244-
cmd=pd_cmd,
245-
framework="paddle",
246-
model_name=pd_model_name,
247-
optim_name=pd_optim_name,
248-
mode="align",
249-
alignment_dir=pt_dump_path,
250-
guard_cfg=guard_cfg,
251-
)
258+
logger.info(f"Running {base_fw.upper()} as BASE")
259+
base_path = run_with_padiff(mode="base", alignment_dir=None, guard_cfg=guard_cfg, **tasks[base_fw])
260+
261+
logger.info(f"Running {align_fw.upper()} in ALIGN mode")
262+
align_path = run_with_padiff(mode="align", alignment_dir=base_path, guard_cfg=guard_cfg, **tasks[align_fw])
263+
264+
pt_dump_path = base_path if base_fw == "torch" else align_path
265+
pd_dump_path = base_path if base_fw == "paddle" else align_path
252266
except Exception as e:
253267
logger.error(f"An error occurred during execution: {type(e).__name__}: {str(e)}")
254268
import traceback

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from setuptools import find_packages, setup
1717

18-
VERSION = "0.3.0"
18+
VERSION = "0.4.0"
1919

2020

2121
def read_requirements_file(filepath):

0 commit comments

Comments
 (0)