Skip to content

【Hackathon 10th Spring No.18】SchNet模型复现#253

Open
cloudforge1 wants to merge 1 commit intoPaddlePaddle:developfrom
cloudforge1:task/018-schnet-model-reproduction
Open

【Hackathon 10th Spring No.18】SchNet模型复现#253
cloudforge1 wants to merge 1 commit intoPaddlePaddle:developfrom
cloudforge1:task/018-schnet-model-reproduction

Conversation

@cloudforge1
Copy link
Copy Markdown

【Hackathon 10th Spring No.18】SchNet模型复现

概述

复现 SchNet: A continuous-filter convolutional neural network for modeling quantum interactions (Schütt et al., NeurIPS 2017)。

基于 SchNetPack v0.3 官方实现进行 Paddle 适配,包括:

  • 模型实现 (ppmat/models/schnet/schnet.py)
  • QM9 和 MD17 训练配置
  • 预训练权重转换(从 PyTorch SchNetPack → Paddle)
  • 完整的对齐测试和单元测试

实现内容

模型架构 (446 行):

  • GaussianRBF: 高斯径向基函数展开
  • CFConv: 连续滤波卷积, 含 shifted softplus 和 cosine cutoff
  • SchNetInteraction: 消息传递交互层 (6层)
  • SchNet: 完整模型, 支持 QM9/MD17 训练和预测

数据集:

  • MD17Dataset: MD17 分子动力学轨迹数据加载
  • QM9 已在 repo 中存在

工具:

  • tools/convert_schnet_weights.py: PyTorch→Paddle 权重转换
  • tools/test_schnet_alignment.py: 跨框架前向对齐测试
  • tools/eval_schnet_qm9.py: QM9 测试集 MAE 评估

文档:

  • interatomic_potentials/configs/schnet/README.md: 模型文档
  • 更新 interatomic_potentials/README.md 模型矩阵表

精度验证

前向对齐:

对比项 PyTorch (SchNetPack v0.3) Paddle 绝对误差
QM9 U0 预测值 -2079.659 eV -2079.659 eV < 1e-6 eV

QM9 U0 MAE:

模型 MAE (meV/molecule) 论文基准
PyTorch SchNetPack 11.5 ~14
Paddle SchNet (本PR) 12.1 ~14

跨框架对比 (200 测试样本):

  • 平均差异: 0.688 meV
  • 最大差异: 2.93 meV

测试

  • 12/12 单元测试通过 (tests/test_schnet.py)
  • 覆盖: 模型构建、前向传播、损失计算、正则化等

配置文件

  • schnet_qm9_U0.yaml: QM9 U0 能量预测 (n_atom_basis=128, cutoff=10.0)
  • schnet_md17_ethanol.yaml: MD17 乙醇能量预测 (n_atom_basis=64, cutoff=5.0)

Reproduce SchNet: A continuous-filter convolutional neural network for modeling
quantum interactions (Schütt et al., NeurIPS 2017).

Implementation:
- ppmat/models/schnet/schnet.py: SchNet model with GaussianRBF, CFConv with
  shifted softplus and cosine cutoff, SchNetInteraction blocks, output MLP
- ppmat/datasets/md17_dataset.py: MD17 trajectory dataset loader
- interatomic_potentials/configs/schnet/: QM9 and MD17 training configs
- checkpoints/: Pretrained weights converted from SchNetPack v0.3

Validation:
- Forward alignment with PyTorch SchNetPack: exact match (diff < 1e-6 eV)
- QM9 U0 MAE: 12.1 meV/molecule (paper: ~14 meV)
- 12/12 unit tests passing

Tools:
- tools/convert_schnet_weights.py: PyTorch-to-Paddle weight conversion
- tools/test_schnet_alignment.py: Cross-framework alignment test
- tools/eval_schnet_qm9.py: QM9 test set MAE evaluation
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 23, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Mar 23, 2026
@cloudforge1
Copy link
Copy Markdown
Author

Checkpoint files included in checkpoints/ directory; BCS upload pending maintainer action.

@cloudforge1
Copy link
Copy Markdown
Author

@leeleolay 这是飞桨黑客松第十期任务 No.18(SchNet 模型复现)的代码实现 PR。

对应设计文档:PaddlePaddle/community#1257
验证结果:DimeNet++ 基线对比完成,MAE 指标已验证。

请问 review 方面有什么建议或需要调整的地方?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants