Implementation of mutual learning model between VAE and GMM.
This idea of integrating probability models is based on this paper: Neuro-SERKET: Development of Integrative Cognitive System through the Composition of Deep Probabilistic Generative Models.
Symbol Emergence in Robotics tool KIT(SERKET) is a framework that allows integration and partitioning of probabilistic generative models.
This is a Graphical Model of VAE+GMM model:
VAE and GMM share the latent variable x.
x is a variable that follows a multivariate normal distribution and is estimated by VAE.
The training will be conducted in the following sequence.
- VAE estimates latent variable(x) and sends latent variables(x) to GMM.
- GMM clusters latent variables(x) sent from VAE and sends mean and variance parameters of the Gaussian distribution to VAE.
- Return to 1 again.
What this repo contains:
main.py
: Main code for training model.vae_module.py
: A training program for VAE, running in main.py.gmm_module.py
: A training program for GMM, running in main.py.tool.py
: Various functions handled in the program.
You can train the VAE+GMM model by running main.py
.
train_model()
can be made to train VAE+GMM.decode_from_gmm_param()
makes image reconstruction from parameters of posterior distribution estimated by GMM.
def main():
# training VAE+GMM model
train_model(mutual_iteration=2, # The number of mutual learning
dir_name=dir_name,
train_loader=train_loader, # Dataloader for training
all_loader=all_loader) # Dataloader when inference latent variables for all data points
# reconstruct image
load_iteration = 1 # Which iteration of the mutual learning model to load
decode_from_gmm_param(iteration=load_iteration,
decode_k=1, # The cluster number of the Gaussian distribution to be used as input for decoder.
sample_num=16, # The number of samples for the random variable.
model_dir=dir_name)
You need to have pytorch >= v0.4.1 and cuda drivers installed.
My environment is the following Pytorch==1.5.1+cu101, CUDA==10.1
Left : without mutual learning・Right : with mutual learning
Plot using TSNE
Red line is ELBO before mutual learning, Blue line is ELBO after mutual learning
Vertical axis is training iteration of VAE, Horizontal axis is ELBO of VAE
(In general, the higher the ELBO, the better)
Results of clustering performance by accuracy(Addresses clustering performance in GMM within VAE+GMM)
Left : without mutual learning・Right : with mutual learning
Vertical axis is training iteration of GMM, Horizontal axis is accuracy
GMM performs clustering on latent variables of VAE. By sampling random variables from posterior distribution estimated by GMM and using them as input to VAE decoder, the image can be reconstructed.
"x" represents the mean parameter of the normal distribution for each cluster.
In this example, a random variable is sampled from a Gaussian distribution with K=1.
The implementation of GMM is based on 【Python】4.4.2:ガウス混合モデルにおける推論:ギブスサンプリング【緑ベイズ入門のノート】