-
Notifications
You must be signed in to change notification settings - Fork 4
/
ApeGNN_HT.py
233 lines (195 loc) · 10.4 KB
/
ApeGNN_HT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import torch
import torch.nn as nn
from scipy.special import factorial, comb
import scipy.sparse as sp
class GraphConv(nn.Module):
"""
Graph Convolutional Network
"""
def __init__(self, n_hops, n_users, n_items, interact_mat, deg,
e=1e-7, edge_dropout_rate=0.5, mess_dropout_rate=0.1,
t_u=2, t_i=2, device=6):
super(GraphConv, self).__init__()
self.interact_mat = interact_mat
self.n_users = n_users
self.n_items = n_items
self.n_hops = n_hops
self.edge_dropout_rate = edge_dropout_rate
self.mess_dropout_rate = mess_dropout_rate
self.dropout = nn.Dropout(p=mess_dropout_rate) # mess dropout
self.dev = device
# set (n_users+n_items) t
initializer = nn.init.xavier_uniform_
self.user_t = initializer(torch.empty(self.n_users, 1))
self.item_t = initializer(torch.empty(self.n_items, 1))
self.user_t = nn.Parameter(self.user_t, requires_grad=True)
self.item_t = nn.Parameter(self.item_t, requires_grad=True)
self.degree = torch.sigmoid(torch.log(torch.from_numpy(deg) + e))
t_u, t_i = torch.split(self.degree, [self.n_users, self.n_items], dim=0)
self.user_t.data = t_u
self.item_t.data = t_i
def _sparse_dropout(self, x, rate=0.5):
noise_shape = x._nnz()
random_tensor = rate
random_tensor += torch.rand(noise_shape).to(x.device)
dropout_mask = torch.floor(random_tensor).type(torch.bool)
i = x._indices()
v = x._values()
i = i[:, dropout_mask]
v = v[dropout_mask]
out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device)
return out * (1. / (1 - rate))
def forward(self, user_embed, item_embed, mess_dropout=True, edge_dropout=True):
all_embed = torch.cat([user_embed, item_embed], dim=0)
agg_embed = all_embed
u_weight = (torch.exp(-self.user_t).to(self.dev) * torch.pow(self.user_t, torch.FloatTensor([0]).to(self.dev))
.to(self.dev) / torch.FloatTensor([factorial(0)]).to(self.dev))
i_weight = (torch.exp(-self.item_t).to(self.dev) * torch.pow(self.item_t, torch.FloatTensor([0]).to(self.dev))
.to(self.dev) / torch.FloatTensor([factorial(0)]).to(self.dev))
ego_embeddings = torch.cat([u_weight * user_embed, i_weight * item_embed], dim=0)
embs = [ego_embeddings]
for k in range(1, self.n_hops + 1):
interact_mat = self._sparse_dropout(self.interact_mat,
self.edge_dropout_rate) if edge_dropout \
else self.interact_mat
side_embeddings = torch.sparse.mm(interact_mat, agg_embed)
# agg_embed = F.normalize(agg_embed)
user_embedds, item_embedds = torch.split(side_embeddings, [self.n_users, self.n_items],
dim=0)
user_embedds = user_embedds * (torch.exp(-self.user_t).to(self.dev) *
torch.pow(self.user_t, torch.FloatTensor([k]).to(self.dev)).to(self.dev)
/ torch.FloatTensor([factorial(k)]).to(self.dev))
item_embedds = item_embedds * (torch.exp(-self.item_t).to(self.dev) *
torch.pow(self.item_t, torch.FloatTensor([k]).to(self.dev)).to(self.dev)
/ torch.FloatTensor([factorial(k)]).to(self.dev))
side_embeddings_cur = torch.cat([user_embedds, item_embedds], dim=0)
agg_embed = side_embeddings
if mess_dropout:
agg_embed = self.dropout(agg_embed)
embs.append(side_embeddings_cur)
embs = torch.stack(embs, dim=1) # [n_entity, n_hops+1, emb_size]
return embs[:self.n_users, :], embs[self.n_users:, :]
class HeatKernel(nn.Module):
def __init__(self, data_config, args_config, adj_mat, deg):
super(HeatKernel, self).__init__()
self.n_users = data_config['n_users']
self.n_items = data_config['n_items']
self.adj_mat = adj_mat
self.deg = deg
self.e = args_config.e
self.decay = args_config.l2
self.emb_size = args_config.dim
self.context_hops = args_config.context_hops
self.mess_dropout = args_config.mess_dropout
self.mess_dropout_rate = args_config.mess_dropout_rate
self.edge_dropout = args_config.edge_dropout
self.edge_dropout_rate = args_config.edge_dropout_rate
self.pool = args_config.pool
self.n_negs = args_config.n_negs
self.ns = args_config.ns
self.K = args_config.K
self.heatkernel = args_config.heatkernel
self.device = torch.device("cuda:0") if args_config.cuda else torch.device("cpu")
self.t_u = args_config.t_u
self.t_i = args_config.t_i
self._init_weight()
self.user_embed = nn.Parameter(self.user_embed)
self.item_embed = nn.Parameter(self.item_embed)
self.gcn = self._init_model()
def _init_weight(self):
initializer = nn.init.xavier_uniform_
self.user_embed = initializer(torch.empty(self.n_users, self.emb_size))
self.item_embed = initializer(torch.empty(self.n_items, self.emb_size))
# [n_users+n_items, n_users+n_items]
self.sparse_norm_adj = self._convert_sp_mat_to_sp_tensor(self.adj_mat).to(self.device)
def _init_model(self):
return GraphConv(n_hops=self.context_hops,
n_users=self.n_users,
n_items=self.n_items,
interact_mat=self.sparse_norm_adj,
deg=self.deg,
e=self.e,
edge_dropout_rate=self.edge_dropout_rate,
mess_dropout_rate=self.mess_dropout_rate,
t_u=self.t_u,
t_i=self.t_i,
device=self.device)
def _convert_sp_mat_to_sp_tensor(self, X):
coo = X.tocoo()
i = torch.LongTensor([coo.row, coo.col])
v = torch.from_numpy(coo.data).float()
return torch.sparse.FloatTensor(i, v, coo.shape)
def forward(self, batch=None):
user = batch['users']
pos_item = batch['pos_items']
neg_item = batch['neg_items'] # [batch_size, n_negs * K]
user_gcn_emb, item_gcn_emb = self.gcn(self.user_embed,
self.item_embed,
edge_dropout=self.edge_dropout,
mess_dropout=self.mess_dropout)
if self.ns == 'rns': # n_negs = 1
neg_gcn_embs = item_gcn_emb[neg_item[:, :self.K]]
else:
neg_gcn_embs = []
for k in range(self.K):
neg_gcn_embs.append(self.negative_sampling(user_gcn_emb, item_gcn_emb,
user, neg_item[:, k * self.n_negs: (k + 1) * self.n_negs],
pos_item))
neg_gcn_embs = torch.stack(neg_gcn_embs, dim=1)
return self.create_bpr_loss(user_gcn_emb[user], item_gcn_emb[pos_item], neg_gcn_embs)
def negative_sampling(self, user_gcn_emb, item_gcn_emb, user, neg_candidates, pos_item):
batch_size = user.shape[0]
s_e, p_e = user_gcn_emb[user], item_gcn_emb[pos_item] # [batch_size, n_hops+1, channel]
if self.pool != 'concat':
s_e = self.pooling(s_e).unsqueeze(dim=1)
"""positive mixing"""
seed = torch.rand(batch_size, 1, p_e.shape[1], 1).to(p_e.device) # (0, 1)
n_e = item_gcn_emb[neg_candidates] # [batch_size, n_negs, n_hops, channel]
n_e_ = seed * p_e.unsqueeze(dim=1) + (1 - seed) * n_e # mixing
"""hop mixing"""
scores = (s_e.unsqueeze(dim=1) * n_e_).sum(dim=-1) # [batch_size, n_negs, n_hops+1]
indices = torch.max(scores, dim=1)[1].detach()
neg_items_emb_ = n_e_.permute([0, 2, 1, 3]) # [batch_size, n_hops+1, n_negs, channel]
# [batch_size, n_hops+1, channel]
return neg_items_emb_[[[i] for i in range(batch_size)],
range(neg_items_emb_.shape[1]), indices, :]
def pooling(self, embeddings):
# [-1, n_hops, channel]
if self.pool == 'mean':
return embeddings.mean(dim=1)
elif self.pool == 'sum':
return embeddings.sum(dim=1)
elif self.pool == 'concat':
return embeddings.view(embeddings.shape[0], -1)
else: # final
return embeddings[:, -1, :]
def generate(self, split=True):
user_gcn_emb, item_gcn_emb = self.gcn(self.user_embed,
self.item_embed,
edge_dropout=False,
mess_dropout=False)
user_gcn_emb, item_gcn_emb = self.pooling(user_gcn_emb), self.pooling(item_gcn_emb)
if split:
return user_gcn_emb, item_gcn_emb
else:
return torch.cat([user_gcn_emb, item_gcn_emb], dim=0)
def rating(self, u_g_embeddings=None, i_g_embeddings=None):
return torch.matmul(u_g_embeddings, i_g_embeddings.t())
def create_bpr_loss(self, user_gcn_emb, pos_gcn_embs, neg_gcn_embs):
# user_gcn_emb: [batch_size, n_hops+1, channel]
# pos_gcn_embs: [batch_size, n_hops+1, channel]
# neg_gcn_embs: [batch_size, K, n_hops+1, channel]
batch_size = user_gcn_emb.shape[0]
u_e = self.pooling(user_gcn_emb)
pos_e = self.pooling(pos_gcn_embs)
neg_e = self.pooling(neg_gcn_embs.view(-1, neg_gcn_embs.shape[2], neg_gcn_embs.shape[3])).view(batch_size,
self.K, -1)
pos_scores = torch.sum(torch.mul(u_e, pos_e), axis=1)
neg_scores = torch.sum(torch.mul(u_e.unsqueeze(dim=1), neg_e), axis=-1) # [batch_size, K]
mf_loss = torch.mean(torch.log(1 + torch.exp(neg_scores - pos_scores.unsqueeze(dim=1)).sum(dim=1)))
# cul regularizer
regularize = (torch.norm(user_gcn_emb[:, 0, :]) ** 2
+ torch.norm(pos_gcn_embs[:, 0, :]) ** 2
+ torch.norm(neg_gcn_embs[:, :, 0, :]) ** 2) / 2 # take hop=0
emb_loss = self.decay * regularize / batch_size
return mf_loss + emb_loss, mf_loss, emb_loss