Skip to content

MadsHogenhaug/aml-mini-project

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

52 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generating synthetic drawings with Google QuickDraw

Members:

Mads Høgenhaug, Marcus Friis, Morten Pedersen

This project generates fake drawings using Generative Adversarial Networks (GAN), using the Google Quickdraw dataset.

The dataset

This project uses the Google Quickdraw dataset. It consists of 50 million 28x28 greyscale images across 345 different classes. Google has preprocessed the data by centering and scaling the drawings appropriately. Due to the complexity of the problem, we currently only use and generate cats. However, for future development, we would like to include more classes.

Methodology

Throughout the process, we trained many models with multiple architecture types. The architectures used for the final and best performing generator and discriminator are SuperDeepGenerator and SuperDeepConvDiscriminator respectively.

Generator - SuperDeepGenerator

The generator is a convolutional model, that generates a 28x28 synthetic drawing from a latent vector. It contains 4 transposed convolutional layers, each halving the amount of channels. Between layers, we apply batch normalization and activate the layer with ReLU. For added randomness, and to combat mode collapse, we use dropout on all layers except the last. For the last layer, we convolve to 1 channel, and activate it with sigmoid to scale values between 0 and 1.

The discriminator consists of 4 convolutions with batch normalization and leaky ReLU after each layer, followed with a flattening operation, and two linear layers, lastly activated with sigmoid for binary classification.

Training

For training the model, we instanciate a generator and discriminator. We repeat the following steps for n_epochs.

Step 1

We train the discriminator; we generate fake samples with the generator, and the discriminator predicts whether a drawing is real or fake on both the fake and real data. We calculate the loss of this process, backpropagate and take an optimizer step.

Step 2

The generator generates a new batch. The discriminator predicts, and we calculate the loss as a binary cross entropy, such that the generator is rewarded if it fools the discriminator. We calculate the loss of this process, backpropagate and take an optimizer step.

Results

When a model is trained, it can generate drawings. The discriminator is usually better than the generator, as seen on the loss curve.

While these results are not perfect, they're a good start. Randomness seemed to help the model generalize better, and we observed deeper models performed better than shallower models.

Replicate our results

To replicate our results, run

python train.py

This script can be modified to change hyperparameters, models, data etc. We use ml-flow for tracking experiment results. As such, an ml-flow server must be hosted for the script to execute correctly. We hosted it on a local host with the command

mlflow ui

More models can be added in src/models.py.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 100.0%