Skip to content

Latest commit

 

History

History
88 lines (61 loc) · 2.47 KB

File metadata and controls

88 lines (61 loc) · 2.47 KB

VAE Training

The Variational Autoencoder (VAE) is the first stage of the Chemeleon2 pipeline. It learns to encode crystal structures into a continuous latent space.

What VAE Does

The VAE is the first stage of Chemeleon2 that encodes crystal structures into continuous latent space representations. For architectural details, see VAE Module.

Quick Start

# Train VAE on MP-20 dataset (src/train_vae.py)
python src/train_vae.py experiment=mp_20/vae_dng

Training script: src/train_vae.py Example config: configs/experiment/mp_20/vae_dng.yaml

:::{tip} Example training run: See this W&B run for a successful VAE training example on MP-20 dataset. :::

Training Commands

Basic Training

# Use experiment config
python src/train_vae.py experiment=mp_20/vae_dng

# Override training parameters
python src/train_vae.py experiment=mp_20/vae_dng \
    trainer.max_epochs=3000 \
    data.batch_size=128

Resume from Checkpoint

python src/train_vae.py experiment=mp_20/vae_dng \
    ckpt_path=ckpts/vae_checkpoint.ckpt

Configuration

Key Hyperparameters

Parameter Default Description
latent_dim 8 Dimension of latent space
hidden_dim 512 Hidden dimension in encoder/decoder (d_model)
num_layers 8 Number of transformer layers
kl_weight 1e-5 KL divergence loss weight

Example Config Override

python src/train_vae.py experiment=mp_20/vae_dng \
    vae_module.latent_dim=16 \
    vae_module.kl_weight=1e-4

Available Experiments

Experiment Dataset Description
mp_20/vae_dng MP-20 VAE for de novo generation

Training Tips

Monitoring

Key metrics to watch in WandB:

  • train/recon_loss: Reconstruction loss (should decrease)
  • train/kl_loss: KL divergence (should stabilize)
  • val/recon_loss: Validation reconstruction (check for overfitting)

Typical Training

  • Duration: ~1000-5000 epochs
  • Batch size: 64-256 depending on GPU memory
  • Learning rate: 1e-4 (default)

Next Steps

After training VAE:

  1. Note the checkpoint path
  2. Proceed to LDM Training using your VAE checkpoint