-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_smiles.py
82 lines (58 loc) · 2.72 KB
/
utils_smiles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import torch
import torch.nn as nn
def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, cubic=None):
"""Generate samples from a correlated Gaussian distribution."""
x, eps = torch.chunk(torch.randn(batch_size, 2 * dim), 2, dim=1)
y = rho * x + torch.sqrt(torch.tensor(1. - rho**2).float()) * eps
if cubic is not None:
y = y ** 3
return x, y
def rho_to_mi(dim, rho):
"""Obtain the ground truth mutual information from rho."""
return -0.5 * np.log(1 - rho**2) * dim
def mi_to_rho(dim, mi):
"""Obtain the rho for Gaussian give ground truth mutual information."""
return np.sqrt(1 - np.exp(-2.0 / dim * mi))
def mi_schedule(n_iter):
"""Generate schedule for increasing correlation over time."""
mis = np.round(np.linspace(0.5, 5.5 - 1e-9, n_iter)) * 2.0
return mis.astype(np.float32)
def mlp(dim, hidden_dim, output_dim, layers, activation):
"""Create a mlp from the configurations."""
activation = {
'relu': nn.ReLU
}[activation]
seq = [nn.Linear(dim, hidden_dim), activation()]
for _ in range(layers):
seq += [nn.Linear(hidden_dim, hidden_dim), activation()]
seq += [nn.Linear(hidden_dim, output_dim)]
return nn.Sequential(*seq)
class SeparableCritic(nn.Module):
"""Separable critic. where the output value is g(x) h(y). """
def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs):
super(SeparableCritic, self).__init__()
self._g = mlp(dim, hidden_dim, embed_dim, layers, activation)
self._h = mlp(dim, hidden_dim, embed_dim, layers, activation)
def forward(self, x, y):
scores = torch.matmul(self._h(y), self._g(x).t())
return scores
class ConcatCritic(nn.Module):
"""Concat critic, where we concat the inputs and use one MLP to output the value."""
def __init__(self, dim_x,dim_y, hidden_dim, layers, activation, **extra_kwargs):
super(ConcatCritic, self).__init__()
# output is scalar score
self._f = mlp(dim_x+dim_y , hidden_dim, 1, layers, activation)
def forward(self, x, y):
batch_size = x.size(0)
# Tile all possible combinations of x and y
x_tiled = torch.stack([x] * batch_size, dim=0)
y_tiled = torch.stack([y] * batch_size, dim=1)
# xy is [batch_size * batch_size, x_dim + y_dim]
xy_pairs = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [
batch_size * batch_size, -1])
# Compute scores for each x_i, y_j pair.
scores = self._f(xy_pairs)
return torch.reshape(scores, [batch_size, batch_size]).t()
def log_prob_gaussian(x):
return torch.sum(torch.distributions.Normal(0., 1.).log_prob(x), -1)