Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goodman1204 committed Dec 10, 2020
1 parent b2f33ad commit 1c1b484
Show file tree
Hide file tree
Showing 31 changed files with 4,049 additions and 91 deletions.
Binary file added data/ind.citeseer.allx
Binary file not shown.
Binary file added data/ind.citeseer.ally
Binary file not shown.
Binary file added data/ind.citeseer.graph
Binary file not shown.
1,000 changes: 1,000 additions & 0 deletions data/ind.citeseer.test.index

Large diffs are not rendered by default.

Binary file added data/ind.citeseer.tx
Binary file not shown.
Binary file added data/ind.citeseer.ty
Binary file not shown.
Binary file added data/ind.citeseer.x
Binary file not shown.
Binary file added data/ind.citeseer.y
Binary file not shown.
Binary file added data/ind.cora.allx
Binary file not shown.
Binary file added data/ind.cora.ally
Binary file not shown.
Binary file added data/ind.cora.graph
Binary file not shown.
1,000 changes: 1,000 additions & 0 deletions data/ind.cora.test.index

Large diffs are not rendered by default.

Binary file added data/ind.cora.tx
Binary file not shown.
Binary file added data/ind.cora.ty
Binary file not shown.
Binary file added data/ind.cora.x
Binary file not shown.
Binary file added data/ind.cora.y
Binary file not shown.
Binary file added data/ind.pubmed.allx
Binary file not shown.
Binary file added data/ind.pubmed.ally
Binary file not shown.
Binary file added data/ind.pubmed.graph
Binary file not shown.
1,000 changes: 1,000 additions & 0 deletions data/ind.pubmed.test.index

Large diffs are not rendered by default.

Binary file added data/ind.pubmed.tx
Binary file not shown.
Binary file added data/ind.pubmed.ty
Binary file not shown.
Binary file added data/ind.pubmed.x
Binary file not shown.
Binary file added data/ind.pubmed.y
Binary file not shown.
164 changes: 164 additions & 0 deletions estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import numpy as np
import torch
import torch.nn.functional as F


def logmeanexp_diag(x, device='cpu'):
"""Compute logmeanexp over the diagonal elements of x."""
batch_size = x.size(0)

logsumexp = torch.logsumexp(x.diag(), dim=(0,))
num_elem = batch_size

return logsumexp - torch.log(torch.tensor(num_elem).float()).to(device)


def logmeanexp_nodiag(x, dim=None, device='cpu'):
batch_size = x.size(0)
if dim is None:
dim = (0, 1)

logsumexp = torch.logsumexp(
x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim)

try:
if len(dim) == 1:
num_elem = batch_size - 1.
else:
num_elem = batch_size * (batch_size - 1.)
except ValueError:
num_elem = batch_size - 1
return logsumexp - torch.log(torch.tensor(num_elem)).to(device)


def tuba_lower_bound(scores, log_baseline=None):
if log_baseline is not None:
scores -= log_baseline[:, None]

# First term is an expectation over samples from the joint,
# which are the diagonal elmements of the scores matrix.
joint_term = scores.diag().mean()

# Second term is an expectation over samples from the marginal,
# which are the off-diagonal elements of the scores matrix.
marg_term = logmeanexp_nodiag(scores).exp()
return 1. + joint_term - marg_term


def nwj_lower_bound(scores):
return tuba_lower_bound(scores - 1.)


def infonce_lower_bound(scores):
nll = scores.diag().mean() - scores.logsumexp(dim=1)
# Alternative implementation:
# nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size))
mi = torch.tensor(scores.size(0)).float().log() + nll
mi = mi.mean()
return mi


def js_fgan_lower_bound(f):
"""Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016)."""
f_diag = f.diag()
first_term = -F.softplus(-f_diag).mean()
n = f.size(0)
second_term = (torch.sum(F.softplus(f)) -
torch.sum(F.softplus(f_diag))) / (n * (n - 1.))
return first_term - second_term


def js_lower_bound(f):
"""Obtain density ratio from JS lower bound then output MI estimate from NWJ bound."""
nwj = nwj_lower_bound(f)
js = js_fgan_lower_bound(f)

with torch.no_grad():
nwj_js = nwj - js

return js + nwj_js


def dv_upper_lower_bound(f):
"""
Donsker-Varadhan lower bound, but upper bounded by using log outside.
Similar to MINE, but did not involve the term for moving averages.
"""
first_term = f.diag().mean()
second_term = logmeanexp_nodiag(f)

return first_term - second_term


def mine_lower_bound(f, buffer=None, momentum=0.9):
"""
MINE lower bound based on DV inequality.
"""
if buffer is None:
buffer = torch.tensor(1.0)
first_term = f.diag().mean()

buffer_update = logmeanexp_nodiag(f).exp()
with torch.no_grad():
second_term = logmeanexp_nodiag(f)
buffer_new = buffer * momentum + buffer_update * (1 - momentum)
buffer_new = torch.clamp(buffer_new, min=1e-4)
third_term_no_grad = buffer_update / buffer_new

third_term_grad = buffer_update / buffer_new

return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update


def smile_lower_bound(f, clip=None):
if clip is not None:
f_ = torch.clamp(f, -clip, clip)
else:
f_ = f
z = logmeanexp_nodiag(f_, dim=(0, 1))
dv = f.diag().mean() - z

js = js_fgan_lower_bound(f)

with torch.no_grad():
dv_js = dv - js

return js + dv_js


def estimate_mutual_information(estimator, x, y, critic_fn,
baseline_fn=None, alpha_logit=None, **kwargs):
"""Estimate variational lower bounds on mutual information.
Args:
estimator: string specifying estimator, one of:
'nwj', 'infonce', 'tuba', 'js', 'interpolated'
x: [batch_size, dim_x] Tensor
y: [batch_size, dim_y] Tensor
critic_fn: callable that takes x and y as input and outputs critic scores
output shape is a [batch_size, batch_size] matrix
baseline_fn (optional): callable that takes y as input
outputs a [batch_size] or [batch_size, 1] vector
alpha_logit (optional): logit(alpha) for interpolated bound
Returns:
scalar estimate of mutual information
"""
x, y = x, y
scores = critic_fn(x, y)
if baseline_fn is not None:
# Some baselines' output is (batch_size, 1) which we remove here.
log_baseline = torch.squeeze(baseline_fn(y))
if estimator == 'infonce':
mi = infonce_lower_bound(scores)
elif estimator == 'nwj':
mi = nwj_lower_bound(scores)
elif estimator == 'tuba':
mi = tuba_lower_bound(scores, log_baseline)
elif estimator == 'js':
mi = js_lower_bound(scores)
elif estimator == 'smile':
mi = smile_lower_bound(scores, **kwargs)
elif estimator == 'dv':
mi = dv_upper_lower_bound(scores)
return mi
10 changes: 7 additions & 3 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ def __init__(self,in_features,out_features,dropout=0.,act=torch.relu,bias=True,s
self.sparse_inputs = sparse_inputs

self.weight= Parameter(torch.FloatTensor(in_features,out_features))
self.reset_parameters()

if self.bias:
self.weight_bias = Parameter(torch.FloatTensor(torch.zeros(out_features)))
self.weight_bias = Parameter(torch.FloatTensor(1,out_features))

self.reset_parameters()

def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.weight)
if self.bias:
torch.nn.init.xavier_uniform_(self.weight_bias)

def forward(self,input):
if self.sparse_inputs:
Expand All @@ -94,7 +97,8 @@ def forward(self,input):
output = torch.mm(input,self.weight)

if self.bias:
output += self.bias
output += self.weight_bias #Find the bug, self.bias should be self.weight_bias
# output += self.bias #Find the bug, self.bias should be self.weight_bias

return self.act(output)

Expand Down
89 changes: 59 additions & 30 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from layers import GraphConvolution, GraphConvolutionSparse, Linear, InnerDecoder, InnerProductDecoder
from utils import cluster_acc

from utils_smiles import *
from estimators import estimate_mutual_information

class GCNModelAE(nn.Module):
def __init__(self, input_feat_dim, n_nodes, hidden_dim1, hidden_dim2, dropout,args):
super(GCNModelAE, self).__init__()
Expand Down Expand Up @@ -191,32 +194,33 @@ def loss(self,x,adj,labels, n_nodes, n_features, norm, pos_weight,L=1):
# Loss=L_rec*x.size(1)


self.pi_.data = (self.pi_/self.pi_.sum()).data
# self.pi_.data = (self.pi_/self.pi_.sum()).data
# log_sigma2_c=self.log_sigma2_c
# mu_c=self.mu_c

# z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
z = self.reparameterize(mu,logvar)

yita_c=torch.exp(torch.log(self.pi_.unsqueeze(0))+self.gaussian_pdfs_log(z,self.mu_c,self.log_sigma2_c))+det
# yita_c = F.softmax(yita_c) # is softmax a good way?
gamma_c=torch.exp(torch.log(self.pi_.unsqueeze(0))+self.gaussian_pdfs_log(z,self.mu_c,self.log_sigma2_c))+det
# gamma_c = F.softmax(gamma_c) # is softmax a good way?

yita_c=yita_c/(yita_c.sum(1).view(-1,1)) #shape: batch_size*Clusters
gamma_c=gamma_c/(gamma_c.sum(1).view(-1,1)) #shape: batch_size*Clusters
self.pi_.data = gamma_c.mean(0).data # prior need to be re-normalized? In GMM, prior is based on gamma_c:https://brilliant.org/wiki/gaussian-mixture-model/

# KLD_u_c=(0.5 / n_nodes)*torch.mean(torch.sum(yita_c*torch.sum(self.log_sigma2_c.unsqueeze(0)+\
# KLD_u_c=(0.5 / n_nodes)*torch.mean(torch.sum(gamma_c*torch.sum(self.log_sigma2_c.unsqueeze(0)+\
# torch.exp(2*logvar.unsqueeze(1)-self.log_sigma2_c.unsqueeze(0))+\
# (mu.unsqueeze(1)-self.mu_c.unsqueeze(0)).pow(2)/torch.exp(self.log_sigma2_c.unsqueeze(0)),2),1))

# KLD_u_c-= (0.5/n_nodes)*torch.mean(torch.sum(1+2*logvar,1))
# yita_loss = (1 / self.args.nClusters) * torch.mean(torch.sum(yita_c*torch.log(yita_c/self.pi_.unsqueeze(0)),1)) - (0.5 / hidden_dim2)*torch.mean(torch.sum(1+2*logvar,1))
# gamma_loss = (1 / self.args.nClusters) * torch.mean(torch.sum(gamma_c*torch.log(gamma_c/self.pi_.unsqueeze(0)),1)) - (0.5 / hidden_dim2)*torch.mean(torch.sum(1+2*logvar,1))

KLD_u_c=-(0.5/n_nodes)*torch.mean(torch.sum(yita_c*torch.sum(-1+self.log_sigma2_c.unsqueeze(0)-2*logvar.unsqueeze(1)+
KLD_u_c=-(0.5/n_nodes)*torch.mean(torch.sum(gamma_c*torch.sum(-1+self.log_sigma2_c.unsqueeze(0)-2*logvar.unsqueeze(1)+
torch.exp(2*logvar.unsqueeze(1)-self.log_sigma2_c.unsqueeze(0))+
(mu.unsqueeze(1)-self.mu_c.unsqueeze(0)).pow(2)/torch.exp(self.log_sigma2_c.unsqueeze(0)),2),1))

yita_loss = -(1 / self.args.nClusters) * torch.mean(torch.sum(yita_c*torch.log(yita_c/self.pi_.unsqueeze(0)),1))
gamma_loss = -(1 / self.args.nClusters) * torch.mean(torch.sum(gamma_c*torch.log(gamma_c/self.pi_.unsqueeze(0)),1))

return L_rec_u,-KLD_u_c,-yita_loss
return L_rec_u,-KLD_u_c,-gamma_loss

def pre_train(self,x,adj,Y,pre_epoch=50):
'''
Expand Down Expand Up @@ -285,11 +289,11 @@ def predict(self,mu, logvar):
pi = self.pi_
log_sigma2_c = self.log_sigma2_c
mu_c = self.mu_c
yita_c = torch.exp(torch.log(pi.unsqueeze(0))+self.gaussian_pdfs_log(z,mu_c,log_sigma2_c))
gamma_c = torch.exp(torch.log(pi.unsqueeze(0))+self.gaussian_pdfs_log(z,mu_c,log_sigma2_c))

yita=yita_c.detach().cpu().numpy()
gamma=gamma_c.detach().cpu().numpy()

return np.argmax(yita,axis=1)
return np.argmax(gamma,axis=1),gamma


def gaussian_pdfs_log(self,x,mus,log_sigma2s):
Expand All @@ -312,6 +316,7 @@ class GCNModelVAECE(nn.Module):
def __init__(self, input_feat_dim, n_nodes, hidden_dim1, hidden_dim2, dropout,args):
super(GCNModelVAECE, self).__init__()


self.args = args
self.gc1 = GraphConvolutionSparse(input_feat_dim, hidden_dim1, dropout, act=torch.relu)
self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
Expand All @@ -326,8 +331,14 @@ def __init__(self, input_feat_dim, n_nodes, hidden_dim1, hidden_dim2, dropout,ar


self.pi_=nn.Parameter(torch.FloatTensor(args.nClusters,).fill_(1)/args.nClusters,requires_grad=True)
self.mu_c=nn.Parameter(torch.randn(args.nClusters,hidden_dim2),requires_grad=True)
self.log_sigma2_c=nn.Parameter(torch.randn(args.nClusters,hidden_dim2),requires_grad=True)
self.mu_c=nn.Parameter(torch.FloatTensor(args.nClusters,hidden_dim2).fill_(0),requires_grad=True)
self.log_sigma2_c=nn.Parameter(torch.FloatTensor(args.nClusters,hidden_dim2).fill_(0),requires_grad=True)

# calculate mi

# critic_params = {'dim_x': x.shape[1],'dim_y':y.shape[1],'layers': 2,'embed_dim': 32,'hidden_dim': 64,'activation': 'relu',}
# self.critic_structure = ConcatCritic(hidden_dim2,n_nodes,256,3,'relu',rho=None,)
# self.critic_feature = ConcatCritic(hidden_dim2,input_feat_dim,256,3,'relu',rho=None,)

def encoder(self, x, adj):
hidden1 = self.gc1(x, adj)
Expand Down Expand Up @@ -362,6 +373,14 @@ def dist(self,x):
dn = (norm + norm.view(1, -1)) - 2.0 * (x @ x.t())
return torch.sum(torch.relu(dn).sqrt())

def mi_loss(self,z,x,a):
# critic_params = {'dim_x': x.shape[1],'dim_y':y.shape[1],'layers': 2,'embed_dim': 32,'hidden_dim': 64,'activation': 'relu',}
# critic = ConcatCritic(rho=None,**critic_params)
indice = torch.randperm(len(z))[0:50]
# mi_x = estimate_mutual_information('dv',z[indice],x[indice],self.critic_structure)
mi_a = estimate_mutual_information('js',z[indice],a[indice],self.critic_feature)
return mi_a

def loss(self,x,adj,labels, n_nodes, n_features, norm, pos_weight,L=1):

det=1e-10
Expand All @@ -372,8 +391,14 @@ def loss(self,x,adj,labels, n_nodes, n_features, norm, pos_weight,L=1):
L_rec_u=0
L_rec_a=0

mi=0

mu, logvar, mu_a, logvar_a = self.encoder(x, adj)

# mutual information loss

# z_mu, z_sigma2_log = self.encoder(x)
# mi_a = self.mi_loss(mu,adj.to_dense(),x.to_dense())
for l in range(L):

# z=torch.randn_like(z_mu)*torch.exp(z_sigma2_log/2)+z_mu
Expand All @@ -387,6 +412,7 @@ def loss(self,x,adj,labels, n_nodes, n_features, norm, pos_weight,L=1):
L_rec_u += cost_u
L_rec_a += cost_a


L_rec_u/=L
L_rec_a/=L

Expand All @@ -398,36 +424,39 @@ def loss(self,x,adj,labels, n_nodes, n_features, norm, pos_weight,L=1):
# Loss=L_rec*x.size(1)


self.pi_.data = (self.pi_/self.pi_.sum()).data
# log_sigma2_c=self.log_sigma2_c
# mu_c=self.mu_c

# z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
z = self.reparameterize(mu,logvar)

yita_c=torch.exp(torch.log(self.pi_.unsqueeze(0))+self.gaussian_pdfs_log(z,self.mu_c,self.log_sigma2_c))+det
gamma_c=torch.exp(torch.log(self.pi_.unsqueeze(0))+self.gaussian_pdfs_log(z,self.mu_c,self.log_sigma2_c))+det

gamma_c=gamma_c/(gamma_c.sum(1).view(-1,1))#batch_size*Clusters
# gamma_c=F.softmax(gamma_c)

yita_c=yita_c/(yita_c.sum(1).view(-1,1))#batch_size*Clusters
# yita_c=F.softmax(yita_c)
# self.pi_.data = (self.pi_/self.pi_.sum()).data # prior need to be re-normalized? In GMM, prior is based on gamma_c:https://brilliant.org/wiki/gaussian-mixture-model/
self.pi_.data = gamma_c.mean(0).data # prior need to be re-normalized? In GMM, prior is based on gamma_c:https://brilliant.org/wiki/gaussian-mixture-model/

KLD_u_c=-(0.5/n_nodes)*torch.mean(torch.sum(yita_c*torch.sum(-1+self.log_sigma2_c.unsqueeze(0)-2*logvar.unsqueeze(1)+
KLD_u_c=-(0.5/n_nodes)*torch.mean(torch.sum(gamma_c*torch.sum(-1+self.log_sigma2_c.unsqueeze(0)-2*logvar.unsqueeze(1)+
torch.exp(2*logvar.unsqueeze(1)-self.log_sigma2_c.unsqueeze(0))+
(mu.unsqueeze(1)-self.mu_c.unsqueeze(0)).pow(2)/torch.exp(self.log_sigma2_c.unsqueeze(0)),2),1))

# KLD_u_c=(0.5 / n_nodes)*torch.mean(torch.sum(yita_c*torch.sum(self.log_sigma2_c.unsqueeze(0)+\
# KLD_u_c=(0.5 / n_nodes)*torch.mean(torch.sum(gamma_c*torch.sum(self.log_sigma2_c.unsqueeze(0)+\
# torch.exp(2*logvar.unsqueeze(1)-self.log_sigma2_c.unsqueeze(0))+\
# (mu.unsqueeze(1)-self.mu_c.unsqueeze(0)).pow(2)/torch.exp(self.log_sigma2_c.unsqueeze(0)),2),1))

# mutual_dist = (-1/(self.args.nClusters**2))*self.dist(self.mu_c)

# yita_loss=-(1/self.args.nClusters)*torch.mean(torch.sum(yita_c*torch.log(yita_c),1))
# yita_loss = (1 / self.args.nClusters) * torch.mean(torch.sum(yita_c*torch.log(yita_c),1)) - (0.5 / self.args.hid_dim)*torch.mean(torch.sum(1+2*logvar,1))
yita_loss = -(1 / self.args.nClusters) * torch.mean(torch.sum(yita_c*torch.log(yita_c/self.pi_.unsqueeze(0)),1))
# yita_loss = (1 / self.args.nClusters) * torch.mean(torch.sum(yita_c*torch.log(yita_c/self.pi_.unsqueeze(0)),1)) - (0.5 / self.args.hid_dim)*torch.mean(torch.sum(1+2*logvar,1))
# gamma_loss=-(1/self.args.nClusters)*torch.mean(torch.sum(gamma_c*torch.log(gamma_c),1))
# gamma_loss = (1 / self.args.nClusters) * torch.mean(torch.sum(gamma_c*torch.log(gamma_c),1)) - (0.5 / self.args.hid_dim)*torch.mean(torch.sum(1+2*logvar,1))
gamma_loss = -(1 / self.args.nClusters) * torch.mean(torch.sum(gamma_c*torch.log(gamma_c/self.pi_.unsqueeze(0)),1))
# gamma_loss = (1 / self.args.nClusters) * torch.mean(torch.sum(gamma_c*torch.log(gamma_c/self.pi_.unsqueeze(0)),1)) - (0.5 / self.args.hid_dim)*torch.mean(torch.sum(1+2*logvar,1))


return L_rec_u , L_rec_a , -KLD_u_c ,-KLD_a , -yita_loss
# return L_rec_u + L_rec_a + KLD_u_c + KLD_a + yita_loss
return L_rec_u , L_rec_a , -KLD_u_c ,-KLD_a , -gamma_loss
# return L_rec_u , L_rec_a , -KLD_u_c ,-KLD_a , -gamma_loss,-mi_a
# return L_rec_u + L_rec_a + KLD_u_c + KLD_a + gamma_loss


def pre_train(self,x,adj,Y,pre_epoch=50):
Expand Down Expand Up @@ -498,10 +527,10 @@ def predict(self,mu, logvar):
pi = self.pi_
log_sigma2_c = self.log_sigma2_c
mu_c = self.mu_c
yita_c = torch.exp(torch.log(pi.unsqueeze(0))+self.gaussian_pdfs_log(z,mu_c,log_sigma2_c))
gamma_c = torch.exp(torch.log(pi.unsqueeze(0))+self.gaussian_pdfs_log(z,mu_c,log_sigma2_c))

yita=yita_c.detach().cpu().numpy()
return np.argmax(yita,axis=1)+1
gamma=gamma_c.detach().cpu().numpy()
return np.argmax(gamma,axis=1),gamma


def gaussian_pdfs_log(self,x,mus,log_sigma2s):
Expand All @@ -513,7 +542,7 @@ def gaussian_pdfs_log(self,x,mus,log_sigma2s):

@staticmethod
def gaussian_pdf_log(x,mu,log_sigma2):
return -0.5*(torch.sum(np.log(np.pi*2)+log_sigma2+(x-mu).pow(2)/torch.exp(log_sigma2),1))
return -0.5*(torch.sum(np.log(np.pi*2)+log_sigma2+(x-mu).pow(2)/torch.exp(log_sigma2),1)) # np.pi*2, not square

def check_parameters(self):
for name, param in self.named_parameters():
Expand Down
Loading

0 comments on commit 1c1b484

Please sign in to comment.