From aec2e51d6e6bf83d148151e7b6f9b0e90ea9d366 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Tue, 27 Aug 2024 14:55:52 -0300 Subject: [PATCH] exposing in top level api --- blackjax/__init__.py | 4 +++- blackjax/smc/__init__.py | 1 + blackjax/smc/partial_posteriors_path.py | 2 +- tests/smc/test_partial_posteriors_smc.py | 7 +++---- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index dfdcfc545..6d4258eed 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -36,6 +36,7 @@ from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning from .smc import tempered +from .smc import partial_posteriors_path as _partial_posteriors_smc from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder from .vi import schrodinger_follmer as _schrodinger_follmer @@ -119,8 +120,9 @@ def generate_top_level_api_from(module): adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) tempered_smc = generate_top_level_api_from(tempered) inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) +partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc) -smc_family = [tempered_smc, adaptive_tempered_smc] +smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc] "Step_fn returning state has a .particles attribute" # stochastic gradient mcmc diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index ef10b10e6..2c09aa67b 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -6,4 +6,5 @@ "tempered", "inner_kernel_tuning", "extend_params", + "partial_posteriors_path" ] diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py index 49244b4ae..753d00247 100644 --- a/blackjax/smc/partial_posteriors_path.py +++ b/blackjax/smc/partial_posteriors_path.py @@ -98,8 +98,8 @@ def as_top_level_api( mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, + num_mcmc_steps, partial_logposterior_factory: Callable, - num_mcmc_steps: Optional[int] = 10, update_strategy=update_and_take_last, ) -> SamplingAlgorithm: """ diff --git a/tests/smc/test_partial_posteriors_smc.py b/tests/smc/test_partial_posteriors_smc.py index 4abbb7c92..d6bad6146 100644 --- a/tests/smc/test_partial_posteriors_smc.py +++ b/tests/smc/test_partial_posteriors_smc.py @@ -7,7 +7,6 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax.smc import extend_params -from blackjax.smc.partial_posteriors_path import build_kernel, init from tests.smc import SMCLinearRegressionTestCase @@ -49,13 +48,13 @@ def partial_logposterior(x): return jax.jit(partial_logposterior) - kernel = build_kernel( + init, kernel = blackjax.partial_posteriors_smc( hmc_kernel, hmc_init, + hmc_parameters, resampling.systematic, 30, - hmc_parameters, - partial_logposterior_factory=partial_logposterior_factory, + partial_logposterior_factory=partial_logposterior_factory ) init_state = init(init_particles, 1000)