This repo features a benchmark on the joint self-attention and selective state space models for conditional image synthesis using latent diffusion models.
It contains model implementation of:
DiT: Diffusion Transformer (baseline)DiM: Diffusion Mamba (ours)M-DiT: Mamba Diffusion Transformer (ours)DiMT: Diffusion Mamba-Transformer (ours)E-DiMT: Enhanced Diffusion Mamba-Transformer (ours)
The repository is organised as follows:
- This repository is built upon the fast-DiT, an improved PyTorch implementation of Scalable Diffusion Models with Transformers (DiT).
- Refer to Vim, as the official implementation of the paper Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model.
- Refer to ADM's TensorFlow evaluation suite for more details about the compute of FID, Inception Score and other metrics.
e-dimt-main
├── diffusion/ # Diffusion dir
├── docs/ # Documentation figures
├── evaluator/ # Evaluator dir
│ └── eval.py # Evaluation script
│ └── ...
├── mamba/ # Mamba model dir
├── models/ # Backbones dir
│ └── layers.py # Layers and utility functions
│ └── edimt.py # e-dimt backbone
│ └── ...
├── extract_features.py # Feature extraction script
├── requirements.txt # Requirements
├── sample.py # Sampling script
├── sample_ddp.py # DDP sampling script
└── train.py # Training script
First, download and set up the repo:
git clone https://github.com/ahmedgh970/e-dimt.git
cd EDiMTThen, create a python 3.10 conda env and install the requirements
conda create --name edimt python=3.10
conda activate edimt
pip install -r requirements.txt
cd mamba
pip install -e .To extract ImageNet-256 features:
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --image-size 256 --data-path /path/to/imagenet256 --features-path /path/to/store/featuresWe provide a training script for E-DiMT model in train.py. Please modify the necessary import and model definition/name to train a different model from the proposed benchmark.
To launch EDiMT-L/2 (256x256) training with N GPUs on one node:
accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model EDiMT-L/2 --image-size 256 --features-path /path/to/store/featuresTo sample from the EMA weights of a trained (256x256) EDiMT-L/2 model, run:
python sample.py --model EDiMT-L/2 --image-size 256 --ckpt /path/to/model.ptThe sampling results will be saved in the model results dir and inside samples dir.
To samples 50K images from E-DiMT model over N GPUs, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model EDiMT-L/2 --num-fid-samples 50000This script generates a folder of samples as well as a .npz file which can be directly used with evaluator/eval.py to compute FID, Inception Score and other metrics, as follows:
python eval.py /path/to/refernce/batch/.npz /path/to/generated/batch/.npz
...
computing reference batch activations...
computing/reading reference batch statistics...
computing sample batch activations...
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 215.8370361328125
FID: 3.9425574129223264
sFID: 6.140433703346162
Precision: 0.8265
Recall: 0.5309