Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lacking documentation on how to create the average latent file #19

Closed
ericchansen opened this issue Aug 2, 2021 · 13 comments
Closed

Lacking documentation on how to create the average latent file #19

ericchansen opened this issue Aug 2, 2021 · 13 comments

Comments

@ericchansen
Copy link

ericchansen commented Aug 2, 2021

The example configuration files (https://github.com/nv-tlabs/datasetGAN_release/tree/master/datasetGAN/experiments), ex. "cat_16.json" or "car_20.json", contain a field named "average_latent" and "annotation_image_latent_path". Both of these fields are paths to .npy files.

This repo does not describe how these files are generated. Even after one trains a StyleGAN model, one cannot then use DatasetGAN without these files. Please provide documentation on how to create this file for a custom dataset.

I imagine that you'd want to provide examples using your up-to-date repos, ex. https://github.com/NVlabs/stylegan2-ada-pytorch and/or https://github.com/NVlabs/stylegan2-ada.

@PoissonChasseur
Copy link

You can take a look at issue #10 where there is more explanation of the more difficult parts to understand in relation to these configuration files. However, it is certain that in my opinion, it would be a good idea to have an official document about them.

Note that you will need to code some parts by yourself in order to get the "avg_latent_stylegan1.npy", "latent_stylegan1.npy" files and the converted weights file (Tensorflow to PyTorch) of the pre-trained StyleGAN (if you have trained your own StyleGAN) (see issu #1). (At least for now, there is no file related to all of these parts on this Github).

@ericchansen
Copy link
Author

ericchansen commented Aug 2, 2021

@PoissonChasseur, agreed that documentation on how to do these pieces, "avg_latent_stylegan1.npy" and "latent_stylegan1.npy", would be good.

I trained my own StyleGAN using https://github.com/NVlabs/stylegan2-ada-pytorch, and it's not clear to me how I would go about getting or creating "avg_latent_stylegan1.npy" and "latent_stylegan1.npy".

@PoissonChasseur, if you have any further advice---I already read issue #10 and #1---or experience, please do share. Otherwise, I hope NVIDIA will be able to improve the docs.

@PoissonChasseur
Copy link

PoissonChasseur commented Aug 3, 2021

So from what was described in issue #10 and what I also used in my case with a StyleGAN V1 (this should also be valid for the StyleGAN V2, but I haven't read the research yet nor the code on the StyleGAN2-Ada):

  • avg_latent_stylegan1.npy : You have to calculate the average of (for example) 8 000 "W latent code" (the latent code after the "mapping network" which will convert the "Z latent code" to "W latent code" - but here we need the version BEFORE applying the threshold, because we cannot yet use the threshold at that time, see details after):

    • In the case of StyleGAN V1, as shown in the code of this Github (train_interpreter.py and the other files it uses), this "avg_latent" is only used when the model is initialized after loading the pre-trained weights. It is only used when using the "threshold" which will be used on the "W latent code" (so the StyleGAN V1 can work without this, but all images presented in the research of the StyleGAN V1 used it).
    • As I haven't read the StyleGAN2-Ada yet, it's hard to say what this might correspond to, but I think the "avg_latent" might also be present in the StyleGAN V2 since this architecture still contains a "mapping network "as well as the threshold principle on the output of this sub-network.
  • latent_stylegan1.npy: This file simply contains all the "Z latent codes" (or "W" if we have captured these instead) linked to the training images of the DatasetGAN and which were also previously generated by the pre-trained StyleGAN.

    • The DatasetGAN uses as input the features (the activations of the outputs of the StyleGAN layers) linked to the images generated by the StyleGAN. Indeed, each pixel_classifier (the MLP of 3 layers) analyzes 1D vectors whose length is according to the dimension of the StyleGAN depth. For this reason, we need to be able to re-generate the images generated by the StyleGAN and that we have labeled manually, in order to train the DatasetGAN.
    • The “latent_stylegan1.npy” file therefore simply contains a 2D matrix of the form [Nb_img, 512] where 512 = the length of the latent code. Each row of this matrix corresponds to the latent code (the entry of StyleGAN V1) linked to the generation of one training image (by using the StyleGAN) for the DatasetGAN and for which we have also labeled manually (since we also need its corresponding mask).
    • It should be possible to do the same in the case of StyleGAN2-Ada since we simply need the inputs to give to the StyleGAN to obtain each images of training. However, it 's possible that the exact name used has changed. It is also possible (because I do not know the architecture of StyleGAN2-Ada) that more information is needed as input, in order to have a perfect reproduction of the initial images that were generated. Fixing PyTorch's Random Seed (Issue Images generated by the StyleGAN of the DatasetGAN different from the original ones (with the same StyleGAN) #15 ) will likely be needed as well.

Note that if your initial issue with deciding to choose StyleGAN2-Ada is the convertion of the file of the pre-trained StyleGAN between Tensorfkow and PyTorch, I had no issues related to this via the link provided in issue #1 .

But it is certain that the use of a better version of the StyleGAN will also add other elements to be adjusted in the code (for example, the fact that the DatasetGAN works with the features of the StyleGAN and not with the output images = the name and position of the layers to use will probably change). Anyway, I'll be interested to see the final code related to all of this once that's done, as probably a lot of other people :).

I hope that all my explanations will help you on certain points.

@ericchansen
Copy link
Author

ericchansen commented Aug 3, 2021

@PoissonChasseur, this is very helpful! I'm going to have to re-read what you wrote another 10 or 20 times, and then do more digging through the architecture. I'll probably reply later today with more questions.

Hopefully someone from NVIDIA will want to answer these questions.

Practically speaking, I still don't know how to generate "avg_latent_stylegan1.npy" and "latent_stylegan1.npy" from any of the StyleGAN repos. It'd be very helpful to have working code/examples/documentation that shows how to generate these files from at least one of the four StyleGAN repos (shown below).

@PoissonChasseur
Copy link

PoissonChasseur commented Aug 3, 2021

Overall, at least in the case of StyleGAN V1 ( https://github.com/NVlabs/stylegan ) which I used and by also using the same notation as in the code of the "train_interpreter.py" file from this Github:

  • Step 0: Train StyleGAN V1 with your own data set to get a pre-trained StyleGAN (or download instead a pre-trained StyleGAN provided by the Github of the StyleGAN V1 or any other sources)

  • Step 1: Follow the instructions in the link for [Tensorflow -> PyTorch] conversion of the pre-trained StyleGAN weights file in issue Confusion regarding checkpoints #1 . This will give you the ".pt" file you need to perform all of the following steps.

  • Step 2: avg_latent_stylegan1.npy:

    • First, you need to load the pre-trained StyleGAN you got in step 1. The initialization will however exclude the use of the threshold for this step (as is also done in Step 1 if you want to test the pre-trained StyleGAN).
    • For this step, you only need to use the "G_mapping" part (the mapping network) of "g_all" (the complete StyleGAN V1 ).
    • You need to do a for loop on "range (0, 8000)" where on each iteration: you compute a random "Z latent code" with "np.random.randn(1, 512)" (Gaussian distribution N(mean = 0 , std / var = 1)), send it as input to "G_mapping" and you capture its output (the "W latent code").
    • You then simply calculate the average of the 8,000 "W latent code" that you obtained. You then store this result in a Numpy file that you can name "avg_latent_stylegan1.npy" (or any other name).
  • Step 3: latent_stylegan1.npy + generation of fake images:

    • First, you need to load the pre-trained StyleGAN you got in step 1. However, this time you will include the threshold in the architecture since you now have the "avg_latent" you need for its initialization.
    • Make a for loop on "range (0, nb_img_you_want)" where at each iteration: you first create a random "Z latent code" ("latent = np.random.randn (1, 512)") then pass it as input to the StyleGAN using the "latent_to_image" function (from the "utils.utils" file on this Github). During this loop, you must store the latent codes linked to the creation of each image since they will be saved in the "latent_stylegan1.npy" file. You also need to save the generated images as you will then need to label them manually (using, for example, LabelMe https://github.com/wkentaro/labelme ).
    • If you choose to label only a subset of all generated images, then "latent_stylegan1.npy" should contain only the "laten code" related to the images you have chosen - but this can be done later, after you have manually labelled some images that were generated by the pre-trained StyleGAN.
  • Step 4: Manually label some fake images generated by the StyleGAN (by using, for example, LabelMe https://github.com/wkentaro/labelme) in order to obtain their corresponding "mask".

  • Step 5: Now that you have "avg_latent_stylegan1.npy", "latent_stylegan1.npy", the pre-trained StyleGAN and some pairs of (image, mask) for the DatasetGAN training dataset, you can perform the DatasetGAN training via the file "train_interpreter.py" from this Github. Some modifications might be necessary on some small parts (eg the files extension used for the images which could be different from ".jpg").

I hope you are now better able to implement your own implementation of all of these parts. If you haven't already, I strongly suggest that you read the StyleGAN V1 paper first, then read and understand all the code in the "train_interpreter.py" file on this Github, before you start implementing all of these steps.

@dcetin
Copy link

dcetin commented Aug 4, 2021

Hello, I am also trying my hand at training StyleGAN/DatasetGAN for a custom dataset I have and wanted to share something. I don't know if it's the correct way to do it but it seems to have worked out quite fine for me. I think you can obtain the average latent code from a trained StyleGAN by doing a forward pass on any sample (i.e. random latent code) and getting what's stored in the appropriate variable as Gs.get_var('dlatent_avg') on tensorflow side. I did this once and saved the average latent code, converted the checkpoint to pytorch, initialized the pytorch cehckpoint using the saved average latent code, sampled a bunch of latent code, generated images on pytorch side then provided all of this to DatasetGAN. Let me know if you have any further questions.

Edit: I only worked on StyleGAN, so I don't know if what I said translates to StyleGANv2 as well.
Edit: I found this line that I haven't seen before on the StyleGAN readme:

The average w needed to manually perform the truncation trick can be looked up using Gs.get_var('dlatent_avg').

@ericchansen
Copy link
Author

@PoissonChasseur, that's helpful. Can you please elaborate more about bullet 1 - 3 of step 2? Example code would be helpful.

@PoissonChasseur
Copy link

PoissonChasseur commented Aug 7, 2021

@ericchansen : For step 2, bullet 1 and 3, like I said in step 1, you can look at what is done in this link: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb (manually copy and paste the link if it does not go to the corresponding page)

If you compare the code in this reference with the one used in the DatasetGAN, you will see that most of the code related to the different layers is exactly the same. So, you can simply start from cell 11 ("In [11]:") and call on the StyleGAN layer classes which are defined in the DatasetGAN files (no need to copy and paste those defined in this reference to make it work and is also preferable to do so at my sense at this point). This reference also explains directly how to use the pre-trained StyleGAN to generate images using random "latent_code".

Extracting the part of the StyleGAN related to the "mapping network" (g_mapping) (see cell 11) should also be pretty clear there, as they also build the 3 main parts of the StyleGAN which you will need after that for the DatasetGAN.

@ericchansen
Copy link
Author

ericchansen commented Aug 10, 2021

To re-iterate the issue, this repo is currently missing documentation on two required input files. In the configuration files (https://github.com/nv-tlabs/datasetGAN_release/tree/master/datasetGAN/experiments), they're called "average_latent" and "annotation_image_latent_path". There is no documentation or example code on how to create these files.

Here's my solution. I think that anyone should be able to use this code to create "average_latent" and "annotation_image_latent_path" for their own datasets.

For the imports in this piece of code to "play nice", put it into the "utils" directory.

"""
Generate the files that are needed to run DatasetGAN from a PyTorch
StyleGAN checkpoint:

    1. the average w latent code,
    2. images generated by StyleGAN and
    3. the upsampled latent code from StyleGAN that is used to generate each
       image.

In the configuration files of NVIDIA's DatasetGAN repo,
https://github.com/nv-tlabs/datasetGAN_release/tree/master/datasetGAN/experiments,
there are two fields, "average_latent" and "annotation_image_latent_path".

    * "average_latent" is the path to a NumPy binary of the average w latent
      code. For NVIDIA's cat example, the shape is (18, 512).

    * "annotation_image_latent_path" is the path to a NumPy binary. The binary
      contains the upsampled latent code from StyleGAN that is used to
      generate each of the images in your DatasetGAN training (and testing?)
      set. For NVIDIA's cat example, the shape is (30, 512).

Images will be generated with the names "image_{i}.jpg". The matching mask
must be manually created and should get the name "image_mask{i}.jpg".

DatasetGAN repo:
https://github.com/nv-tlabs/datasetGAN_release

NVIDIA's PyTorch checkpoints, which this function uses, can be downloaded
here:
https://drive.google.com/drive/folders/1Hhu8aGxbnUtK-yHRD9BXZ3gn3bNNmsGi
"""

import argparse
import logging
import logging.config
import os
import sys
from collections import OrderedDict
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import yaml
from matplotlib import pyplot
from PIL import Image

from utils import Interpolate, latent_to_image

DIR_MODELS = "../models"
sys.path.append(DIR_MODELS)
from stylegan1 import G_mapping, G_synthesis, Truncation

DIR_DATASETGAN = "../datasetGAN"
sys.path.append(DIR_DATASETGAN)
from train_interpreter import prepare_stylegan


logger = logging.getLogger(__name__)
with open("logging.yml", "r") as f:
    config = yaml.safe_load(f.read())
logging.config.dictConfig(config)


DEFAULT_RESOLUTION = 256
DEFAULT_MAX_LAYERS = 7
DEFAULT_NUM_IMAGES = 30  # nvidia used 30 images for cats
DEFAULT_UPSAMPLE_MODE = "bilinear"
DEFAULT_PATH_AVG_LATENT = Path("avg_latent.npy")
DEFAULT_PATH_LATENT_USED_TO_GENERATE_IMAGES = "latent_stylegan1.npy"


# from train_interprer.py in nvidia's datasetgan repo. not sure if needed here
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
torch.manual_seed(0)
device_ids = [0]  # TODO: move this


def load_stylegan_checkpoint(
    stylegan_checkpoint_path,
    average_w_latent_code: Optional[np.ndarray] = None,  # or torch.Tensor
    which_layers_to_load: str = "all",
    resolution: int = DEFAULT_RESOLUTION,
    device: str = "cuda",
):
    """
    Load all of StyleGAN or load pieces of the network.

    NVIDIA's PyTorch checkpoints, which this function uses, can be downloaded
    here:
    https://drive.google.com/drive/folders/1Hhu8aGxbnUtK-yHRD9BXZ3gn3bNNmsGi
    """
    assert which_layers_to_load in ["all", "mapping", "mapping and truncation"]
    assert type(average_w_latent_code) in (type(None), np.ndarray, torch.Tensor)

    if type(average_w_latent_code) == np.ndarray:
        average_w_latent_code = torch.from_numpy(average_w_latent_code)
        if device != "cpu":
            average_w_latent_code = average_w_latent_code.to(device)

    state_dictionary = torch.load(stylegan_checkpoint_path)

    if which_layers_to_load == "all":
        model = nn.Sequential(
            OrderedDict(
                [
                    ("g_mapping", G_mapping()),
                    ("truncation", Truncation(average_w_latent_code, device=device,),),
                    ("g_synthesis", G_synthesis(resolution=resolution)),
                ]
            )
        )
        model.load_state_dict(state_dictionary)

    elif which_layers_to_load == "mapping and truncation":
        mapping_and_truncation_state_dictionary = OrderedDict(
            list(state_dictionary.items())[:16]
        )  # remove everything besides the mapping and truncation network
        model = nn.Sequential(
            OrderedDict(
                [
                    ("g_mapping", G_mapping()),
                    ("truncation", Truncation(average_w_latent_code, device=device),),
                ]
            )
        )
        model.load_state_dict(mapping_and_truncation_state_dictionary)

    elif which_layers_to_load == "mapping":
        mapping_state_dictionary = OrderedDict(
            list(state_dictionary.items())[:16]
        )  # remove everything besides mapping network
        model = nn.Sequential(OrderedDict([("g_mapping", G_mapping())]))
        model.load_state_dict(mapping_state_dictionary)

    model.eval()
    model.to(device)
    return model


def generate_random_z_latent_code(n: int = 1, device: str = "cuda"):
    return torch.randn(n, 512, device=device)


def generate_random_z_latent_code_w_numpy(n: int = 1, device: str = "cuda"):
    random_vector = np.random.randn(n, 512)
    random_vector = torch.from_numpy(random_vector).type(torch.FloatTensor).to(device)
    return random_vector


def generate_average_w_latent_code(
    model,
    path_to_save: Path = DEFAULT_PATH_AVG_LATENT,
    num_w_latent_codes_to_generate: int = 10000,
    device: str = "cuda",
):
    """
    Generate the average w latent code. This is a pre-requisite for DatasetGAN.

    Loops through generating random inputs for StyleGAN, gives the inputs to
    StyleGAN, collects the outputs, calculates the element-wise average and
    then save the average as a NumPy binary.
    """
    w_latent_code_list = []
    for i in range(0, num_w_latent_codes_to_generate):
        random_tensor = generate_random_z_latent_code(device=device)
        w_latent_code = model.g_mapping(random_tensor)
        [np_arr] = w_latent_code.cpu().detach().numpy()
        w_latent_code_list.append(np_arr)
    average_w_latent_code = np.mean(w_latent_code_list, axis=0)
    if path_to_save is not None:
        np.save(path_to_save, average_w_latent_code)
    return average_w_latent_code


def generate_latent_space_and_image_pairs_using_nvidia_code(
    model,
    upsamplers,
    output_image_dir: Path,
    num_images: int,
    resolution: int,
    path_latent_used_to_generate_images: Path = DEFAULT_PATH_LATENT_USED_TO_GENERATE_IMAGES,
    device: str = "cuda",
):
    """
    Generate image and latent space pairs.
    """
    logger.debug(f"num images to generate: {num_images}")
    w_latent_code_list = []
    with torch.no_grad():
        for i in range(0, num_images):
            torch.cuda.empty_cache()
            random_tensor = generate_random_z_latent_code(
                n=1, device=device
            )  # nvidia asserts that length of tensor == 1
            [numpy_permuted_img], affine_layer_upsamples = latent_to_image(
                model,
                upsamplers,
                random_tensor,
                dim=resolution,
            )
            pyplot.imsave(Path(output_image_dir, f"image_{i}.jpg"), numpy_permuted_img)
            [random_tensor_in_format_needed] = random_tensor.cpu().detach().numpy()
            w_latent_code_list.append(random_tensor_in_format_needed)
    np.save(path_latent_used_to_generate_images, w_latent_code_list)
    logger.info(f"generated {num_images} images: {output_image_dir}")
    logger.info(f"saved latent vectors: {path_latent_used_to_generate_images}")


def main(
    stylegan_checkpoint_path: Path,
    output_image_dir: Path,
    num_images: int,
    path_average_w_latent_code: Path = DEFAULT_PATH_AVG_LATENT,
    path_latent_used_to_generate_images: Path = DEFAULT_PATH_LATENT_USED_TO_GENERATE_IMAGES,
    resolution: int = DEFAULT_RESOLUTION,
    max_layers: int = DEFAULT_MAX_LAYERS,
    device: str = "cuda",
):
    if output_image_dir.exists():
        logger.debug(f"already exists: {output_image_dir}")
    else:
        output_image_dir.mkdir()
        logger.debug(f"mkdir: {output_image_dir}")

    # 1. generate average latent code
    mapping_network = load_stylegan_checkpoint(stylegan_checkpoint_path, device=device)
    average_w_latent_code = generate_average_w_latent_code(
        mapping_network,
        path_to_save=path_average_w_latent_code,
        resolution=resolution,
        device=device,
    )

    # 2. load the network with average latent code
    args_for_nvidia = {
        "stylegan_ver": "1",
        "category": "cat",
        "average_latent": path_average_w_latent_code,
        "stylegan_checkpoint": stylegan_checkpoint_path,
        "dim": [resolution, resolution, 4992],  # TODO: what does the 4992 mean?
        "upsample_mode": "bilinear",
    }
    entire_network, other_average_w_latent_code, upsamplers = prepare_stylegan(
        args_for_nvidia
    )

    # 3. generate images and save associated z latent code
    generate_latent_space_and_image_pairs_using_nvidia_code(
        entire_network,
        upsamplers,
        output_image_dir,
        num_images,
        resolution=resolution,
        path_latent_used_to_generate_images=path_latent_used_to_generate_images,
        device=device,
    )
main.__doc__ = __doc__


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=Path)
    parser.add_argument("output_image_dir", type=Path, metavar="output-image-dir")
    parser.add_argument(
        "num_images", type=int, metavar="num-images", default=DEFAULT_NUM_IMAGES
    )
    parser.add_argument(
        "--path-average-w-latent-code", type=Path, default=DEFAULT_PATH_AVG_LATENT
    )
    parser.add_argument(
        "--path-latent-used-to-generate-images",
        type=Path,
        default=DEFAULT_PATH_LATENT_USED_TO_GENERATE_IMAGES,
    )
    parser.add_argument("--resolution", type=int, default=DEFAULT_RESOLUTION)
    parser.add_argument("--max-layers", type=int, default=DEFAULT_MAX_LAYERS)
    parser.add_argument("--disable-cuda", action="store_true")
    opts = parser.parse_args()

    if opts.disable_cuda:
        device = "cpu"
    elif torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    logger.debug(f"pytorch device: {device}")

    main(
        opts.checkpoint,
        opts.output_image_dir,
        opts.num_images,
        path_average_w_latent_code=opts.path_average_w_latent_code,
        path_latent_used_to_generate_images=opts.path_latent_used_to_generate_images,
        resolution=opts.resolution,
        device=device,
    )

Also, since my code uses a logging configuration file, here's an example of that file.

version: 1
formatters:
  simple:
    format: "[%(name)s][L%(lineno)d][%(levelname)s] %(message)s"
  complex:
    format: "[%(name)s][L%(lineno)d][%(levelname)s][%(asctime)s] %(message)s"
handlers:
  console:
    class: logging.StreamHandler
    level: DEBUG
    formatter: simple
  file:
    class: logging.handlers.TimedRotatingFileHandler
    when: midnight
    backupCount: 5
    level: DEBUG
    formatter: simple
    filename: debug.log
loggers:
  __main__:
    level: DEBUG
    handlers: [console]
    propagate: yes

@arieling or other folks from NVIDIA, is this the correct approach?

Someday, I'll open up a PR with this and updates to the README to include the missing steps.

@ericchansen
Copy link
Author

For those out there who are wondering how they can use DatasetGAN with their own datasets, PR #21 should help.

I still think that

  1. the code in the PR should be double checked before merging, and
  2. NVIDIA should add some documentation on these steps

before we close this issue.

@arieling
Copy link
Collaborator

Hi, thank you for pointing this out.

We update the doc in the section "Create your own model" to explain the npy cache files.
A function datasetGAN/make_training_data.py is also updated to create the caches from scratch.

@arieling
Copy link
Collaborator

arieling commented Aug 18, 2021

For those out there who are wondering how they can use DatasetGAN with their own datasets, PR #21 should help.

I still think that

  1. the code in the PR should be double checked before merging, and
  2. NVIDIA should add some documentation on these steps

before we close this issue.

For your specific confusion about average_latent.npy. This is used for sampling truncation. We dump the average latent code to .npy file mainly for academic reproduction reasons.

For your convenience, we implemented a simple function in https://github.com/nv-tlabs/datasetGAN_release/blob/master/datasetGAN/make_training_data.py#L81 to dump the average latent.

Please refer to StyleGAN V1 ( https://github.com/NVlabs/stylegan ) code and paper [1] for more details

[1] A Style-Based Generator Architecture for Generative Adversarial Networks, Tero Karras, Samuli Laine, Timo Aila

@ericchansen
Copy link
Author

Good enough!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants