forked from dalab/hessian-rank
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinitializers.py
37 lines (28 loc) · 969 Bytes
/
initializers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from jax.nn.initializers import glorot_normal, normal, ones, zeros, uniform, he_uniform
from jax import random
import jax.numpy as jnp
def orthogonal_init():
"""Implements orthogonal initialization, i.e. sampling w.r.t. Haar measure over the space of orthogonal matrices"""
def init(key, shape):
W = 1 / shape[0] * random.normal(key, shape)
if shape[0] < shape[1]:
Q, _ = jnp.linalg.qr(W.T)
return Q.T
else:
Q, _ = jnp.linalg.qr(W)
return Q
return init
def uniform_init():
"""Uniform initialization"""
def init(key, shape):
W = 1 / shape[0] * random.uniform(key, shape)
return W
return init
def get_init(name):
"""Helper function returning the desired initialization scheme"""
if name == 'orthogonal':
return orthogonal_init
if name == 'uniform':
return uniform_init
if name == 'glorot':
return glorot_normal