Skip to content

cvsp-lab/Stable-cGAN

Repository files navigation

Enhancing Stability in Training Conditional Generative Adversarial Networks via Selective Data Matching

distribution_overview1

TODO

⬜️ Release pre-trained model on ImageNet 64x64, 128x128, CIFAR-10, CIFAR-100 dataset

⬜️ Release training code for Exact Selective-Matching

About Enhancing Stability in Training Conditional Generative Adversarial Networks via Selective Data Matching

Conditional generative adversarial networks (cGANs) have demonstrated remarkable success due to their class-wise controllability and superior quality for complex generation tasks. Typical cGANs solve the joint distribution matching problem by decomposing two easier sub-problems: marginal matching and conditional matching. In this paper, we proposes a simple but effective training methodology, selective focusing learning, which enforces the discriminator and generator to learn easy samples of each class rapidly while maintaining diversity. Our key idea is to selectively apply conditional and joint matching for the data in each mini-batch. Specifically, we first select the samples with the highest scores when sorted using the conditional term of the discriminator outputs (real and generated samples). Then we optimize the model using the selected samples with only conditional matching and the other samples with joint matching. From our toy experiments, we found that it is the best to apply only conditional matching to certain samples due to the content-aware optimization of the discriminator. We conducted experiments on ImageNet (64 × 64 and 128 × 128), CIFAR-10, CIFAR-100 datasets, and Mixture of Gaussian, noisy label settings to demonstrate that the proposed method can substantially (up to 35.18% in terms of FID) improve all indicators with 10 independent trials.

Requirements

To install requirements:

pip install -r requirements.txt

Training BigGAN with Selective Focusing Learing on ImageNet

To train BigGAN models we use the BigGAN-PyTorch and Instance Selection for GANs repo. We perform minimal changes to the code. The main change part is the conditional term of the projection discriminator in BigGAN.py (L391-L402, L415-L447). Further, updating the focusing rate is represented in train.py (L66-L71, L146-L155, L185-L209).

Preparing Data (Same as Instance Selection for GANs)

To train a BigGAN on ImageNet you will first need to construct an HDF5 dataset file for ImageNet (optional), compute Inception moments for calculating FID, and construct the image manifold for calculating Precision, Recall, Density, and Coverage. All can by done by modifying and running

bash scripts/utils/prepare_data_imagenet_[res].sh

where [res] is substituted with the desired resolution (options are 64, 128, or 256). These scripts will assume that ImageNet is in a folder called data in the instance_selection_for_gans directory. Replace this with the filepath to your copy of ImageNet.

64x64 ImageNet

To replicate our best 64x64 model run bash scripts/launch_SAGAN_res64_ch32_bs128_dstep_1_rr40.sh. A single GPU with at least 12GB of memory should be sufficient to train this model. Training is expected to take about 2-3 days on a high-end GPU.

We added only two configurations: Training_type and maximum_focusing_rate.

parser.add_argument(
  '--Training_type', type=str, default='without_SFL',
  choices=['without_SFL', 'SFL', 'SFL+'],
  help='Training type of SFL (default: %(default)s)')

parser.add_argument(
  '--maximum_focusing_rate', type=float, default=1,
  help='The percentage of maximum focusing rate (default: %(default)s)')

Pre-trained weight

SFL+ [SFL] to be

Results

Our model achieves the following performance on :

Model name IS ↑ FID ↓ P ↑ R ↑ D ↑ C ↑
SA-GAN 17.77 17.23 0.68 0.66 0.72 0.71
Approx SFL 19.11 16.20 0.69 0.67 0.76 0.76
Approx SFL+ 21.50 14.20 0.72 0.68 0.84 0.80
Exact SFL+ 21.98 13.55 0.73 0.66 0.85 0.81

Applying Selective Focusing Learning to Your Own Dataset or Any cGAN variant architectures

Selective Focusing Learing can be applied to any class labeled PyTorch dataset using the SFL and SFL_plus functions which are a few lines of code.

  def SFL(self, out_c, out_u, Focusing_rate):
    out_c, idx_c = torch.sort(out_c, dim=0, descending=True)
    out_u = out_u[idx_c[:, 0]]
    out = torch.cat([out_c[Focusing_rate:] + out_u[Focusing_rate:], out_c[:Focusing_rate]], 0)
    return out

  def SFL_plus(self, out_c, out_u, Focusing_rate, scores):
    _,idx_c = torch.sort(scores, dim=0)
    out_c = out_c[idx_c]
    out_u = out_u[idx_c]
    out = torch.cat([out_c[Focusing_rate:] + out_u[Focusing_rate:], out_c[:Focusing_rate]], 0)
    return out

Contributing

[1] Brock, Andrew, and Alex Andonian. "BigGAN-PyTorch". https://github.com/ajbrock/BigGAN-PyTorch

[2] Terrance DeVries, Michal Drozdzal, and Graham W. Taylor. "Instance Selection for GANs". https://github.com/uoguelph-mlrg/instance_selection_for_gans

About

Enhancing Stability in Training Conditional Generative Adversarial Networks via Selective Data Matching

Resources

License

MIT, Unknown licenses found

Licenses found

MIT
LICENSE
Unknown
LICENSE.md

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors