This repository contains the official implementation for the paper:
S. Ahamed, E. Haber, DAWN-FM: Data-Aware and Noise-Informed Flow Matching for Solving Inverse Problems, Foundations of Data Science (AIMS), Jan 2026, doi: 10.3934/fods.2026005.
This repository provides clean, consolidated implementations of:
- DAW-FM: Data-Aware Flow Matching (with data embedding only)
- DAWN-FM: Data-Aware Noise-Informed Flow Matching (with data and noise embedding)
Both methods are applied to two inverse problems:
- Image Deblurring: on MNIST, CIFAR10, and STL10 datasets
- Tomography Reconstruction: on OrganCMNIST, OrganAMNIST, and OrganSMNIST datasets
Inverse problems, which involve estimating parameters from incomplete or noisy observations, arise in various fields such as medical imaging, geophysics, and signal processing. These problems are often ill-posed, requiring regularization techniques to stabilize the solution. In this work, we employ Flow Matching (FM), a generative framework that integrates a deterministic processes to map a simple reference distribution, such as a Gaussian, to the target distribution. Our method DAWN-FM: Data-AWare and Noise-Informed Flow Matching incorporates data and noise embedding, allowing the model to access representations about the measured data explicitly and also account for noise in the observations, making it particularly robust in scenarios where data is noisy or incomplete. By learning a time-dependent velocity field, FM not only provides accurate solutions but also enables uncertainty quantification by generating multiple plausible outcomes. Unlike pretrained diffusion models, which may struggle in highly ill-posed settings, our approach is trained specifically for each inverse problem and adapts to varying noise levels. We validate the effectiveness and robustness of our method through extensive numerical experiments on tasks such as image deblurring and tomography.
Figure 1: Schematics for standard flow matching (FM) (left) and flow matching for solving inverse problem characterized by the forward problem
- ✅ Unified training scripts for deblurring and tomography tasks
- ✅ Unified inference scripts with comprehensive metric computation
- ✅ Support for both DAW-FM and DAWN-FM training modes
- ✅ Command-line interface for easy experimentation
- ✅ Efficient FFT-based blur operators and matrix-based 2D tomography operators
- ✅ Multi-GPU inference support for faster processing
- ✅ Comprehensive evaluation metrics (PSNR, SSIM, MSE, data misfit)
- ✅ Support for MNIST, CIFAR10, STL10, and MedMNIST datasets
- Python 3.8+
- PyTorch 2.0+
- CUDA 12.8
# Clone the repository
git clone https://github.com/yourusername/DAWN-FM.git
cd DAWN-FM
# Install dependencies
pip install -r requirements.txtBy default, all scripts use ./data as the data directory. Datasets will be automatically downloaded on first use.
You can organize your data directory as follows:
DAWN-FM/
├── data/ # Default data directory
│ ├── MNIST/ # Auto-downloaded MNIST dataset
│ ├── CIFAR10/ # Auto-downloaded CIFAR10 dataset
│ ├── STL10/ # Auto-downloaded STL10 dataset
│ └── medmnist/ # Auto-downloaded MedMNIST datasets
│ ├── organcmnist.npz
│ ├── organamnist.npz
│ └── organsmnist.npz
├── train_deblurring.py
└── ...
Note: To use a different data directory, specify --data_dir /your/custom/path when running scripts.
DAWN-FM/
├── train_deblurring.py # Training script for deblurring
├── inference_deblurring.py # Inference script for deblurring
├── train_tomography.py # Training script for tomography
├── inference_tomography.py # Inference script for tomography
├── dawnfm/ # Core package
│ ├── models.py # Network architectures (UNetFMG_DE, UNetFMG_DE_NE)
│ ├── forward_problems.py # Forward operators (blur, tomography)
│ ├── load_datasets.py # Dataset loading utilities
│ ├── config.py # Configuration utilities
│ └── utils.py # Utility functions
├── requirements.txt # Python dependencies
├── README.md # This file
└── LICENSE # License file
Train on MNIST:
python train_deblurring.py \
--dataset mnist \
--batch_size 512 \
--max_epochs 1000 \
--save_dir ./experimentsTrain on CIFAR10:
python train_deblurring.py \
--dataset cifar10 \
--batch_size 512 \
--max_epochs 1000 \
--save_dir ./experimentsTrain on MNIST with noise embedding:
python train_deblurring.py \
--dataset mnist \
--use_noise_embed \
--noise_range 0.0 0.1 \
--batch_size 512 \
--max_epochs 1000 \
--save_dir ./experimentsTrain on CIFAR10 with noise embedding:
python train_deblurring.py \
--dataset cifar10 \
--use_noise_embed \
--noise_range 0.0 0.1 \
--batch_size 512 \
--max_epochs 1000 \
--save_dir ./experimentsTrain on OrganCMNIST:
python train_tomography.py \
--dataset organcmnist \
--img_size 64 \
--batch_size 128 \
--max_epochs 1000 \
--num_angles 180 \
--save_dir ./experimentsTrain on OrganAMNIST:
python train_tomography.py \
--dataset organamnist \
--img_size 64 \
--batch_size 128 \
--max_epochs 1000 \
--num_angles 180 \
--save_dir ./experimentsTrain on OrganCMNIST with noise embedding:
python train_tomography.py \
--dataset organcmnist \
--use_noise_embed \
--img_size 64 \
--noise_range 0.0 0.1 \
--batch_size 128 \
--max_epochs 1000 \
--num_angles 180 \
--save_dir ./experimentsRun inference on test set:
python inference_deblurring.py \
--dataset mnist \
--model_path ./experiments/models/mnist_daw-fm_arch-1x16x32/model_ep=1000.pth \
--noise_level 0.0 \
--nsteps 100 \
--num_runs 32 \
--save_dir ./resultsFor DAWN-FM models, add the --use_noise_embed flag:
python inference_deblurring.py \
--dataset mnist \
--use_noise_embed \
--model_path ./experiments/models/mnist_dawn-fm_arch-1x16x32/model_ep=1000.pth \
--noise_level 0.05 \
--nsteps 100 \
--num_runs 32 \
--save_dir ./resultsFor faster inference, distribute the dataset across multiple GPUs:
# Use GPUs 0, 1, and 2 for parallel inference
python inference_deblurring.py \
--dataset cifar10 \
--model_path ./experiments/models/cifar10_daw-fm_arch-3x16x32/model_ep=1000.pth \
--gpus 0 1 2 \
--noise_level 0.0 \
--nsteps 100 \
--num_runs 32 \
--save_dir ./resultsThe --gpus flag automatically splits the test dataset across the specified GPUs and runs inference in parallel. Each GPU processes its assigned subset independently, significantly reducing total inference time.
Run inference on test set:
python inference_tomography.py \
--dataset organcmnist \
--img_size 64 \
--model_path ./experiments/models/tomography_organcmnist_daw-fm_arch-1x16x32x64x128/model_ep=1000.pth \
--num_angles 180 \
--noise_level 0.0 \
--nsteps 100 \
--num_runs 32 \
--save_dir ./resultsFor DAWN-FM models, add the --use_noise_embed flag:
python inference_tomography.py \
--dataset organcmnist \
--use_noise_embed \
--img_size 64 \
--model_path ./experiments/models/tomography_organcmnist_dawn-fm_arch-1x16x32x64x128/model_ep=1000.pth \
--num_angles 180 \
--noise_level 0.07 \
--nsteps 100 \
--num_runs 32 \
--save_dir ./resultsMulti-GPU inference for tomography:
# Use GPUs 1 and 2 for parallel tomography inference
python inference_tomography.py \
--dataset organcmnist \
--use_noise_embed \
--img_size 64 \
--model_path ./experiments/models/tomography_organcmnist_dawn-fm_arch-1x16x32x64x128/model_ep=1000.pth \
--gpus 1 2 \
--num_angles 180 \
--noise_level 0.07 \
--nsteps 100 \
--num_runs 32 \
--save_dir ./results--dataset: Dataset to use (mnist,cifar10,stl10)--data_dir: Directory containing datasets (default:./data/)
--use_noise_embed: Use noise embedding (DAWN-FM mode)--arch: Network architecture (auto-determined if not specified)- Default:
[1, 16, 32]for MNIST,[3, 16, 32]for RGB
- Default:
--batch_size: Batch size (default: 512)--max_epochs: Maximum number of epochs (default: 1000)--lr: Learning rate (default: 1e-4)--lr_min: Minimum learning rate for scheduler (default: 0)--num_workers: Number of data loading workers (default: 4)
--blur_sigma: Gaussian blur kernel sigma[sigma_x, sigma_y](default:[3.0, 3.0])--noise_range: Noise level range[min, max]as fraction of data range (default:[0.0, 0.1])
--interpolation_sigma: Sigma for stochastic interpolation (default: 0.01)--save_dir: Directory to save experiments (default:./experiments)--save_every: Save checkpoint every N epochs (default: 25)--device: Device to use (auto-detected if not specified)
--dataset: Dataset to use (organcmnist,organamnist,organsmnist)--data_dir: Directory containing datasets (default:./data/)--img_size: Image size (default: 64)
--use_noise_embed: Use noise embedding (DAWN-FM mode)--arch: Network architecture (default:[1, 16, 32, 64, 128])
--batch_size: Batch size (default: 128)--max_epochs: Maximum number of epochs (default: 1000)--lr: Learning rate (default: 1e-4)--lr_min: Minimum learning rate for scheduler (default: 0)--num_workers: Number of data loading workers (default: 4)
--num_angles: Number of projection angles for tomography (default: 180)--noise_range: Noise level range[min, max]as fraction of data range (default:[0.0, 0.1])
--interpolation_sigma: Sigma for stochastic interpolation (default: 0.01)--save_dir: Directory to save experiments (default:./experiments)--save_every: Save checkpoint every N epochs (default: 1)--device: Device to use (auto-detected if not specified)
--dataset: Dataset to use (mnist,cifar10,stl10)--data_dir: Directory containing datasets (default:./data/)
--model_path: Path to trained model checkpoint (required)--use_noise_embed: Use noise embedding (must match training)--arch: Network architecture (auto-determined if not specified)
--noise_level: Noise level as fraction of data range (default: 0.0)--nsteps: Number of ODE solver steps (default: 100)--num_runs: Number of reconstructions per image (default: 2)--batch_size: Batch size for inference (default: 1)
--blur_sigma: Blur kernel sigma (default:[3.0, 3.0])--save_dir: Directory to save results (default:./results)--device: Device to use (e.g.,cuda:0,cpu). Ignored if--gpusis specified--gpus: GPU IDs for parallel inference (e.g.,--gpus 0 1 2). Automatically splits dataset across GPUs--n_jobs: Number of parallel jobs for metrics (default: 8)
--dataset: Dataset to use (organcmnist,organamnist,organsmnist)--data_dir: Directory containing datasets (default:./data/)--img_size: Image size (default: 28)
--model_path: Path to trained model checkpoint (required)--use_noise_embed: Use noise embedding (must match training)--arch: Network architecture (default:[1, 16, 32, 64, 128])
--noise_level: Noise level as fraction of data range (default: 0.0)--nsteps: Number of ODE solver steps (default: 100)--num_runs: Number of reconstructions per image (default: 32)--batch_size: Batch size for inference (default: 1)
--num_angles: Number of projection angles (default: 360)--save_dir: Directory to save results (default:./results)--device: Device to use (e.g.,cuda:0,cpu). Ignored if--gpusis specified--gpus: GPU IDs for parallel inference (e.g.,--gpus 1 2). Automatically splits dataset across GPUs--n_jobs: Number of parallel jobs for metrics (default: 8)
experiments/
├── logs/
│ ├── deblurring_{dataset}_{mode}_arch-{arch}/
│ │ └── train_logs.csv # Training loss logs
│ └── tomography_{dataset}_{mode}_arch-{arch}/
│ └── train_logs.csv
└── models/
├── deblurring_{dataset}_{mode}_arch-{arch}/
│ ├── model_ep=0025.pth # Checkpoints
│ └── ...
└── tomography_{dataset}_{mode}_arch-{arch}/
├── model_ep=0001.pth
└── ...
results/
├── generated_images/
│ ├── deblurring_{dataset}_{mode}_{noise}/
│ │ ├── mnist_000000.npy # Saved reconstructions
│ │ └── ...
│ └── tomography_{dataset}_{mode}_{noise}/
│ ├── organcmnist_test_000000.npy
│ └── ...
└── metrics/
├── deblurring_{dataset}_{mode}_{noise}/
│ └── metrics.csv # Evaluation metrics
└── tomography_{dataset}_{mode}_{noise}/
└── metrics.csv
Each .npy file contains a stack of images with shape:
[2 + num_runs, H, W]for grayscale (MNIST)[2 + num_runs, H, W, 3]for RGB (CIFAR10, STL10)
Stack order: [original, blurred_data, recon_1, recon_2, ..., recon_N]
Each .npy file contains a stack of images with shape:
[1 + num_runs, H, W]for grayscale (all MedMNIST datasets)
Stack order: [original, recon_1, recon_2, ..., recon_N]
Note: Sinogram data is not saved due to different dimensions.
The inference scripts compute the following metrics:
- MSE: Mean Squared Error
- PSNR: Peak Signal-to-Noise Ratio (dB)
- SSIM: Structural Similarity Index
- MISFIT: Data consistency metric (relative norm of residual)
Metrics are computed for:
- Each individual reconstruction (2-32 runs depending on task)
- The mean reconstruction across all runs
Figure 2: Schematic of the training process for the FM model for solving inverse problem, where the forward model is given by
- U-Net with data embedding
- Inputs: current image, time, adjoint of measurement data
- Deblurring: Adjoint is deconvolved/backprojected blurred image
- Tomography: Adjoint is backprojected sinogram
- U-Net with data and noise embedding
- Inputs: current image, time, adjoint of measurement data, noise level
- Same architecture as DAW-FM but with additional noise embedding layers
- Deblurring: Noise level encoded per-image based on blur + additive noise
- Tomography: Noise level encoded per-image based on sinogram noise
- FFT-based Gaussian blur convolution
- Efficient computation in Fourier domain
- Configurable blur kernel sigma
- Matrix-based Radon transform approximation
- Configurable number of projection angles (default: 180-360)
- Optional Hamming filter for filtered backprojection
- MNIST: 28×28 grayscale handwritten digits
- CIFAR10: 32×32 RGB natural images
- STL10: 64×64 RGB natural images
- OrganCMNIST: 64×64 grayscale abdominal CT organ scans (coronal plane)
- OrganAMNIST: 64×64 grayscale abdominal CT organ scans (axial plane)
- OrganSMNIST: 64×64 grayscale abdominal CT organ scans (sagittal plane)
All datasets are automatically downloaded on first use.
If you use this code in your research, please cite:
@article{ahamed2026dawn,
title={Data-Aware and Noise-Informed Flow Matching for Solving Inverse Problems},
author={S. Ahamed, E. Haber},
journal={Foundations of Data Science},
year={2026},
doi={10.3934/fods.2026005}
}This project is licensed under the MIT License - see the LICENSE file for details.

