Skip to content

watml/SFBD-flow

Repository files navigation

SFBD Flow: A Continuous-Optimization Framework for Training Diffusion Models with Noisy Samples

AISTATS 2026

arXiv OpenReview License: CC BY-NC-SA 4.0


Overview

This repository contains the official implementation of SFBD Flow, a framework for training diffusion models from corrupted data supplemented with a small fraction of clean samples — a setting motivated by data privacy constraints.

Key contributions:

  • We reinterpret the existing SFBD (Score-based Fine-tuning with Blurred Data) approach as an alternating projection algorithm, providing a new theoretical lens.
  • We introduce SFBD Flow, a continuous-time variant that eliminates manual coordination between denoising and fine-tuning steps.
  • We develop Online SFBD, a practical instantiation of SFBD Flow that outperforms strong baselines across CIFAR-10 and CelebA benchmarks.
  • We establish theoretical connections between SFBD Flow and consistency-constraint-based methods.

Method

Training proceeds in two stages:

Stage 1 — Pre-training

The model is first trained on the small clean dataset only. This stage runs for involve_denoise_after_niter ticks. It provides a stable starting point before the noisy data is introduced.

Stage 2 — Iterative Fine-tuning (Online SFBD)

After pre-training, the model is fine-tuned iteratively. At each update step the denoised dataset is refreshed using the latest model checkpoint, and the model is re-trained on the union of clean and denoised images.

We support two update-size strategies:

Mode Flag Behaviour
Fixed --adaptive_update=False Updates exactly num_samples_to_update images per step (set in the sampling config).
Adaptive (KID-guided) --adaptive_update=True (default) Uses the KID difference between the newly denoised images and the current denoised dataset to decide how many samples to commit. More improvement → more samples updated.

The weight of denoised images in training increases linearly from 0 to 1 between involve_denoise_after_niter and fully_involve_denoise_after_niter.


Installation

Requirements

  • Python 3.9
  • CUDA 12.1 (tested; other CUDA versions may work)
  • 1–4 GPUs (examples below use 4)

Step 1 — Install PyTorch

pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121

Step 2 — Install remaining dependencies

pip install -r requirements.txt

Data Preparation

The framework expects two separate dataset directories:

Argument Description
--data_clean Directory of clean images (the small trusted set, e.g. 1% or 10% of the full dataset)
--data_noisy Directory of corrupted images (the larger private/noisy set)

Both directories should contain images as individual .png or .JPEG files. We follow the same folder structure as EDM.


Quick Start

CIFAR-10 — Fixed update size

This is the standard Online SFBD run. The denoised dataset is refreshed every update_denoise_freq ticks, and a fixed num_samples_to_update images are re-generated each time (controlled in sampling_configs/cifar10.yaml).

torchrun --standalone --nproc_per_node=4 train.py \
  --outdir=./experiments/cifar10_fixed \
  --data_clean=/path/to/cifar10_clean_1pct \
  --data_noisy=/path/to/cifar10_train \
  --sigma=0.59 \
  --corruption_probability=0.1 \
  --dataset_keep_percentage=1.0 \
  --consistency_coeff=0.0 \
  --enable_online_sfbd=True \
  --adaptive_update=False \
  --update_denoise_freq=20 \
  --involve_denoise_after_niter=100 \
  --fully_involve_denoise_after_niter=600 \
  --fid_ref=/path/to/cifar10_train \
  --tick=20 \
  --sample_config=sampling_configs/cifar10.yaml \
  --expr_id=sfbd_cifar10_fixed

CIFAR-10 — KID-adaptive update size

The adaptive variant measures the KID gap between newly generated samples and the current denoised dataset, and uses it to scale the number of samples committed per step. Use sampling_configs/cifar10c.yaml which sets a larger num_samples_to_update budget for the adaptive controller.

torchrun --standalone --nproc_per_node=4 train.py \
  --outdir=./experiments/cifar10_adaptive \
  --data_clean=/path/to/cifar10_clean_1pct \
  --data_noisy=/path/to/cifar10_train \
  --sigma=0.59 \
  --corruption_probability=0.1 \
  --dataset_keep_percentage=1.0 \
  --consistency_coeff=0.0 \
  --enable_online_sfbd=True \
  --adaptive_update=True \
  --update_denoise_freq=20 \
  --involve_denoise_after_niter=200 \
  --fully_involve_denoise_after_niter=201 \
  --fid_ref=/path/to/cifar10_train \
  --tick=20 \
  --sample_config=sampling_configs/cifar10c.yaml \
  --expr_id=sfbd_cifar10_adaptive

CelebA — Fixed update size

torchrun --standalone --nproc_per_node=4 train.py \
  --outdir=./experiments/celeba_fixed \
  --data_clean=/path/to/celeba_clean_1pct \
  --data_noisy=/path/to/celeba_train \
  --sigma=1.38 \
  --corruption_probability=0.1 \
  --dataset_keep_percentage=1.0 \
  --consistency_coeff=0.0 \
  --enable_online_sfbd=True \
  --adaptive_update=False \
  --update_denoise_freq=20 \
  --involve_denoise_after_niter=100 \
  --fully_involve_denoise_after_niter=600 \
  --fid_ref=/path/to/celeba_train \
  --tick=20 \
  --sample_config=sampling_configs/celeba.yaml \
  --expr_id=sfbd_celeba_fixed

Fine-tuning from a pre-trained checkpoint

You can skip Stage 1 by loading a pre-trained checkpoint with --transfer:

torchrun --standalone --nproc_per_node=4 train.py \
  --outdir=./experiments/cifar10_finetune \
  --data_clean=/path/to/cifar10_clean_1pct \
  --data_noisy=/path/to/cifar10_train \
  --sigma=0.59 \
  --corruption_probability=0.1 \
  --dataset_keep_percentage=1.0 \
  --consistency_coeff=0.0 \
  --enable_online_sfbd=True \
  --adaptive_update=False \
  --update_denoise_freq=20 \
  --involve_denoise_after_niter=1 \
  --fully_involve_denoise_after_niter=100 \
  --transfer=/path/to/pretrained_network.pkl \
  --fid_ref=/path/to/cifar10_train \
  --tick=20 \
  --sample_config=sampling_configs/cifar10.yaml \
  --expr_id=sfbd_cifar10_finetune

Key Arguments

Argument Default Description
--data_clean required Path to the clean (trusted) dataset directory
--data_noisy required Path to the corrupted dataset directory
--sigma 0.0 Noise level added to corrupted images
--corruption_probability 0.0 Fraction of pixels corrupted per image
--dataset_keep_percentage 1.0 Fraction of the noisy dataset to use
--enable_online_sfbd False Enable Online SFBD iterative fine-tuning
--adaptive_update True Use KID-guided adaptive update size; set False for fixed
--update_denoise_freq Re-denoise every N ticks (required when enable_online_sfbd=True)
--involve_denoise_after_niter -1 Tick at which Stage 1 (pre-training) ends and Stage 2 begins
--fully_involve_denoise_after_niter 500 Tick at which denoised images reach full weight in the training mix
--sample_config required Path to YAML file controlling the denoiser sampler (see sampling_configs/)
--transfer Load weights from a pre-trained network pickle to skip Stage 1
--outdir required Directory to save checkpoints and logs
--expr_id test W&B project/run name

Sampling Configs

Sampling configs (in sampling_configs/) control the denoising process used to refresh the denoised dataset during fine-tuning.

File Dataset Steps num_samples_to_update
cifar10.yaml CIFAR-10 (fixed mode) 18 1024
cifar10c.yaml CIFAR-10 (adaptive mode) 21 2048
celeba.yaml CelebA 40 2048

The key fields are:

  • num_samples_to_update: budget of images considered for re-generation per update step (fixed mode uses this exactly; adaptive mode uses it as a scaling reference).
  • init_samples_per_task: images dispatched per sub-task.
  • rho: EMA decay for the running KID estimate (adaptive mode only).

Checkpointing and Resuming

Checkpoints are saved to --outdir every --ckpt_ticks ticks. To resume a run:

torchrun --standalone --nproc_per_node=4 train.py \
  --outdir=./experiments/cifar10_fixed \
  ... \
  --resume=./experiments/cifar10_fixed/training-state-last.pt

Monitoring

Training metrics (loss, FID, KID, denoised image weights) are logged to Weights & Biases. Set TESTMODE=1 to disable W&B logging during local debugging:

TESTMODE=1 torchrun --standalone --nproc_per_node=1 train.py ...

Citation

If you find this work useful, please cite:

@inproceedings{lu2026sfbdflow,
  title     = {SFBD Flow: A Continuous-Optimization Framework for Training Diffusion Models with Noisy Samples},
  author    = {Lu, Haoye and Lo, Darren and Yu, Yaoliang},
  booktitle = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics (AISTATS)},
  year      = {2026},
  url       = {https://arxiv.org/abs/2506.02371}
}

Acknowledgements

This codebase builds on the excellent EDM framework by Karras et al. (NeurIPS 2022) and the Ambient Diffusion codebase by Giannis Daras et al.


License

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages