This is an implementation of Analyzing and Improving the Image Quality of StyleGAN and Differentiable Augmentation for Data-Efficient GAN Training in Tensorflow 2.3.
Check the ./results folder to see more images.
Use main.py
to train a StyleGAN2 based on given dataset.
Training takes 80s(CUDA op)/110s(Tensorflow op) for 100 steps(batch_size=4) on a GTX 1080ti.
Example usage for training on afhq-dataset:
python main.py train --dataset_name afhq \
--dataset_path ./path/to/afhq_dataset_dir \
--batch_size 4 \
--res 512 \
--config e \
--impl ref \
Use main.py
to do inference on different mode and a given label.
Inference mode be one of: [example, gif, mixing].
The pre-trained ffhq/afhq weights are located here.
Example usage:
python main.py inference --ckpt ./weights-ffhq/official_1024x1024 \
--res 1024 \
--config f \
--truncation_psi 0.5 \
--mode example \
Use cal_metrics.py
to calculate PPL/FID score.
The pre-trained LPIPS's weights(standard metric to estimate perceptual similarity) used in PPL will be downloaded automatically from here.
Evaluation time and results for the pre-trained FFHQ generator using one GTX 1080ti.
Metric | Time | Result | Description |
---|---|---|---|
fid50k | 1.5 hours | 3.096 | Fréchet Inception Distance using 50,000 images. |
ppl_wend | 2.5 hours | 144.044 | Perceptual Path Length for endpoints in W. |
Example usage for FID evaluation:
python cal_metrics.py --ckpt ./weights-ffhq/official_1024x1024 \
--res 1024 \
--config f \
--mode fid \
--dataset './datasets/ffhq' \
- Add FFHQ official-weights inference feature.
- Add metrics.py to compute PPL and FID.
- Train a model based on custom dataset with DiffAugment method.
You will need the following to run the above:
- TensorFlow = 2.3
- Python 3, Pillow 7.0.0, Numpy 1.18
- Most of the code/CUDA are based on the official implementation.
- The code of modules/DiffAugment_tf.py is from data-efficient-gans.
- The AFHQ training dataset is from stargan-v2.
- The pre-trained FFHQ generator's weights are convered from stylegan2-ffhq-config-f.pkl.
- The pre-trained LPIPS's weights used in PPL are converted from vgg16_zhang_perceptual.pkl.