From 056c97008238cf0556555b7bf8f707384d4ed677 Mon Sep 17 00:00:00 2001 From: odunbar Date: Fri, 24 Jan 2025 15:08:51 -0800 Subject: [PATCH] start on observation series in Sampler --- src/MarkovChainMonteCarlo.jl | 112 ++++++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/src/MarkovChainMonteCarlo.jl b/src/MarkovChainMonteCarlo.jl index dd84f8be..0cee3783 100644 --- a/src/MarkovChainMonteCarlo.jl +++ b/src/MarkovChainMonteCarlo.jl @@ -142,6 +142,42 @@ MetropolisHastingsSampler(::pCNMHSampling, prior::ParameterDistribution) = pCNMe # ------------------------------------------------------------------------------------------ # Use emulated model in sampler +""" +$(DocStringExtensions.TYPEDEF) + +Stores an observation series and a log-density function as ℓ(params, observation). +This can handle minibatched likelihoods, unlike the AdvancedMH.DensityModel that only +stores a log-density as a function ℓ(params). +""" +struct ObservationSeriesLogDensityModel{F, OS <: ObservationSeries} <: AdvancedMH.DensityModel + "Function to compute logdensity ℓ(θ, observation)" + logdensity::F + "ObservationSeries object storing observation vectors and batching. `get_obs(observation_series)` gets current minibatch observation" + observation_series::OS +end + +function emulator_log_density_model(θ, observation_series) + obs = to_decorrelated(get_obs(observation_series), em) + + # θ: model params we evaluate at; in original coords. + # transform_to_real = false means g, g_cov, obs_sample are in decorrelated coords. + + # Recall predict() written to return multiple N_samples: expects input to be a + # Matrix with N_samples columns. Returned g is likewise a Matrix, and g_cov is a + # Vector of N_samples covariance matrices. For MH, N_samples is always 1, so we + # have to reshape()/re-cast input/output; simpler to do here than add a + # predict() method. + g, g_cov = + Emulators.predict(em, reshape(θ, :, 1), transform_to_real = false, vector_rf_unstandardize = false) + #TODO vector_rf will always unstandardize, but other methods will not, so we require this additional flag. + + if isa(g_cov[1], Real) + return logpdf(MvNormal(obs, g_cov[1] * I), vec(g)) + logpdf(prior, θ) + else + return logpdf(MvNormal(obs, g_cov[1]), vec(g)) + logpdf(prior, θ) + end + +end """ $(DocStringExtensions.TYPEDSIGNATURES) @@ -153,29 +189,12 @@ with the MCMC, which is the role of the `DensityModel` class in the `AbstractMCM function EmulatorPosteriorModel( prior::ParameterDistribution, em::Emulator{FT}, - obs_sample::AbstractVector{FT}, -) where {FT <: AbstractFloat} - return AdvancedMH.DensityModel( - function (θ) - # θ: model params we evaluate at; in original coords. - # transform_to_real = false means g, g_cov, obs_sample are in decorrelated coords. - # - # Recall predict() written to return multiple N_samples: expects input to be a - # Matrix with N_samples columns. Returned g is likewise a Matrix, and g_cov is a - # Vector of N_samples covariance matrices. For MH, N_samples is always 1, so we - # have to reshape()/re-cast input/output; simpler to do here than add a - # predict() method. - g, g_cov = - Emulators.predict(em, reshape(θ, :, 1), transform_to_real = false, vector_rf_unstandardize = false) - #TODO vector_rf will always unstandardize, but other methods will not, so we require this additional flag. - - if isa(g_cov[1], Real) - return logpdf(MvNormal(obs_sample, g_cov[1] * I), vec(g)) + logpdf(prior, θ) - else - return logpdf(MvNormal(obs_sample, g_cov[1]), vec(g)) + logpdf(prior, θ) - end + observation_series::OS, +) where {FT <: AbstractFloat, OS <: ObservationSeries} - end, + return ObservationSeriesLogDensityModel( + emulator_log_density_model, # needs to be a function? + observation_series, ) end @@ -192,7 +211,7 @@ Metropolis-Hastings proposal) or old (from rejecting a proposal). # Fields $(DocStringExtensions.TYPEDFIELDS) """ -struct MCMCState{T, L <: Real} <: AdvancedMH.AbstractTransition +struct MCMCState{T, OS <: ObservationSeries, L <: Real} <: AdvancedMH.AbstractTransition "Sampled value of the parameters at the current state of the MCMC chain." params::T "Log probability of `params`, as computed by the model using the prior." @@ -207,7 +226,10 @@ MCMCState(model::AdvancedMH.DensityModel, params, accepted = true) = MCMCState(params, logdensity(model, params), accepted) # Calculate the log density of the model given some parameterization. -AdvancedMH.logdensity(model::AdvancedMH.DensityModel, t::MCMCState) = t.log_density +function AdvancedMH.logdensity(model::AdvancedMH.DensityModel, params, observation_series::OS) + return model.logdensity(params, obs) +end +AdvancedMH.logdensity(model::AdvancedMH.DensityModel, t::MCMCState) = t.log_density # AdvancedMH.transition() is only called to create a new proposal, so create a MCMCState # with accepted = true since that object will only be used if proposal is accepted. @@ -260,11 +282,11 @@ function AbstractMCMC.step( ) where {FT <: AbstractFloat} # Generate a new proposal. new_params = AdvancedMH.propose(rng, sampler, model, current_state; stepsize = stepsize) - + update_minibatch(observation # Calculate the log acceptance probability and the log density of the candidate. - new_log_density = AdvancedMH.logdensity(model, new_params) + new_log_density = AdvancedMH.logdensity(model, new_params) log_α = - new_log_density - AdvancedMH.logdensity(model, current_state) + + new_log_density - AdvancedMH.logdensity(model, current_state) + AdvancedMH.logratio_proposal_density(sampler, current_state, new_params) # Decide whether to return the previous params or the new one. @@ -382,9 +404,11 @@ AbstractMCMC's terminology). # Fields $(DocStringExtensions.TYPEDFIELDS) """ -struct MCMCWrapper +struct MCMCWrapper{OS <: ObservationSeries} "[`ParameterDistribution`](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/parameter_distributions/) object describing the prior distribution on parameter values." prior::ParameterDistribution + "[`ObservationSeries`](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/observations/) object describing the observations, noise and batching from the inverse problem." + observation_series::OS "`AdvancedMH.DensityModel` object, used to evaluate the posterior density being sampled from." log_posterior_map::AbstractMCMC.AbstractModel "Object describing a MCMC sampling algorithm and its settings." @@ -420,15 +444,14 @@ decorrelation) that was applied in the Emulator. It creates and wraps an instanc """ function MCMCWrapper( mcmc_alg::MCMCProtocol, - obs_sample::AbstractVector{FT}, + observation_series::OS prior::ParameterDistribution, em::Emulator; - init_params::AbstractVector{FT}, - burnin::IT = 0, + init_params::AV, + burnin::Int = 0, kwargs..., -) where {FT <: AbstractFloat, IT <: Integer} - obs_sample = to_decorrelated(obs_sample, em) - log_posterior_map = EmulatorPosteriorModel(prior, em, obs_sample) +) where {AV <: AbstractVector, OS <: ObservationSeries} + log_posterior_map = EmulatorPosteriorModel(prior, em, observation_series) mh_proposal_sampler = MetropolisHastingsSampler(mcmc_alg, prior) # parameter names are needed in every dimension in a MCMCChains object needed for diagnostics @@ -448,9 +471,28 @@ function MCMCWrapper( :chain_type => MCMCChains.Chains, ) sample_kwargs = merge(sample_kwargs, kwargs) # override defaults with any explicit values - return MCMCWrapper(prior, log_posterior_map, mh_proposal_sampler, sample_kwargs) + return MCMCWrapper(prior, observation_series, log_posterior_map, mh_proposal_sampler, sample_kwargs) +end + + +function MCMCWrapper( + mcmc_alg::MCMCProtocol, + obs::OB, + args...; kwargs... +) where {OB <: Observation} + observation_series = ObservationSeries(observation) + return MCMCWrapper(mcmc_alg, observation_series, args...; kwargs...) end +function MCMCWrapper( + mcmc_alg::MCMCProtocol, + obs::AV, + obs_noise_cov::UorM, + args...; kwargs... +) where {AV <: AbstractVector, UorM <: Union{UniformScaling,AbstractMatrix}} + observation = Observation(Dict("samples" => obs, "covariances" => obs_noise_cov, "names" => "observation")) + return MCMCWrapper(mcmc_alg, observation, args...; kwargs...) +end """ $(DocStringExtensions.FUNCTIONNAME)([rng,] mcmc::MCMCWrapper, args...; kwargs...)