Skip to content

Commit

Permalink
* renamed ConvPLCA to SoftConvPLCA to clarify that it uses softmax.
Browse files Browse the repository at this point in the history
* added ProjConvPLCA which simply projects its parameters onto the simplex after every training step.
* made parent class of all PLCA models
  • Loading branch information
huangeddie committed Dec 12, 2020
1 parent ddce01f commit 934907d
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
help='makes the images probability distributions')

# Model
parser.add_argument('--model', choices=['conv-plca', 'deep-plca', 'ae', 'al'],
parser.add_argument('--model', choices=['proj-conv-plca', 'soft-conv-plca', 'deep-plca', 'ae', 'al'],
help='conv-plca, deep plca, auto encoder, auto layer')
parser.add_argument('--nconvs', type=int, default=None,
help='number of convolutions to use per impulse and prior (only for deep-plca)')
Expand Down
6 changes: 4 additions & 2 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@


def make_model(args, channels):
if args.model == 'conv-plca':
model = plca.ConvPLCA(channels, args.imsize, args.nkern, args.kern_size)
if args.model == 'proj-conv-plca':
model = plca.ProjConvPLCA(channels, args.nkern, args.kern_size)
elif args.model == 'soft-conv-plca':
model = plca.SoftConvPLCA(channels, args.imsize, args.nkern, args.kern_size)
elif args.model == 'deep-plca':
model = plca.DeepPLCA(channels, args.imsize, args.nkern, args.kern_size, args.nconvs, args.hdim)
elif args.model == 'ae':
Expand Down
77 changes: 75 additions & 2 deletions models/plca.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from torch.nn import functional as F


class DeepPLCA(nn.Module):
class PLCA(nn.Module):
pass


class DeepPLCA(PLCA):
"""
Priors and impulse are deep CNN functions of the image, while the features are global parameters
"""
Expand Down Expand Up @@ -88,7 +92,76 @@ def forward(self, imgs):
return recon, priors, impulse, feats


class ConvPLCA(nn.Module):
def project_simplex_sort(v, z=1):
"""
Projects a vector onto the simplex
Takes O(d log d) time where d is the dimension of the vector.
https://eng.ucmerced.edu/people/wwang5/papers/SimplexProj.pdf
"""
n_features = v.shape[0]
u = torch.sort(v)[::-1]
cssv = torch.cumsum(u) - z
ind = torch.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
w = torch.maximum(v - theta, torch.zeros_like(v))
return w


class ProjConvPLCA(PLCA):
"""
Let params be the core nkern x channels x kern_size x kern_size parameters that influences everything
The impulse convolutional kernels are generated from a learnable per-kernel affine transformation from params
The feature logits are generated from a learnable per-kernel linear transformation from params
(it’s linear and not affine because the feature logits are then immediately fed into the soft max activation which is shift invariant
The priors are global
"""

def __init__(self, channels, nkern, kern_size):
super().__init__()
self.nkern = nkern

# Core parameters
self.feats = nn.Parameter(torch.rand(nkern, channels, kern_size, kern_size))

# Priors
self.priors = nn.Parameter(torch.rand(1, nkern, 1, 1))

def project_params_to_simplex(self):
"""
Simple algorithm: https://eng.ucmerced.edu/people/wwang5/papers/SimplexProj.pdf
Takes O(d log d) time where d is the number of dimensions.
"""
# Priors
simplex_priors = self.priors.detach()
kern_shape = self.priors.shape[1:]
for i in range(self.nkern):
simplex_priors[i] = project_simplex_sort(simplex_priors[i].flatten()).view(kern_shape)
self.priors.copy_(simplex_priors)

# Features
simplex_feats = self.feats.detach()
kern_shape = self.feats.shape[1:]
for i in range(self.nkern):
simplex_feats[i] = project_simplex_sort(simplex_feats[i].flatten()).view(kern_shape)
self.feats.copy_(simplex_feats)

def forward(self, imgs):
# Impulse
impulse_logits = F.conv2d(imgs, self.feats)
impulse = impulse_logits / torch.sum(impulse_logits, dim=(2, 3), keepdim=True)

# Convolutional transpose
recon = F.conv_transpose2d(self.priors * impulse, self.feats)

# For some reason, when run on CUDA, there can be negative values
# recon.clamp_(min=0)

return recon, self.priors.detach(), impulse, self.feats.detach()


class SoftConvPLCA(PLCA):
"""
Let params be the core nkern x channels x kern_size x kern_size parameters that influences everything
The impulse convolutional kernels are generated from a learnable per-kernel affine transformation from params
Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_recon_loss(args, recon, imgs):
#### Steps

def plca_step(args, model, imgs):
# Deep PLCA
# PLCA
recon, priors, impulse, feat = model(imgs)

# Entropy loss
Expand Down Expand Up @@ -128,7 +128,7 @@ def loop_data(args, model, data_loader, opt=None):
opt.zero_grad()

# Train steps defined seperately for each model
if isinstance(model, models.plca.ConvPLCA) or isinstance(model, models.plca.DeepPLCA):
if isinstance(model, models.plca.PLCA):
loss, recon_loss = plca_step(args, model, imgs)
elif isinstance(model, models.auto.AutoEncoder):
loss, recon_loss = ae_step(args, model, imgs)
Expand All @@ -141,6 +141,8 @@ def loop_data(args, model, data_loader, opt=None):
if training:
loss.backward()
opt.step()
if isinstance(model, models.plca.ProjConvPLCA):
model.project_params_to_simplex()

# Record
recon_losses.append(recon_loss.item())
Expand Down

0 comments on commit 934907d

Please sign in to comment.