diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py index 2381152f4..1279ad245 100644 --- a/blackjax/smc/partial_posteriors_path.py +++ b/blackjax/smc/partial_posteriors_path.py @@ -16,19 +16,19 @@ class PartialPosteriorsSMCState(NamedTuple): The particles' positions. weights: Weights of the particles, so that they represent a probability distribution - selector: - Datapoints used to calculate the posterior the particles represent, a 1D boolean - array to indicate which datapoints to include in the computation of the observed likelihood. + data_mask: + A 1D boolean array to indicate which datapoints to include + in the computation of the observed likelihood. """ particles: ArrayTree weights: Array - selector: Array + data_mask: Array def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState: """num_datapoints are the number of observations that could potentially be - used in a partial posterior. Since the initial selector is all 0s, it + used in a partial posterior. Since the initial data_mask is all 0s, it means that no likelihood term will be added (only prior). """ num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] @@ -73,11 +73,11 @@ def build_kernel( delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) def step( - key, state: PartialPosteriorsSMCState, selector: Array + key, state: PartialPosteriorsSMCState, data_mask: Array ) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]: - logposterior_fn = partial_logposterior_factory(selector) + logposterior_fn = partial_logposterior_factory(data_mask) - previous_logposterior_fn = partial_logposterior_factory(state.selector) + previous_logposterior_fn = partial_logposterior_factory(state.data_mask) def log_weights_fn(x): return logposterior_fn(x) - previous_logposterior_fn(x) @@ -86,7 +86,7 @@ def log_weights_fn(x): key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn ) - return PartialPosteriorsSMCState(state.particles, state.weights, selector), info + return PartialPosteriorsSMCState(state.particles, state.weights, data_mask), info return step @@ -118,7 +118,7 @@ def init_fn(position: ArrayLikeTree, num_observations, rng_key=None): del rng_key return init(position, num_observations) - def step(key: PRNGKey, state: PartialPosteriorsSMCState, selector: Array): - return kernel(key, state, selector) + def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array): + return kernel(key, state, data_mask) return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type] diff --git a/tests/smc/test_partial_posteriors_smc.py b/tests/smc/test_partial_posteriors_smc.py index 0b12be8f1..5d5a5e0ed 100644 --- a/tests/smc/test_partial_posteriors_smc.py +++ b/tests/smc/test_partial_posteriors_smc.py @@ -38,12 +38,12 @@ def test_partial_posteriors(self): dataset_size = 1000 - def partial_logposterior_factory(selector): + def partial_logposterior_factory(data_mask): def partial_logposterior(x): lp = logprior_fn(x) return lp + jnp.sum( self.logdensity_by_observation(**x, **observations) - * selector.reshape(-1, 1) + * data_mask.reshape(-1, 1) ) return jax.jit(partial_logposterior) @@ -60,20 +60,20 @@ def partial_logposterior(x): init_state = init(init_particles, 1000) smc_kernel = self.variant(kernel) - selectors = jnp.array( + data_masks = jnp.array( [ - jnp.concat([jnp.ones(selector), jnp.zeros(dataset_size - selector)]) - for selector in np.arange(100, 1001, 50) + jnp.concat([jnp.ones(datapoints_chosen), jnp.zeros(dataset_size - datapoints_chosen)]) + for datapoints_chosen in np.arange(100, 1001, 50) ] ) - def body_fn(carry, selector): + def body_fn(carry, data_mask): i, state = carry subkey = jax.random.fold_in(self.key, i) - new_state, info = smc_kernel(subkey, state, selector) + new_state, info = smc_kernel(subkey, state, data_mask) return (i + 1, new_state), (new_state, info) - (steps, result), it = jax.lax.scan(body_fn, (0, init_state), selectors) + (steps, result), it = jax.lax.scan(body_fn, (0, init_state), data_masks) assert steps == 19 self.assert_linear_regression_test_case(result)