Skip to content

Commit

Permalink
Do not create plates for observed sites in AutoGuide.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Feb 5, 2025
1 parent d6ba568 commit 681361a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
4 changes: 4 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def _setup_prototype(self, *args, **kwargs):
# raise support errors early for discrete sites
with helpful_support_errors(site):
biject_to(site["fn"].support)
# Do not create plates for observed sites because they may be subsampled
# with a different size during prototype setup and training.
if site["is_observed"]:
continue
for frame in site["cond_indep_stack"]:
if frame.name in self._prototype_frames:
assert frame == self._prototype_frames[frame.name], (
Expand Down
35 changes: 35 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,38 @@ def model(x):
# Check delta distributions are fine if observed.
guide = AutoDiagonalNormal(lambda: model(3.0))
numpyro.handlers.seed(guide, 9)()


@pytest.mark.parametrize(
"guide_cls",
[
AutoBNAFNormal,
AutoDAIS,
AutoDelta,
AutoDiagonalNormal,
AutoLaplaceApproximation,
AutoLowRankMultivariateNormal,
AutoMultivariateNormal,
AutoNormal,
],
)
def test_subsample(guide_cls) -> None:
def model(n: int, x: jnp.ndarray):
mu = numpyro.sample("mu", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
with numpyro.plate("n", n, subsample_size=x.size):
numpyro.sample("x", dist.Normal(mu, sigma), obs=x)

n = 20
x = 5 + jax.random.normal(jax.random.key(1), (20,))
subset = x[: n // 2]

svi = numpyro.infer.SVI(
model,
guide_cls(model),
numpyro.optim.Adam(0.1),
numpyro.infer.Trace_ELBO(),
n=n,
)
state = svi.init(jax.random.key(2), x=x)
svi.update(state, x=subset)

0 comments on commit 681361a

Please sign in to comment.