Skip to content

Commit

Permalink
Learn white_vec parameter of AutoGaussian guide (#2946)
Browse files Browse the repository at this point in the history
* Learn white_vec in AutoGaussian

* Fix bugs

* Link to issue

* Attempt to fix AutoGaussian dispatch

* Fix some tests

* Speed up test_median

* Support more temperatures

* Add xfailing tests

* Fix bug excluding obs sites from prototype_trace

* Fix more bugs

* Fix more tests

* Add failing test of elbo gradient

* lint

* Make test less trivial

* Strengthen tests, make AutoGaussian abstract

* Add has_rsample kwarg to pyro.factor

* Fix tests

* Add has_rsample kwarg to pyro.factor

* Require specification of has_rsample for pyro.factor in guides

* Remove debug statement

* Update AutoGaussian

* Fix scanvi example

* Fix tests

* Fix profiling test

* Remove experimental code

* Bump funsor version

* Pin to Funsor 0.4.2
  • Loading branch information
fritzo authored Dec 14, 2021
1 parent dbc59a0 commit 1877116
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 62 deletions.
158 changes: 108 additions & 50 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def _setup_prototype(self, *args, **kwargs) -> None:

self.locs = PyroModule()
self.scales = PyroModule()
self.factors = PyroModule()
self.white_vecs = PyroModule()
self.prec_sqrts = PyroModule()
self._factors = OrderedDict()
self._plates = OrderedDict()
self._event_numel = OrderedDict()
Expand Down Expand Up @@ -211,18 +212,20 @@ def _setup_prototype(self, *args, **kwargs) -> None:
d_size = min(d_size, u_size) # just an optimization
batch_shape = _plates_to_shape(self._plates[d])

# Create a square root parameter (full, not lower triangular).
# Create parameters of each Gaussian factor.
white_vec = init_loc.new_zeros(batch_shape + (d_size,))
# We initialize with noise to avoid singular gradient.
sqrt = torch.rand(
prec_sqrt = torch.rand(
batch_shape + (u_size, d_size),
dtype=init_loc.dtype,
device=init_loc.device,
)
sqrt.sub_(0.5).mul_(self._init_scale)
prec_sqrt.sub_(0.5).mul_(self._init_scale)
if not site["is_observed"]:
# Initialize the [d,d] block to the identity matrix.
sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2))
prec_sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
deep_setattr(self.white_vecs, d, PyroParam(white_vec, event_dim=1))
deep_setattr(self.prec_sqrts, d, PyroParam(prec_sqrt, event_dim=2))

@staticmethod
def _compress_site(site):
Expand All @@ -243,7 +246,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
if self.prototype_trace is None:
self._setup_prototype(*args, **kwargs)

aux_values = self._sample_aux_values()
aux_values = self._sample_aux_values(temperature=1.0)
values, log_densities = self._transform_values(aux_values)

# Replay via Pyro primitives.
Expand All @@ -268,7 +271,7 @@ def median(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
:rtype: dict
"""
with torch.no_grad(), poutine.mask(mask=False):
aux_values = {name: 0.0 for name in self._factors}
aux_values = self._sample_aux_values(temperature=0.0)
values, _ = self._transform_values(aux_values)
return values

Expand Down Expand Up @@ -299,7 +302,7 @@ def _transform_values(
return values, log_densities

@abstractmethod
def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
raise NotImplementedError


Expand Down Expand Up @@ -331,11 +334,13 @@ def _setup_prototype(self, *args, **kwargs):
# Create sparse -> dense precision scatter indices.
self._dense_scatter = {}
for d, site in self._factors.items():
sqrt_shape = deep_getattr(self.factors, d).shape
precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1]
index = torch.zeros(precision_shape, dtype=torch.long)
prec_sqrt_shape = deep_getattr(self.prec_sqrts, d).shape
info_vec_shape = prec_sqrt_shape[:-1]
precision_shape = prec_sqrt_shape[:-1] + prec_sqrt_shape[-2:-1]
index1 = torch.zeros(info_vec_shape, dtype=torch.long)
index2 = torch.zeros(precision_shape, dtype=torch.long)

# Collect local offsets.
# Collect local offsets and create index1 for info_vec blockwise.
upstreams = [
u for u in self.dependencies[d] if not self._factors[u]["is_observed"]
]
Expand All @@ -345,8 +350,17 @@ def _setup_prototype(self, *args, **kwargs):
local_offsets[u] = pos
broken_plates = self._plates[u] - self._plates[d]
pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel()
u_index = global_indices[u]

# Permute broken plates to the right of preserved plates.
u_index = _break_plates(u_index, self._plates[u], self._plates[d])

# Create indices blockwise.
# Scatter global indices into the [u] block.
u_start = local_offsets[u]
u_stop = u_start + u_index.size(-1)
index1[..., u_start:u_stop] = u_index

# Create index2 for precision blockwise.
for u, v in itertools.product(upstreams, upstreams):
u_index = global_indices[u]
v_index = global_indices[v]
Expand All @@ -360,18 +374,24 @@ def _setup_prototype(self, *args, **kwargs):
u_stop = u_start + u_index.size(-1)
v_start = local_offsets[v]
v_stop = v_start + v_index.size(-1)
index[
index2[
..., u_start:u_stop, v_start:v_stop
] = self._dense_size * u_index.unsqueeze(-1) + v_index.unsqueeze(-2)

self._dense_scatter[d] = index.reshape(-1)

def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
flat_samples = pyro.sample(
f"_{self._pyro_name}_latent",
self._dense_get_mvn(),
infer={"is_auxiliary": True},
)
self._dense_scatter[d] = index1.reshape(-1), index2.reshape(-1)

def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
mvn = self._dense_get_mvn()
if temperature == 0:
# Simply return the mode.
flat_samples = mvn.mean
elif temperature == 1:
# Sample from a dense joint Gaussian over flattened variables.
flat_samples = pyro.sample(
f"_{self._pyro_name}_latent", mvn, infer={"is_auxiliary": True}
)
else:
raise NotImplementedError(f"Invalid temperature: {temperature}")
samples = self._dense_unflatten(flat_samples)
return samples

Expand Down Expand Up @@ -401,14 +421,22 @@ def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:

def _dense_get_mvn(self):
# Create a dense joint Gaussian over flattened variables.
flat_info_vec = torch.zeros(self._dense_size)
flat_precision = torch.zeros(self._dense_size ** 2)
for d, index in self._dense_scatter.items():
sqrt = deep_getattr(self.factors, d)
precision = sqrt @ sqrt.transpose(-1, -2)
flat_precision.scatter_add_(0, index, precision.reshape(-1))
for d, (index1, index2) in self._dense_scatter.items():
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
info_vec = (prec_sqrt @ white_vec[..., None])[..., 0]
precision = prec_sqrt @ prec_sqrt.transpose(-1, -2)
flat_info_vec.scatter_add_(0, index1, info_vec.reshape(-1))
flat_precision.scatter_add_(0, index2, precision.reshape(-1))
info_vec = flat_info_vec
precision = flat_precision.reshape(self._dense_size, self._dense_size)
loc = precision.new_zeros(self._dense_size)
return dist.MultivariateNormal(loc, precision_matrix=precision)
scale_tril = _precision_to_scale_tril(precision)
loc = (
scale_tril @ (scale_tril.transpose(-1, -2) @ info_vec.unsqueeze(-1))
).squeeze(-1)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)


class AutoGaussianFunsor(AutoGaussian):
Expand Down Expand Up @@ -464,7 +492,7 @@ def _setup_prototype(self, *args, **kwargs):
self._funsor_plate_to_dim = plate_to_dim
self._funsor_plates = frozenset(plate_to_dim)

def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
funsor = _import_funsor()

# Convert torch to funsor.
Expand All @@ -473,38 +501,43 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
plate_to_dim.update({f.name: f.dim for f in particle_plates})
factors = {}
for d, inputs in self._funsor_factor_inputs.items():
prec_sqrt = deep_getattr(self.factors, d)
batch_shape = torch.Size(
p.size for p in sorted(self._plates[d], key=lambda p: p.dim)
)
prec_sqrt = prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:])
# TODO Make white_vec learnable once .median() can be computed via
# funsor.recipies.forward_filter_backward_precondition()
# https://github.com/pyro-ppl/funsor/pull/553
white_vec = prec_sqrt.new_zeros(()).expand(
prec_sqrt.shape[:-2] + prec_sqrt.shape[-1:]
)
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
factors[d] = funsor.gaussian.Gaussian(
white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs
white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]),
prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]),
inputs=inputs,
)

# Perform Gaussian tensor variable elimination.
try: # Convert ValueError into NotImplementedError.
samples, log_prob = funsor.recipes.forward_filter_backward_rsample(
if temperature == 1:
samples, log_prob = _try_possibly_intractable(
funsor.recipes.forward_filter_backward_rsample,
factors=factors,
eliminate=self._funsor_eliminate,
plates=frozenset(plate_to_dim),
sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates},
)
except ValueError as e:
if str(e) != "intractable!":
raise e from None
raise NotImplementedError(
"Funsor backend found intractable plate nesting. "
'Consider using AutoGaussian(..., backend="dense"), '
"splitting into multiple guides via AutoGuideList, or "
"replacing some plates in the model by .to_event()."
) from e

else:
samples, log_prob = _try_possibly_intractable(
funsor.recipes.forward_filter_backward_precondition,
factors=factors,
eliminate=self._funsor_eliminate,
plates=frozenset(plate_to_dim),
)

# Substitute noise.
sample_shape = torch.Size(f.size for f in particle_plates)
noise = torch.randn(sample_shape + log_prob.inputs["aux"].shape)
noise.mul_(temperature)
aux = funsor.Tensor(noise)[tuple(f.name for f in particle_plates)]
with funsor.interpretations.memoize():
samples = {k: v(aux=aux) for k, v in samples.items()}
log_prob = log_prob(aux=aux)

# Convert funsor to torch.
if am_i_wrapped() and poutine.get_mask() is not False:
Expand All @@ -516,6 +549,31 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
return samples


def _precision_to_scale_tril(P):
# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
L = torch.triangular_solve(
torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False
)[0]
return L


def _try_possibly_intractable(fn, *args, **kwargs):
# Convert ValueError into NotImplementedError.
try:
return fn(*args, **kwargs)
except ValueError as e:
if str(e) != "intractable!":
raise e from None
raise NotImplementedError(
"Funsor backend found intractable plate nesting. "
'Consider using AutoGaussian(..., backend="dense"), '
"splitting into multiple guides via AutoGuideList, or "
"replacing some plates in the model by .to_event()."
) from e


def _plates_to_shape(plates):
shape = [1] * max([0] + [-f.dim for f in plates])
for f in plates:
Expand Down
26 changes: 15 additions & 11 deletions tests/infer/autoguide/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_break_plates,
)
from pyro.infer.reparam import LocScaleReparam
from pyro.optim import Adam
from pyro.optim import ClippedAdam
from tests.common import assert_close, assert_equal, xfail_if_not_implemented

BACKENDS = [
Expand Down Expand Up @@ -131,27 +131,27 @@ def check_backends_agree(model):
params2 = dict(guide2.named_parameters())
assert set(params1) == set(params2)
for k, v in params1.items():
v.data.normal_()
v.data.add_(torch.zeros_like(v).normal_())
params2[k].data.copy_(v.data)
names = sorted(params1)

# Check densities agree between backends.
with torch.no_grad(), poutine.trace() as tr:
aux = guide2._sample_aux_values()
aux = guide2._sample_aux_values(temperature=1.0)
flat = guide1._dense_flatten(aux)
tr.trace.compute_log_prob()
log_prob_funsor = tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"]
with torch.no_grad(), poutine.trace() as tr:
with poutine.condition(data={"_AutoGaussianDense_latent": flat}):
guide1._sample_aux_values()
guide1._sample_aux_values(temperature=1.0)
tr.trace.compute_log_prob()
log_prob_dense = tr.trace.nodes["_AutoGaussianDense_latent"]["log_prob"]
assert_equal(log_prob_funsor, log_prob_dense)

# Check Monte Carlo estimate of entropy.
entropy1 = guide1._dense_get_mvn().entropy()
with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr:
guide2._sample_aux_values()
guide2._sample_aux_values(temperature=1.0)
tr.trace.compute_log_prob()
entropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean()
assert_close(entropy1, entropy2, atol=1e-2)
Expand All @@ -163,10 +163,14 @@ def check_backends_agree(model):
)
for name, grad1, grad2 in zip(names, grads1, grads2):
# Gradients should agree to very high precision.
if grad1 is None and grad2 is not None:
grad1 = torch.zeros_like(grad2)
elif grad2 is None and grad1 is not None:
grad2 = torch.zeros_like(grad1)
assert_close(grad1, grad2, msg=f"{name}:\n{grad1} vs {grad2}")

# Check elbos agree between backends.
elbo = Trace_ELBO(num_particles=100000, vectorize_particles=True)
elbo = Trace_ELBO(num_particles=1000000, vectorize_particles=True)
loss1 = elbo.differentiable_loss(model, guide1)
loss2 = elbo.differentiable_loss(model, guide2)
assert_close(loss1, loss2, atol=1e-2, rtol=0.05)
Expand Down Expand Up @@ -422,7 +426,7 @@ def model():
pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step()
Expand All @@ -445,7 +449,7 @@ def model():
pyro.sample("d", dist.Normal(c, 1), obs=torch.zeros(3, 2))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step()
Expand Down Expand Up @@ -674,7 +678,7 @@ def test_pyrocov_smoke(model, Guide, backend):
}

guide = Guide(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step(dataset)
Expand Down Expand Up @@ -703,7 +707,7 @@ def test_pyrocov_reparam(model, Guide, backend):
}
model = poutine.reparam(model, config)
guide = Guide(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step(dataset)
Expand Down Expand Up @@ -825,7 +829,7 @@ def test_profile(backend, jit, n=1, num_steps=1, log_every=1):
print("Training")
Elbo = JitTrace_ELBO if jit else Trace_ELBO
elbo = Elbo(max_plate_nesting=3, ignore_jit_warnings=True)
svi = SVI(model, guide, Adam({"lr": 1e-8}), elbo)
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), elbo)
for step in range(num_steps):
loss = svi.step(dataset)
if log_every and step % log_every == 0:
Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,7 @@ def model(data):
guide.requires_grad_(False)
with torch.no_grad():
# Check moments.
vectorize = pyro.plate("particles", 10000, dim=-2)
vectorize = pyro.plate("particles", 50000, dim=-2)
guide_trace = poutine.trace(vectorize(guide)).get_trace(data)
samples = poutine.replay(vectorize(model), guide_trace)(data)
for name in ["x", "y"]:
Expand Down

0 comments on commit 1877116

Please sign in to comment.