Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interchangeability of Callable and BasePotential #1223

Open
schroedk opened this issue Aug 19, 2024 · 7 comments
Open

Interchangeability of Callable and BasePotential #1223

schroedk opened this issue Aug 19, 2024 · 7 comments
Labels
question Further information is requested

Comments

@schroedk
Copy link
Contributor

schroedk commented Aug 19, 2024

I have a question regarding the interchangeability of the argument potential_fn of

class NeuralPosterior(ABC):
    r"""Posterior $p(\theta|x)$ with `log_prob()` and `sample()` methods.<br/><br/>
    All inference methods in sbi train a neural network which is then used to obtain
    the posterior distribution. The `NeuralPosterior` class wraps the trained network
    such that one can directly evaluate the (unnormalized) log probability and draw
    samples from the posterior.
    """

    def __init__(
        self,
        potential_fn: Union[Callable, BasePotential],

For the callable case, it must be something like:

def potential(theta=None, x0=None)
    ...

in contrast to BasePotential, which is a Callable with theta as positional argument and track_gradients as keyword argument, correct? Is this tested somewhere? I only found examples where the argument is of type BasePotential.

@schroedk schroedk added the question Further information is requested label Aug 19, 2024
@schroedk
Copy link
Contributor Author

Related #1055

@janfb
Copy link
Contributor

janfb commented Aug 19, 2024

yes, correct.

The reason that the custom potential_fn has theta and x_o as args is that quantities are required to calculate the "potential", i.e., the unnormalized posterior probability.

For the BasePotential potential, the call method does not have x_o as arg, because it is set as property at runtime.

If a user passes a custom potential, then this is checked for the required args here:

for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)
# If the `potential_fn` is a Callable then we wrap it as a
# `CallablePotentialWrapper` which inherits from `BasePotential`.
potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
potential_fn, prior=None, x_o=None, device=potential_device
)

and then wrapped as BasePotential here:

class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""
allow_iid_x = True # type: ignore
def __init__(
self,
callable_potential,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential
def __call__(self, theta, track_gradients: bool = True):
with torch.set_grad_enabled(track_gradients):
return self.callable_potential(theta=theta, x_o=self.x_o)

@janfb
Copy link
Contributor

janfb commented Aug 19, 2024

Is this tested somewhere? I only found examples where the argument is of type BasePotential.

Yes, I had to dig a bit as well, but it's tested here:

def test_callable_potential(sampling_method, mcmc_params_accurate: dict):
"""Test whether callable potentials can be used to sample from a Gaussian."""
dim = 2
mean = 2.5
cov = 2.0
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))
def potential(theta, x_o):
return target_density.log_prob(theta + x_o)

Here, you can how we define a custom potential, depending on inputs theta and x_o.

@michaeldeistler
Copy link
Contributor

I think this can be closed, feel free to reopen if anything is still unclear!

@janfb
Copy link
Contributor

janfb commented Aug 29, 2024

I think it's a good starting point for refactoring the Callable potential API.

@janfb janfb reopened this Aug 29, 2024
@michaeldeistler
Copy link
Contributor

what what have to be done here? Just more docs?

@janfb
Copy link
Contributor

janfb commented Aug 29, 2024

At the moment, if a user passes a just a Callable as potential, we test during runtime whether it has the required arguments, e.g., theta and track_gradients and x_o (or so). This is brittle. It would be nice to do this beforehand with types, e.g., define a Protocol to ensure that the passed Callable has the correct signature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants