Re-implementation of Mean Flows for One-step Generative Modeling paper.
- Training can be run on CPU or on hardware accelerators (e.g. CUDA GPUs).
- Install project dependencies:
pip install -r requirements.txt
-
train_meanflow.py: Training script.
-
Training Scripts: All training, dataset generation and FID evaluation scripts.
-
generate_dataset.py: Generate a Synthetic MNIST training ataset of 60000 samples.
-
loss.py: The meanflow loss with or without CFG and improved CFG.
-
MNIST Evaluation: Code for CAS evaluation of generated MNIST samples and image generation helpers.
-
evaluation_meanflow.py: To compute FID score. Can also be used to generate samples.
-
trivial_baseline.py: Helper code that can be used instead of the meanflow sampler and meanflow loss to train a trivial baseline (direct flow map matching).
-
meanflow_sampler.py: Implementation of MeanFlow sampler as described in the original paper.
-
Networks: Directory to drop additional networks. Currently a U-NET is the only available network.
To start training, run for example (or any other training script in scripts/):
bash scripts/mnist/train_cfg_1.sh
Where training parameters can be changed in train_cfg_1.sh:
accelerate launch train_meanflow.py \
--dataset mnist \
--export_name 1cfg.pth \
--batch_size 64 \
--epochs 10 \
--lr 0.001 \
--ema 0.9995 \
--time_sampler logit_normal \
--logit_sigma 2.0 \
--logit_mu -2.0 \
--ratio_r_not_equal_t 0.75 \
--scheduler linear \
--model unet \
--seed 99 \
--num_workers 0 \
--cfg_omega 3.0 \
--cfg_kappa 0.5 \
--cfg_drop_ratio 0.1