Skip to content

Commit

Permalink
exposing in top level api
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Aug 27, 2024
1 parent 1304f9f commit aec2e51
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
Expand Down Expand Up @@ -119,8 +120,9 @@ def generate_top_level_api_from(module):
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

smc_family = [tempered_smc, adaptive_tempered_smc]
smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
Expand Down
1 change: 1 addition & 0 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
"tempered",
"inner_kernel_tuning",
"extend_params",
"partial_posteriors_path"
]
2 changes: 1 addition & 1 deletion blackjax/smc/partial_posteriors_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def as_top_level_api(
mcmc_init_fn: Callable,
mcmc_parameters: dict,
resampling_fn: Callable,
num_mcmc_steps,
partial_logposterior_factory: Callable,
num_mcmc_steps: Optional[int] = 10,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""
Expand Down
7 changes: 3 additions & 4 deletions tests/smc/test_partial_posteriors_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import blackjax
import blackjax.smc.resampling as resampling
from blackjax.smc import extend_params
from blackjax.smc.partial_posteriors_path import build_kernel, init
from tests.smc import SMCLinearRegressionTestCase


Expand Down Expand Up @@ -49,13 +48,13 @@ def partial_logposterior(x):

return jax.jit(partial_logposterior)

kernel = build_kernel(
init, kernel = blackjax.partial_posteriors_smc(
hmc_kernel,
hmc_init,
hmc_parameters,
resampling.systematic,
30,
hmc_parameters,
partial_logposterior_factory=partial_logposterior_factory,
partial_logposterior_factory=partial_logposterior_factory
)

init_state = init(init_particles, 1000)
Expand Down

0 comments on commit aec2e51

Please sign in to comment.