Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] MCMC updates (gradient-based variants and ESJD) #339

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

odunbar
Copy link
Collaborator

@odunbar odunbar commented Jan 16, 2025

Purpose

Closes #341
Closes #342
Closes #343
Closes #344
Closes #345

To-do

  • Remove hard-coded constants/parameters by creating proper constructors for the different structs
  • Investigate test failures for non-Barker forward-diff options: Currently it appears there were sign errors in the derivative. Possibly due to confusion of ln(pi) and f in the following: pi = exp (-f) and therefore \nabla ln(pi) = - \nabla f. (Replacing autodiff with - autodiff dramatically improves the MALA and infMALA performance. though did not improve HMC performance)

Content

From a private repos of @KotaOK-00 - with some small changes

  • Copy implementation of the gradient-based sampler proposals
  • Implemented the calculation of ESJD
  • Added esjd unit test.
  • Created a first pass of a new system AutodiffProtocol containing derived types GradFreeProtocol ForwardDiffProtocol (and in future others such as BackwardDiffProtocol etc.)
    When constructing an MCMCProtocol, default (and choosable) autodiff options are created, e.g.,
MALASampling() # creates MALASampling{ForwardDiffProtocol}()
pCNMHSampling() # creates pCNMHSampling{GradFreeProtocol}()

within the MCMC propose method direct calls to

ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
Symmetric(ForwardDiff.hessian(x -> AdvancedMH.logdensity(model, x), current_state.params))

Are replaced with autodiff_gradient and autodiff_hessian

autodiff_gradient(model, current_state.params, sampler)
autodiff_hessian(model, current_state.params, sampler)

that dispatch off the pre-specified autodiff protocol in the sampler. This will make adding new autodiff options easier

  • Created a new type AGPJL for AbstractGPs. When building an emulator with AGPJL, As there is no optimizer of such kernels, we allow users to train e.g. with GPJL, then do the following:
opt_params_per_model = Emulators.get_params(gp_optimized)
kernel_params = [
Dict(
       "log_rbf_len" => params[1:end-2],
       "log_std_sqexp" => params[end-1],
       "log_std_noise" => params[end],
) 
for params in opt_params_per_model]
agp = GaussianProcess(AGPJL(), ...)
Emulator( agp, ... ; ..., kernel_params = kernel_params) # builds

If the user does not provide kernel_params, a useful error message appears to direct them what to do.

  • Unit tests in GaussianProcesses/runtests.jl for AGP interface as above, and that GPJL vs AGP gives very similar emulator predicitions
  • Unit tests in MarkovChainMonteCarlo/runtests.jl for AGP on RWM and pCN sampling - give similar posterior means to GPJL (NB it is much slower, 5-10x)
  • Unit testing in MarkovChainMonteCarlo/runtests.jl for all other algorithms. Except for Barker: many tests fail due to non-convergent step-size calculation, or incorrect final posterior mean.

  • I have read and checked the items on the review checklist.

@odunbar odunbar changed the title MCMC updates (gradient-based variants and ESJD) [WIP] MCMC updates (gradient-based variants and ESJD calculation with the chain) Jan 16, 2025
@odunbar odunbar changed the title [WIP] MCMC updates (gradient-based variants and ESJD calculation with the chain) [WIP] MCMC updates (gradient-based variants and ESJD) Jan 16, 2025
Copy link

codecov bot commented Jan 16, 2025

Codecov Report

Attention: Patch coverage is 36.36364% with 126 lines in your changes missing coverage. Please review.

Project coverage is 81.58%. Comparing base (d5a079b) to head (11dd80f).

Files with missing lines Patch % Lines
src/MarkovChainMonteCarlo.jl 16.43% 122 Missing ⚠️
src/GaussianProcess.jl 91.83% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #339      +/-   ##
==========================================
- Coverage   88.82%   81.58%   -7.25%     
==========================================
  Files           7        7              
  Lines        1271     1455     +184     
==========================================
+ Hits         1129     1187      +58     
- Misses        142      268     +126     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment