Skip to content

Latest commit

 

History

History
88 lines (68 loc) · 4 KB

README.md

File metadata and controls

88 lines (68 loc) · 4 KB

StyleGAN2

This is an implementation of Analyzing and Improving the Image Quality of StyleGAN and Differentiable Augmentation for Data-Efficient GAN Training in Tensorflow 2.3.

Style mixing examples

Check the ./results folder to see more images.

Training

Use main.py to train a StyleGAN2 based on given dataset. Training takes 80s(CUDA op)/110s(Tensorflow op) for 100 steps(batch_size=4) on a GTX 1080ti.

Example usage for training on afhq-dataset:

python main.py train --dataset_name afhq                       \
                     --dataset_path ./path/to/afhq_dataset_dir \
                     --batch_size 4                            \
                     --res 512                                 \
                     --config e                                \
                     --impl ref                                \

Inference

Generate image_example/transition_gif/style_mixing_example

Use main.py to do inference on different mode and a given label. Inference mode be one of: [example, gif, mixing]. The pre-trained ffhq/afhq weights are located here.

Example usage:

python main.py inference --ckpt ./weights-ffhq/official_1024x1024  \
                         --res 1024                                \
                         --config f                                \
                         --truncation_psi 0.5                      \
                         --mode example                            \

Metric

Calculate quality metric for StyleGAN2

Use cal_metrics.py to calculate PPL/FID score. The pre-trained LPIPS's weights(standard metric to estimate perceptual similarity) used in PPL will be downloaded automatically from here.

Evaluation time and results for the pre-trained FFHQ generator using one GTX 1080ti.

Metric Time Result Description
fid50k 1.5 hours 3.096 Fréchet Inception Distance using 50,000 images.
ppl_wend 2.5 hours 144.044 Perceptual Path Length for endpoints in W.

Example usage for FID evaluation:

python cal_metrics.py --ckpt ./weights-ffhq/official_1024x1024  \
                      --res 1024                                \
                      --config f                                \
                      --mode fid                                \
                      --dataset './datasets/ffhq'               \

Todo

  • Add FFHQ official-weights inference feature.
  • Add metrics.py to compute PPL and FID.
  • Train a model based on custom dataset with DiffAugment method.

Requirements

You will need the following to run the above:

  • TensorFlow = 2.3
  • Python 3, Pillow 7.0.0, Numpy 1.18

Acknowledgements