AISTATS 2026
This repository contains the official implementation of SFBD Flow, a framework for training diffusion models from corrupted data supplemented with a small fraction of clean samples — a setting motivated by data privacy constraints.
Key contributions:
- We reinterpret the existing SFBD (Score-based Fine-tuning with Blurred Data) approach as an alternating projection algorithm, providing a new theoretical lens.
- We introduce SFBD Flow, a continuous-time variant that eliminates manual coordination between denoising and fine-tuning steps.
- We develop Online SFBD, a practical instantiation of SFBD Flow that outperforms strong baselines across CIFAR-10 and CelebA benchmarks.
- We establish theoretical connections between SFBD Flow and consistency-constraint-based methods.
Training proceeds in two stages:
The model is first trained on the small clean dataset only. This stage runs for involve_denoise_after_niter ticks. It provides a stable starting point before the noisy data is introduced.
After pre-training, the model is fine-tuned iteratively. At each update step the denoised dataset is refreshed using the latest model checkpoint, and the model is re-trained on the union of clean and denoised images.
We support two update-size strategies:
| Mode | Flag | Behaviour |
|---|---|---|
| Fixed | --adaptive_update=False |
Updates exactly num_samples_to_update images per step (set in the sampling config). |
| Adaptive (KID-guided) | --adaptive_update=True (default) |
Uses the KID difference between the newly denoised images and the current denoised dataset to decide how many samples to commit. More improvement → more samples updated. |
The weight of denoised images in training increases linearly from 0 to 1 between involve_denoise_after_niter and fully_involve_denoise_after_niter.
- Python 3.9
- CUDA 12.1 (tested; other CUDA versions may work)
- 1–4 GPUs (examples below use 4)
pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121pip install -r requirements.txtThe framework expects two separate dataset directories:
| Argument | Description |
|---|---|
--data_clean |
Directory of clean images (the small trusted set, e.g. 1% or 10% of the full dataset) |
--data_noisy |
Directory of corrupted images (the larger private/noisy set) |
Both directories should contain images as individual .png or .JPEG files. We follow the same folder structure as EDM.
This is the standard Online SFBD run. The denoised dataset is refreshed every update_denoise_freq ticks, and a fixed num_samples_to_update images are re-generated each time (controlled in sampling_configs/cifar10.yaml).
torchrun --standalone --nproc_per_node=4 train.py \
--outdir=./experiments/cifar10_fixed \
--data_clean=/path/to/cifar10_clean_1pct \
--data_noisy=/path/to/cifar10_train \
--sigma=0.59 \
--corruption_probability=0.1 \
--dataset_keep_percentage=1.0 \
--consistency_coeff=0.0 \
--enable_online_sfbd=True \
--adaptive_update=False \
--update_denoise_freq=20 \
--involve_denoise_after_niter=100 \
--fully_involve_denoise_after_niter=600 \
--fid_ref=/path/to/cifar10_train \
--tick=20 \
--sample_config=sampling_configs/cifar10.yaml \
--expr_id=sfbd_cifar10_fixedThe adaptive variant measures the KID gap between newly generated samples and the current denoised dataset, and uses it to scale the number of samples committed per step. Use sampling_configs/cifar10c.yaml which sets a larger num_samples_to_update budget for the adaptive controller.
torchrun --standalone --nproc_per_node=4 train.py \
--outdir=./experiments/cifar10_adaptive \
--data_clean=/path/to/cifar10_clean_1pct \
--data_noisy=/path/to/cifar10_train \
--sigma=0.59 \
--corruption_probability=0.1 \
--dataset_keep_percentage=1.0 \
--consistency_coeff=0.0 \
--enable_online_sfbd=True \
--adaptive_update=True \
--update_denoise_freq=20 \
--involve_denoise_after_niter=200 \
--fully_involve_denoise_after_niter=201 \
--fid_ref=/path/to/cifar10_train \
--tick=20 \
--sample_config=sampling_configs/cifar10c.yaml \
--expr_id=sfbd_cifar10_adaptivetorchrun --standalone --nproc_per_node=4 train.py \
--outdir=./experiments/celeba_fixed \
--data_clean=/path/to/celeba_clean_1pct \
--data_noisy=/path/to/celeba_train \
--sigma=1.38 \
--corruption_probability=0.1 \
--dataset_keep_percentage=1.0 \
--consistency_coeff=0.0 \
--enable_online_sfbd=True \
--adaptive_update=False \
--update_denoise_freq=20 \
--involve_denoise_after_niter=100 \
--fully_involve_denoise_after_niter=600 \
--fid_ref=/path/to/celeba_train \
--tick=20 \
--sample_config=sampling_configs/celeba.yaml \
--expr_id=sfbd_celeba_fixedYou can skip Stage 1 by loading a pre-trained checkpoint with --transfer:
torchrun --standalone --nproc_per_node=4 train.py \
--outdir=./experiments/cifar10_finetune \
--data_clean=/path/to/cifar10_clean_1pct \
--data_noisy=/path/to/cifar10_train \
--sigma=0.59 \
--corruption_probability=0.1 \
--dataset_keep_percentage=1.0 \
--consistency_coeff=0.0 \
--enable_online_sfbd=True \
--adaptive_update=False \
--update_denoise_freq=20 \
--involve_denoise_after_niter=1 \
--fully_involve_denoise_after_niter=100 \
--transfer=/path/to/pretrained_network.pkl \
--fid_ref=/path/to/cifar10_train \
--tick=20 \
--sample_config=sampling_configs/cifar10.yaml \
--expr_id=sfbd_cifar10_finetune| Argument | Default | Description |
|---|---|---|
--data_clean |
required | Path to the clean (trusted) dataset directory |
--data_noisy |
required | Path to the corrupted dataset directory |
--sigma |
0.0 |
Noise level added to corrupted images |
--corruption_probability |
0.0 |
Fraction of pixels corrupted per image |
--dataset_keep_percentage |
1.0 |
Fraction of the noisy dataset to use |
--enable_online_sfbd |
False |
Enable Online SFBD iterative fine-tuning |
--adaptive_update |
True |
Use KID-guided adaptive update size; set False for fixed |
--update_denoise_freq |
— | Re-denoise every N ticks (required when enable_online_sfbd=True) |
--involve_denoise_after_niter |
-1 |
Tick at which Stage 1 (pre-training) ends and Stage 2 begins |
--fully_involve_denoise_after_niter |
500 |
Tick at which denoised images reach full weight in the training mix |
--sample_config |
required | Path to YAML file controlling the denoiser sampler (see sampling_configs/) |
--transfer |
— | Load weights from a pre-trained network pickle to skip Stage 1 |
--outdir |
required | Directory to save checkpoints and logs |
--expr_id |
test |
W&B project/run name |
Sampling configs (in sampling_configs/) control the denoising process used to refresh the denoised dataset during fine-tuning.
| File | Dataset | Steps | num_samples_to_update |
|---|---|---|---|
cifar10.yaml |
CIFAR-10 (fixed mode) | 18 | 1024 |
cifar10c.yaml |
CIFAR-10 (adaptive mode) | 21 | 2048 |
celeba.yaml |
CelebA | 40 | 2048 |
The key fields are:
num_samples_to_update: budget of images considered for re-generation per update step (fixed mode uses this exactly; adaptive mode uses it as a scaling reference).init_samples_per_task: images dispatched per sub-task.rho: EMA decay for the running KID estimate (adaptive mode only).
Checkpoints are saved to --outdir every --ckpt_ticks ticks. To resume a run:
torchrun --standalone --nproc_per_node=4 train.py \
--outdir=./experiments/cifar10_fixed \
... \
--resume=./experiments/cifar10_fixed/training-state-last.ptTraining metrics (loss, FID, KID, denoised image weights) are logged to Weights & Biases. Set TESTMODE=1 to disable W&B logging during local debugging:
TESTMODE=1 torchrun --standalone --nproc_per_node=1 train.py ...If you find this work useful, please cite:
@inproceedings{lu2026sfbdflow,
title = {SFBD Flow: A Continuous-Optimization Framework for Training Diffusion Models with Noisy Samples},
author = {Lu, Haoye and Lo, Darren and Yu, Yaoliang},
booktitle = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics (AISTATS)},
year = {2026},
url = {https://arxiv.org/abs/2506.02371}
}This codebase builds on the excellent EDM framework by Karras et al. (NeurIPS 2022) and the Ambient Diffusion codebase by Giannis Daras et al.
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.