Skip to content

【Hackathon 10th Spring No.13】GDI-NN模型复现 RFC#1254

Open
megemini wants to merge 4 commits intoPaddlePaddle:masterfrom
megemini:gdinn
Open

【Hackathon 10th Spring No.13】GDI-NN模型复现 RFC#1254
megemini wants to merge 4 commits intoPaddlePaddle:masterfrom
megemini:gdinn

Conversation

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 22, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请检查PR提交格式和内容是否完备,具体请参考示例模版
Your PR has been submitted. Thanks for your contribution!
Please check its format and content. For this, you can refer to Template and Demo.


**飞桨实现**:

基于 PaddlePaddle 的实现,使用 PGL 进行图操作。
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦确认是否可行

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PaddlePaddle/PaddleMaterials#252 已经提交 PR 了 ~

我新加了两个测试脚本在 RFC 目录中:

  • quick_test.py 是针对 PaddleMaterials 的测试
  • test_alignment.py 是针对 PaddleMaterials 与 GDI-NN 精度对齐的测试

我在本地验证是可以的,测试步骤:

  • 在 PaddleMaterials 中添加一个 test_gdinn 目录
  • 将 GDI-NN 中的数据放到 test_gdinn/dataset 目录中
  • 将 quick_test.py 和 test_alignment.py 放到 test_gdinn
  • 使用命令 python test_gdinn/test_alignment.py; python test_gdinn/quick_test.py 进行测试

以下是测试结果:

(venv310)  shun@shun-B660M-Pro-RS  ~/workspace/Projects/megemini/PaddleMaterials   gdinn ±  python test_gdinn/test_alignment.py; python test_gdinn/quick_test.py
/home/shun/venv310/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
  warnings.warn(warning_message)
================================================================================
GDI-NN 精度对齐测试 (使用真实数据)
================================================================================

================================================================================
测试 mean_nodes 精度对齐 (Paddle vs DGL)
================================================================================
W0330 17:47:39.440886 27416 gpu_resources.cc:114] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.2, Runtime API Version: 11.8
  Paddle mean_nodes 输出形状: [4, 64]
  DGL mean_nodes 输出形状: torch.Size([4, 64])
  最大差异: 0.0000000596
  平均差异: 0.0000000011
  ✓ 精度对齐 (差异 < 1e-06)

================================================================================
测试氢键特征计算 (与原始 GDI-NN 对齐)
================================================================================
Loaded 700 solvents from solvent list
✓ HB 特征字段存在

溶剂 1: solvent_587
  SMILES: CN
  HBA: 1, HBD: 1
  intra_hb1: expected=1, actual=1

溶剂 2: solvent_604
  SMILES: CC(=O)CC(C)C
  HBA: 1, HBD: 0
  intra_hb2: expected=0, actual=0

交互氢键:
  inter_hb = min(1, 0) + min(1, 1)
           = 0 + 1
           = 1
  actual: 1

✓ HB 特征计算与原始 GDI-NN 一致

测试溶剂缓存机制...
  ✓ 溶剂缓存格式正确: [graph, hba=1, hbd=1, intra_hb=1]

================================================================================
测试 GNN 精度对齐 (Paddle vs PyTorch)
================================================================================
Loaded 700 solvents from solvent list
  gamma1 最大差异: 0.000000
  gamma1 平均差异: 0.000000
  gamma2 最大差异: 0.000000
  gamma2 平均差异: 0.000000

================================================================================
测试 MCM 精度对齐 (Paddle vs PyTorch)
================================================================================
准备测试数据...
✓ 加载数据: 101 样本, 700 溶剂
  ln_gamma1 最大差异: 0.000000
  ln_gamma1 平均差异: 0.000000
  ln_gamma2 最大差异: 0.000000
  ln_gamma2 平均差异: 0.000000

================================================================================
测试总结
================================================================================
mean_nodes_精度对齐: ✓ 通过 (精度对齐)
HB_特征计算: ✓ 通过 (HB 特征测试通过)
GNN_精度对齐: ✓ 通过 (精度对齐)
MCM_精度对齐: ✓ 通过 (精度对齐)

总计: 4/4 测试通过
================================================================================
✓ 所有测试通过!
/home/shun/venv310/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
  warnings.warn(warning_message)
================================================================================
GDI-NN 快速测试
================================================================================
准备测试数据(使用 GDI-NN 格式)...
读取溶剂列表: ./test_gdinn/dataset/solvent_list.csv
✓ 溶剂数量: 700
读取数据: ./test_gdinn/dataset/output_binary_with_inf_all.csv
✓ 数据量: 101
✓ 训练集: 80 样本
✓ 验证集: 10 样本
✓ 测试集: 11 样本
✓ 数据保存在: ./test_gdinn/data/gdinn/

================================================================================
测试数据加载
================================================================================
Loaded 700 solvents from solvent list
✓ 数据集创建成功
  样本数量: 80
✓ 数据加载器创建成功
  Batch数量: 2
W0330 17:47:44.141553 27506 gpu_resources.cc:114] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.2, Runtime API Version: 11.8

Batch 1:
  g1 nodes: 206
  g1 edges: 570
  g2 nodes: 259
  g2 edges: 729
  empty_solvsys nodes: 64
  empty_solvsys edges: 128
  x1 shape: [32]
  x2 shape: [32]
  gamma1 shape: [32]
  gamma2 shape: [32]

Batch 2:
  g1 nodes: 213
  g1 edges: 599
  g2 nodes: 216
  g2 edges: 606
  empty_solvsys nodes: 64
  empty_solvsys edges: 128
  x1 shape: [32]
  x2 shape: [32]
  gamma1 shape: [32]
  gamma2 shape: [32]

✓ 数据加载测试通过

================================================================================
测试模型前向传播
================================================================================
✓ 模型创建成功
  参数数量: 181825
Loaded 700 solvents from solvent list

测试 Batch 1...
✓ 前向传播成功
  loss_dict keys: ['loss', 'pred_loss', 'gd_loss']
  pred_dict keys: ['gamma1', 'gamma2', 'ln_gamma1', 'ln_gamma2']
  loss: 0.4298
  pred_loss: 0.4297
  gd_loss: 0.0001
  gamma1 shape: [32, 1]
  gamma2 shape: [32, 1]

✓ 模型前向传播测试通过

================================================================================
测试训练步骤
================================================================================
✓ 模型和损失函数创建成功
Loaded 700 solvents from solvent list
✓ 优化器创建成功
  Step 1: Loss = 0.9105
  Step 2: Loss = 0.3479

✓ 训练步骤测试通过

================================================================================
测试 SolvGNNxMLP 模型前向传播
================================================================================
✓ SolvGNNxMLP 模型创建成功
  参数数量: 181825
Loaded 700 solvents from solvent list

测试 Batch 1...
✓ 前向传播成功
  loss_dict keys: ['loss', 'pred_loss', 'gd_loss']
  pred_dict keys: ['gamma1', 'gamma2', 'ln_gamma1', 'ln_gamma2']
  loss: 0.4696
  pred_loss: 0.4678
  gd_loss: 0.0018
  gamma1 shape: [32, 1]
  gamma2 shape: [32, 1]

✓ SolvGNNxMLP 模型前向传播测试通过

================================================================================
测试 SolvGNNxMLP 模型训练步骤
================================================================================
✓ SolvGNNxMLP 模型创建成功
Loaded 700 solvents from solvent list
✓ 优化器创建成功
  Step 1: Loss = 0.1254
  Step 2: Loss = 0.8354

✓ SolvGNNxMLP 训练步骤测试通过

================================================================================
测试 GEGNN 模型前向传播
================================================================================
✓ GEGNN 模型创建成功
  参数数量: 186115
Loaded 700 solvents from solvent list

测试 Batch 1...
✓ 前向传播成功
  loss_dict keys: ['loss', 'pred_loss', 'gd_loss']
  pred_dict keys: ['gamma1', 'gamma2', 'ln_gamma1', 'ln_gamma2', 'G_E']
  loss: 0.4366
  pred_loss: 0.4366
  gd_loss: 0.0000
  gamma1 shape: [32, 1]
  gamma2 shape: [32, 1]
  G_E shape: [32, 1]
  G_E mean: 0.0087

✓ GEGNN 模型前向传播测试通过

================================================================================
测试 GEGNN 模型训练步骤
================================================================================
✓ GEGNN 模型创建成功
Loaded 700 solvents from solvent list
✓ 优化器创建成功
  Step 1: Loss = 0.5936, G_E = -0.0045
  Step 2: Loss = 0.4706, G_E = 0.0407

✓ GEGNN 训练步骤测试通过

================================================================================
测试 MCM 模型前向传播
================================================================================
创建 MCM 测试数据...
✓ MCM 训练集: 400 样本
✓ MCM 验证集: 50 样本
✓ MCM 测试集: 50 样本
✓ 数据保存在: ./test_gdinn/data/gdinn/
✓ MCM 模型创建成功
  参数数量: 46658

测试前向传播...
✓ 前向传播成功
  loss_dict keys: ['loss', 'pred_loss', 'gd_loss']
  pred_dict keys: ['gamma1', 'gamma2', 'ln_gamma1', 'ln_gamma2']
  pred_loss: 0.0254
  gd_loss: 0.0013
  gamma1 shape: [32, 1]
  gamma2 shape: [32, 1]

✓ MCM 模型前向传播测试通过

================================================================================
测试 MCM 模型训练步骤
================================================================================
✓ MCM 模型创建成功
✓ 优化器创建成功
  Step 1: Loss = 0.0486
  Step 2: Loss = 0.0341
  Step 3: Loss = 0.0168

✓ MCM 训练步骤测试通过

================================================================================
测试 GibbsDuhemLoss 损失函数
================================================================================
✓ GibbsDuhemLoss 创建成功
  lambda_gd: 1.0
  loss_type: mse

测试 1: 合成数据测试
  满足约束的损失: 0.000000

测试 2: 不满足约束的情况
  不满足约束的损失: 6.315185

测试 3: 测试不同损失类型
  mse 损失: 0.000000
  mae 损失: 0.000000
  huber 损失: 0.000000

测试 5: 梯度计算测试
  损失值: 0.000000
  A_param 梯度: 0.000000

✓ GibbsDuhemLoss 测试通过

================================================================================
测试 GibbsDuhemLoss 与模型集成
================================================================================
✓ 模型和 GibbsDuhemLoss 创建成功
Loaded 700 solvents from solvent list
✓ 优化器创建成功

测试 Batch 1...
✓ 训练步骤成功
  pred_loss: 2.4384
  gd_loss (使用 criterion_gd): 0.0000
  total_loss: 2.4384
  gamma1 shape: [32, 1]
  gamma2 shape: [32, 1]

测试 Batch 2...
✓ 训练步骤成功
  pred_loss: 2.4297
  gd_loss (使用 criterion_gd): 0.0000
  total_loss: 2.4297
  gamma1 shape: [32, 1]
  gamma2 shape: [32, 1]

✓ GibbsDuhemLoss 与模型集成测试通过

================================================================================
测试预测
================================================================================
Loaded 700 solvents from solvent list
✓ 预测完成
  预测样本数: 11
  平均绝对误差 (MAE): 1.0218

前5个预测结果:
/home/shun/workspace/Projects/megemini/PaddleMaterials/test_gdinn/quick_test.py:1217: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  print(f"  {i+1}. gamma1: pred={float(pred['gamma1_pred']):.4f}, target={float(pred['gamma1_target']):.4f}, "
/home/shun/workspace/Projects/megemini/PaddleMaterials/test_gdinn/quick_test.py:1218: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  f"error={abs(float(pred['gamma1_pred']) - float(pred['gamma1_target'])):.4f}")
  1. gamma1: pred=1.0769, target=0.3857, error=0.6913
/home/shun/workspace/Projects/megemini/PaddleMaterials/test_gdinn/quick_test.py:1219: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  print(f"     gamma2: pred={float(pred['gamma2_pred']):.4f}, target={float(pred['gamma2_target']):.4f}, "
/home/shun/workspace/Projects/megemini/PaddleMaterials/test_gdinn/quick_test.py:1220: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  f"error={abs(float(pred['gamma2_pred']) - float(pred['gamma2_target'])):.4f}")
     gamma2: pred=1.1001, target=0.0020, error=1.0981
  2. gamma1: pred=1.0779, target=2.6893, error=1.6114
     gamma2: pred=1.0986, target=0.1162, error=0.9824
  3. gamma1: pred=1.0621, target=0.1561, error=0.9061
     gamma2: pred=1.1015, target=0.0038, error=1.0977
  4. gamma1: pred=1.0661, target=0.9194, error=0.1467
     gamma2: pred=1.0911, target=0.0195, error=1.0716
  5. gamma1: pred=1.0662, target=-0.3831, error=1.4493
     gamma2: pred=1.0984, target=-0.0047, error=1.1030

✓ 预测测试通过

================================================================================
测试总结
================================================================================
数据加载: ✓ 通过
SolvGNN前向传播: ✓ 通过
SolvGNN训练步骤: ✓ 通过
SolvGNNxMLP前向传播: ✓ 通过
SolvGNNxMLP训练步骤: ✓ 通过
GEGNN前向传播: ✓ 通过
GEGNN训练步骤: ✓ 通过
MCM前向传播: ✓ 通过
MCM训练步骤: ✓ 通过
GibbsDuhemLoss: ✓ 通过
GibbsDuhemLoss与模型集成: ✓ 通过
预测: ✓ 通过

总计: 12/12 测试通过
================================================================================
✓ 所有测试通过!

现在可以使用以下命令进行完整训练:
  python train_gdinn.py \
    --model_type SolvGNN \
    --batch_size 32 \
    --epochs 2 \
    --hidden_dim 64 \
    --lr 1e-3 \
    --pinn_lambda 1.0

训练完成后,可以使用以下命令进行预测:
  python predict_gdinn.py \
    --model_type SolvGNN \
    --batch_size 32 \
    --hidden_dim 64 \
    --checkpoint ./checkpoints/best_model.pdparams

请帮忙看一下,感谢!:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants