Skip to content

joarp/meanflow-reproducibility

Repository files navigation

meanflow-reproducibility

Re-implementation of Mean Flows for One-step Generative Modeling paper.

Setup

  1. Training can be run on CPU or on hardware accelerators (e.g. CUDA GPUs).
  2. Install project dependencies: pip install -r requirements.txt

Files

  • train_meanflow.py: Training script.

  • Training Scripts: All training, dataset generation and FID evaluation scripts.

  • generate_dataset.py: Generate a Synthetic MNIST training ataset of 60000 samples.

  • loss.py: The meanflow loss with or without CFG and improved CFG.

  • MNIST Evaluation: Code for CAS evaluation of generated MNIST samples and image generation helpers.

  • evaluation_meanflow.py: To compute FID score. Can also be used to generate samples.

  • trivial_baseline.py: Helper code that can be used instead of the meanflow sampler and meanflow loss to train a trivial baseline (direct flow map matching).

  • meanflow_sampler.py: Implementation of MeanFlow sampler as described in the original paper.

  • Networks: Directory to drop additional networks. Currently a U-NET is the only available network.

Training

To start training, run for example (or any other training script in scripts/):

bash scripts/mnist/train_cfg_1.sh

Where training parameters can be changed in train_cfg_1.sh:

accelerate launch train_meanflow.py \
    --dataset mnist \
    --export_name 1cfg.pth \
    --batch_size 64 \
    --epochs 10 \
    --lr 0.001 \
    --ema 0.9995 \
    --time_sampler logit_normal \
    --logit_sigma 2.0 \
    --logit_mu -2.0 \
    --ratio_r_not_equal_t 0.75 \
    --scheduler linear \
    --model unet \
    --seed 99 \
    --num_workers 0 \
    --cfg_omega 3.0 \
    --cfg_kappa 0.5 \
    --cfg_drop_ratio 0.1

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors