-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmatrix_utils.py
44 lines (40 loc) · 1.43 KB
/
matrix_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
'''
Acquired from the official implementation repo of "Invertible Residual Networks"
https://github.com/jhjacobsen/invertible-resnet
'''
import torch
def power_series_matrix_logarithm_trace(Fx, x, k, n):
"""
Fast-boi Tr(Ln(d(Fx)/dx)) using power-series approximation
biased but fast
:param Fx: output of f(x)
:param x: input
:param k: number of power-series terms to use
:param n: number of Hitchinson's estimator samples
:return: Tr(Ln(I + df/dx))
"""
# trace estimation including power series
outSum = Fx.sum(dim=0)
dim = list(outSum.shape)
dim.insert(0, n)
dim.insert(0, x.size(0))
u = torch.randn(dim).to(x.device)
trLn = 0
for j in range(1, k + 1):
if j == 1:
vectors = u
# compute vector-jacobian product
vectors = [torch.autograd.grad(Fx, x, grad_outputs=vectors[:, i],
retain_graph=True, create_graph=True)[0] for i in range(n)]
# compute summand
vectors = torch.stack(vectors, dim=1)
vjp4D = vectors.view(x.size(0), n, 1, -1)
u4D = u.view(x.size(0), n, -1, 1)
summand = torch.matmul(vjp4D, u4D)
# add summand to power series
if (j + 1) % 2 == 0:
trLn += summand / torch.Tensor([j]).to(summand.device)
else:
trLn -= summand / torch.Tensor([j]).to(summand.device)
trace = trLn.mean(dim=1).squeeze()
return trace