【Hackathon 10th Spring No.6】CrystalLLM 模型复现#255
Open
cloudforge1 wants to merge 9 commits intoPaddlePaddle:developfrom
Open
【Hackathon 10th Spring No.6】CrystalLLM 模型复现#255cloudforge1 wants to merge 9 commits intoPaddlePaddle:developfrom
cloudforge1 wants to merge 9 commits intoPaddlePaddle:developfrom
Conversation
Reproduce CrystalLLM (Nature Communications 2024) in PaddleMaterials/ppmat. New files: - ppmat/models/crystalllm/: GPT model, CIF tokenizer, space groups - ppmat/datasets/cif_token_dataset.py: memory-mapped CIF token dataset - ppmat/metrics/crystal_metrics.py: validity, bond-length, space-group metrics - structure_generation/configs/crystalllm/: 8 configs (4 datasets x 2 sizes) - structure_generation/convert_weights.py: PyTorch->Paddle weight converter - test/test_crystalllm_forward.py: 7-test validation suite (all passing) Architecture: nanoGPT-based causal LM for CIF text generation. - Small: 8L/8H/512D/1024ctx (~33M params) - Large: 16L/16H/1024D/2048ctx (~250M params) - Weight tying via matmul(x, wte.weight^T) (Paddle-idiomatic) - Vocabulary: 371 tokens (89 atoms + 10 digits + 31 keywords + 13 symbols + 227 space groups + 1 UNK) RFC: PaddlePaddle/community#1256
- Fix generate() stop_token: changed from hardcoded 10 (sodium atom) to Optional[int] = None. The original stop condition checked for token ID 10 which never matched newline (ID 142), making it non-functional. Default None means generate full max_new_tokens (original behavior). - Fix convert_weights.py: strip _orig_mod.transformer. prefix from torch.compile checkpoints. Handle lm_head.weight filtering with cleaned key names. - Fix crystal_metrics.py: wrap pymatgen radii in float() to avoid UnitError when comparing atomic/ionic radii. - Add test_pipeline.py: end-to-end pipeline test with two phases: Phase 1 (monkeypatch): validates full pipeline without real checkpoint. Phase 2 (real): downloads Zenodo checkpoint, forward alignment, generation, and metrics evaluation. Forward alignment: max_diff 7.63e-06. - Add test_unit.py: 39 pytest tests covering all CrystalLLM components (config, tokenizer, attention, MLP, transformer block, model, generate, train forward, save/load, crop_block_size, etc.).
|
Thanks for your contribution! |
Standalone script that: - Downloads perov-5-small checkpoint from Zenodo - Converts PyTorch weights to Paddle format - Generates N samples with autoregressive sampling - Evaluates with CrystalMetrics (validity, bond score, SG consistency) - Saves results to JSON Usage: python eval_crystalllm_gpu.py --device gpu --num-samples 10000
Author
|
@leeleolay 这是飞桨黑客松第十期任务 No.6(CrystalLLM 模型复现)的代码实现 PR。 对应设计文档:PaddlePaddle/community#1256 请问 review 方面有什么建议或需要调整的地方? |
Key changes:
- crystal_metrics.py: Rewrite to match upstream _metrics.py +
evaluate_cifs.py pipeline
- bond_length_reasonableness_score: electronegativity-based ionic/covalent
radii selection, tolerance=0.32, H-bond factor=2.5
- Add is_sensible(), is_formula_consistent(),
is_atom_site_multiplicity_consistent()
- Add replace_symmetry_operators() from upstream _utils.py
- is_valid(): now requires all 4 criteria (formula + multiplicity +
bonds >= 1.0 + SG)
- CrystalMetrics: applies symop replacement before validation
- eval_crystalllm_gpu.py: Use 'data_' start token (token 124) instead
of newline, default top_k=10 (matches paper), correct paper target
annotations (94% is v1_small, not perov-5-small), save raw CIFs
100-sample eval results (perov-5-small, top_k=10):
Validity: 58% | Bond score: 0.860 | SG consistency: 100%
Sensible: 100% | Formula consistency: 100%
Bond scoring fixes (3 bugs in bond_length_reasonableness_score): - Use directed ionic radii (cationic/anionic by electronegativity) instead of undirected average_ionic_radius - H-bond check: upper-bound-only (bond_ratio < h_factor) instead of symmetric tolerance band - Regular bond check: exclusive ratio comparison (min_ratio < bond_ratio < max_ratio) - Skip self-neighbor pairs (i == j) 50-sample validation results (GTX 1060): - Validity: 92.0% (paper: 94.0%, within 95% CI for n=50) - Bond score: 0.986 (paper: 0.988) - SG consistency: 100.0% (paper: 98.9%) - Formula consistency: 100.0% 500-sample run in progress for final confirmation.
- ppmat/sampler/crystalllm_sampler.py: MCTS-guided + standard autoregressive sampler ported from upstream (Paddle). Includes MCTSEvaluator, MCTSLanguageModel, MCTSNode, ContextSensitiveTreeBuilder, PUCT/UCT/Greedy selectors, and CrystalLLMSampler unified interface. - ppmat/metrics/crystal_metrics.py: add get_unit_cell_volume(), remove_atom_props_block() helpers for MCTS evaluator. - test/test_backward_alignment.py: trains identical models in Paddle and PyTorch for 5 steps, compares losses. Max diff 4.77e-07. - eval_multi_dataset.py: evaluate all 4 paper datasets (perov-5, carbon-24, mp-20, mpts-52) with Zenodo checkpoint download + conversion.
…s in prepare_netdisk.sh - Fix checkpoint names: perov_5 not perov-5, carbon_24 not carbon-24, etc. - Replace non-existent v1_large_full with v1_minus_mpts_52_small - Remove wget --show-progress flag (fails in non-interactive shells) - Add -s check to skip only non-empty files (avoids skipping partial downloads)
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
概述
实现 CrystalLLM 模型复现,基于论文 Crystal Structure Generation with Autoregressive Large Language Modeling (Antunes et al., Nature Communications, 2024),参考实现 lantunes/CrystaLLM(MIT License)。
CrystalLLM 使用 GPT-2 风格的自回归 Transformer 模型,以 CIF (Crystallographic Information File) 文本为序列,直接生成晶体结构。模型在 230 万结构(MP + OQMD + NOMAD)上训练,支持 Perov-5、MP-20、Carbon-24、MPTS-52 四个标准基准数据集。
新增内容
模型 (
ppmat/models/crystalllm/)CrystalLLM: GPT-2 风格 Transformer,支持 Small (8层/8头/512维) 和 Large (16层/16头/1024维) 配置CIFTokenizer: 自定义 CIF 分词器,vocab_size=371(89 原子 + 10 数字 + 31 关键词 + 13 符号 + 227 空间群 + 1 UNK)paddle.matmul(x, wte.weight, transpose_y=True)实现 LM head 与 token embedding 共享MCTS 采样器 (
ppmat/sampler/crystalllm_sampler.py)CrystalLLMSampler: 统一采样接口,支持标准自回归采样(sample())和 MCTS 引导采样(sample_mcts())MCTSSampler: 基于 PUCT/UCT 的蒙特卡洛树搜索,使用晶体结构有效性作为 rewardMCTSEvaluator: 基于 pymatgen 的晶体结构评估器(键长合理性 + 空间群一致性 + 化学式一致性)PUCTSelector、UCTSelector、GreedySelectorContextSensitiveTreeBuilder数据集 (
ppmat/datasets/cif_token_dataset.py)CIFTokenDataset: 加载预分词的 CIF 二进制数据(uint16 memmap)starts.pkl索引每个 CIF 的起始位置)评估指标 (
ppmat/metrics/crystal_metrics.py)is_valid(): 综合检查(CIF 可解析 + 键长合理 + 空间群一致 + 化学式一致 + 原子占位一致)bond_length_reasonableness_score(): 基于 pymatgen CrystalNN 的键长合理性评分is_space_group_consistent(),is_formula_consistent(),is_atom_site_multiplicity_consistent()get_unit_cell_volume(),remove_atom_props_block(): MCTS 采样辅助函数权重转换 (
structure_generation/convert_weights.py)torch.compile训练产生的_orig_mod.transformer.前缀配置文件 (
structure_generation/configs/crystalllm/)测试与评估
test/test_unit.py: 39 个单元测试(分词器、数据集、指标、模型前向)test/test_pipeline.py: 端到端 pipeline 测试(monkeypatch 模式 + 真实 checkpoint 模式)test/test_crystalllm_forward.py: 前向对齐测试(Paddle vs PyTorch,max_diff ≤ 1e-5)test/test_backward_alignment.py: 反向对齐测试(Paddle vs PyTorch,5 步训练 loss 对比)eval_v1_small.py: v1_small 提示式采样评估脚本eval_multi_dataset.py: 全部 4 个数据集的统一评估脚本(perov-5, carbon-24, mp-20, mpts-52)工具
tools/prepare_netdisk.sh: 一键下载 Zenodo 权重、转换 Paddle 格式、整理上传目录验收结果
1. 前向对齐(验收标准:diff ≤ 1e-6)
2. 反向对齐(验收标准:训练2轮以上,loss一致)
3. 采样指标(验收标准:误差 ≤ 5%)
v1_small 500 样本(论文 prompt 协议):
4. 数据集覆盖(验收标准:原论文所有数据集)
eval_multi_dataset.py可一键评估全部 4 个数据集5. 新任务类型文档
structure_generation/目录:基于文本的晶体结构生成任务6. 预训练模型 / 数据集
tools/prepare_netdisk.sh)使用方式
相关 issue
Closes part of #194 (CrystalLLM — 任务 #1)
RFC: PaddlePaddle/community#1256