You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Nested sampling [Skilling] is a particle monte carlo method typically styled as primarily targeting estimation of the marginal likelihood. It has been justified as a form of SMC 1805.03924, however we think it is in practice a rather singular form and requires some special implementation details hence the PR (although open to debate here). The basic construct is to scan in increasing likelihood levels while sampling from the truncated prior region at each step. It is very well established in the physical sciences 2205.15570, however with limited usage outside of this context. Here it is often noted that as an “athermal” method it sees phase transitions differently to typical methods based on some tempering scheme, and can be very profitable in those sorts of problems.
There are a number of different code implementations of the same generic algorithm, see table 2 2205.15570 for a relatively up to date list with some specific implementation choices listed. Generally in mixes of fortran/python and C, limited consistent working examples that are compatible with modern python PPLs
The successful implementation of nested sampling is, in our opinion, pretty closely tied to slice sampling, so we have implemented that as well, and this can be extracted in a more abstract way if that is generically useful elsewhere.
How does it compare to other algorithms in blackjax?
This fits naturally alongside the existing SMC methods in blackjax, and part of our motivation to push this to blackjax is so we can compare here in a like for like way. We have a working implementation already that gives consistent estimation of the marginal likelihood in 10-100 parameters fairly rapidly. In theory the design of SMC in blackjax could be followed, allowing a flexible choice of inner kernel, however generic application of many gradient based MCMC chains in nested sampling is difficult, so we have currently tied quite closely to slice sampling, this design choice could be lifted.
The rather specific requirements of the inner kernel when sampling from truncated regions, means that in our estimation nested sampling cannot be efficiently adapted to directly use the existing SMC abstractions in blackjax (you “see” both the likelihood and prior log prob in nested sampling as separate so the inner kernel construction is rather bespoke). However we tried to follow the SMC design patterns closely and if unification is possible here that would be worthwhile.
Where does it fit in blackjax
We have already implemented a preliminary version of this and found the blackjax frameworks and abstraction incredibly useful. This would likely have a number of immediate applications in physics. We hope this would benefit blackjax in increased application there, as well as affording uncluttered comparison to other SoTA SMC methods.
Are you willing to open a PR?
Yes, we have a working implementation, albeit needing some generalising and checking, we would make necessary refactors if this is deemed useful to the core library.
The text was updated successfully, but these errors were encountered:
I'd be happy to review a PR on this. I think it's a valid contribution too. I'm not an expert on nested sampling but also not clueless so hopefully it can be painless.
Presentation of the new sampler
Nested sampling [Skilling] is a particle monte carlo method typically styled as primarily targeting estimation of the marginal likelihood. It has been justified as a form of SMC 1805.03924, however we think it is in practice a rather singular form and requires some special implementation details hence the PR (although open to debate here). The basic construct is to scan in increasing likelihood levels while sampling from the truncated prior region at each step. It is very well established in the physical sciences 2205.15570, however with limited usage outside of this context. Here it is often noted that as an “athermal” method it sees phase transitions differently to typical methods based on some tempering scheme, and can be very profitable in those sorts of problems.
There are a number of different code implementations of the same generic algorithm, see table 2 2205.15570 for a relatively up to date list with some specific implementation choices listed. Generally in mixes of fortran/python and C, limited consistent working examples that are compatible with modern python PPLs
The successful implementation of nested sampling is, in our opinion, pretty closely tied to slice sampling, so we have implemented that as well, and this can be extracted in a more abstract way if that is generically useful elsewhere.
How does it compare to other algorithms in blackjax?
This fits naturally alongside the existing SMC methods in blackjax, and part of our motivation to push this to blackjax is so we can compare here in a like for like way. We have a working implementation already that gives consistent estimation of the marginal likelihood in 10-100 parameters fairly rapidly. In theory the design of SMC in blackjax could be followed, allowing a flexible choice of inner kernel, however generic application of many gradient based MCMC chains in nested sampling is difficult, so we have currently tied quite closely to slice sampling, this design choice could be lifted.
The rather specific requirements of the inner kernel when sampling from truncated regions, means that in our estimation nested sampling cannot be efficiently adapted to directly use the existing SMC abstractions in blackjax (you “see” both the likelihood and prior log prob in nested sampling as separate so the inner kernel construction is rather bespoke). However we tried to follow the SMC design patterns closely and if unification is possible here that would be worthwhile.
Where does it fit in blackjax
We have already implemented a preliminary version of this and found the blackjax frameworks and abstraction incredibly useful. This would likely have a number of immediate applications in physics. We hope this would benefit blackjax in increased application there, as well as affording uncluttered comparison to other SoTA SMC methods.
Are you willing to open a PR?
Yes, we have a working implementation, albeit needing some generalising and checking, we would make necessary refactors if this is deemed useful to the core library.
The text was updated successfully, but these errors were encountered: