-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathconfig_example.yaml
More file actions
34 lines (31 loc) · 2.05 KB
/
config_example.yaml
File metadata and controls
34 lines (31 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# --- CLI 部分 ---
# 定义脚本执行相关参数
CLI:
pt_cmd: "python torch_project/run.py" # PyTorch 脚本运行命令
pd_cmd: "python paddle_project/run.py" # PaddlePaddle 脚本运行命令
pt_model_name: "pt_model" # PyTorch 模型变量名
pd_model_name: "pd_model" # PaddlePaddle 模型变量名
pt_optim_name: "pt_optimizer" # (可选) PyTorch 优化器变量名
pd_optim_name: "pd_optimizer" # (可选) PaddlePaddle 优化器变量名
base_framework: "torch" # (可选) 设置作为基准的框架名
log_dir: "./padiff_log" # (可选) 日志目录
# --- COMPARE 部分 ---
# 定义结果对比逻辑
COMPARE:
atol: 1.0e-06 # (可选) 绝对误差
rtol: 1.0e-06 # (可选) 相对误差
compare_mode: "mean" # (可选) 数值对比模式
action_name: "equal" # (可选) 层对比策略
check_mode: "fast" # (可选) 模型对比策略
# --- PaDiffGuard 部分 ---
# 定义模型对齐行为
PaDiffGuard:
align_depth: 1 # (可选) 对齐深度
single_step_mode: "forward" # (可选) 单步模式
load_init_weights: false # (可选) 加载初始权重
load_first_inputs: false # (可选) 加载首次输入
max_calls: 1 # (可选) 最大调用次数
black_list: ["TorcheFakeLayer", "PaddleFakeLayer"] # (可选) 不参与对齐的层的黑名单
keys_mapping: # (可选) 参数名映射 key(paddle模型参数名): value(torch模型参数名)
"parm_name_of_paddle_model": "parm_name_of_torch_model"
"model_pd.layers.0.input_layernorm.weight": "model_pt.layers.0.input_layernorm.weight"