Skip to content

[ICCV 2021] A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling"

Notifications You must be signed in to change notification settings

dzld00/pytorch-manifold-matching

Folders and files

NameName
Last commit message
Last commit date

Latest commit

003a8f6 · Mar 10, 2023

History

48 Commits
Aug 24, 2021
Aug 24, 2021
Jul 15, 2022
Jul 25, 2021
Jul 25, 2021
Aug 10, 2021
Mar 10, 2023

Repository files navigation

Manifold Matching via Deep Metric Learning for Generative Modeling

A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021).

Paper: https://arxiv.org/abs/2106.10777

Objective functions

Objective for metric learning:

triplet_loss = triplet_(ml_real_out,ml_real_out_shuffle,ml_fake_out_shuffle)

Objective for manifold matching with learned metric:

g_loss = p_dist + c_dist 

where

ml_real_out = netML(real_img) # real data
ml_fake_out = netML(fake_img) # generated data 

# shuffle in batch
r1=torch.randperm(batch_size)
r2=torch.randperm(batch_size)
ml_real_out_shuffle = ml_real_out[r1[:, None]].view(ml_real_out.shape[0],ml_real_out.shape[-1])
ml_fake_out_shuffle = ml_fake_out[r2[:, None]].view(ml_fake_out.shape[0],ml_fake_out.shape[-1])

# pairwise distances 
pd_r = pairwise_distances(ml_real_out, ml_real_out) 
pd_f = pairwise_distances(ml_fake_out, ml_fake_out)
 
# matching terms 
p_dist =  torch.dist(pd_r,pd_f,2) # matching 2-diameters             
c_dist = torch.dist(ml_real_out.mean(0),ml_fake_out.mean(0),2) # matching centroids  

Training

To train a model for unconditonal generation, run:

python train.py

       

We also tried our objective on generating higher resolution images using a StyleGAN2 data generator and a simple metric generator. Implemenation details can be found here. Below are randomly generated 512x512 samples on FFHQ dataset at ~150K iterations:

Citation

@InProceedings{Dai_2021_ICCV,
    author    = {Dai, Mengyu and Hang, Haibin},
    title     = {Manifold Matching via Deep Metric Learning for Generative Modeling},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {6587-6597}
}

About

[ICCV 2021] A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages