Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested Sampling implementation #753

Open
yallup opened this issue Nov 1, 2024 · 1 comment
Open

Nested Sampling implementation #753

yallup opened this issue Nov 1, 2024 · 1 comment

Comments

@yallup
Copy link

yallup commented Nov 1, 2024

Presentation of the new sampler

Nested sampling [Skilling] is a particle monte carlo method typically styled as primarily targeting estimation of the marginal likelihood. It has been justified as a form of SMC 1805.03924, however we think it is in practice a rather singular form and requires some special implementation details hence the PR (although open to debate here). The basic construct is to scan in increasing likelihood levels while sampling from the truncated prior region at each step. It is very well established in the physical sciences 2205.15570, however with limited usage outside of this context. Here it is often noted that as an “athermal” method it sees phase transitions differently to typical methods based on some tempering scheme, and can be very profitable in those sorts of problems.

There are a number of different code implementations of the same generic algorithm, see table 2 2205.15570 for a relatively up to date list with some specific implementation choices listed. Generally in mixes of fortran/python and C, limited consistent working examples that are compatible with modern python PPLs

The successful implementation of nested sampling is, in our opinion, pretty closely tied to slice sampling, so we have implemented that as well, and this can be extracted in a more abstract way if that is generically useful elsewhere.

How does it compare to other algorithms in blackjax?

This fits naturally alongside the existing SMC methods in blackjax, and part of our motivation to push this to blackjax is so we can compare here in a like for like way. We have a working implementation already that gives consistent estimation of the marginal likelihood in 10-100 parameters fairly rapidly. In theory the design of SMC in blackjax could be followed, allowing a flexible choice of inner kernel, however generic application of many gradient based MCMC chains in nested sampling is difficult, so we have currently tied quite closely to slice sampling, this design choice could be lifted.

The rather specific requirements of the inner kernel when sampling from truncated regions, means that in our estimation nested sampling cannot be efficiently adapted to directly use the existing SMC abstractions in blackjax (you “see” both the likelihood and prior log prob in nested sampling as separate so the inner kernel construction is rather bespoke). However we tried to follow the SMC design patterns closely and if unification is possible here that would be worthwhile.

Where does it fit in blackjax

We have already implemented a preliminary version of this and found the blackjax frameworks and abstraction incredibly useful. This would likely have a number of immediate applications in physics. We hope this would benefit blackjax in increased application there, as well as affording uncluttered comparison to other SoTA SMC methods.

Are you willing to open a PR?

Yes, we have a working implementation, albeit needing some generalising and checking, we would make necessary refactors if this is deemed useful to the core library.

@AdrienCorenflos
Copy link
Contributor

I'd be happy to review a PR on this. I think it's a valid contribution too. I'm not an expert on nested sampling but also not clueless so hopefully it can be painless.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants