Skip to content

Commit

Permalink
start on observation series in Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jan 24, 2025
1 parent d5a079b commit 056c970
Showing 1 changed file with 77 additions and 35 deletions.
112 changes: 77 additions & 35 deletions src/MarkovChainMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,42 @@ MetropolisHastingsSampler(::pCNMHSampling, prior::ParameterDistribution) = pCNMe
# ------------------------------------------------------------------------------------------
# Use emulated model in sampler

"""
$(DocStringExtensions.TYPEDEF)
Stores an observation series and a log-density function as ℓ(params, observation).
This can handle minibatched likelihoods, unlike the AdvancedMH.DensityModel that only
stores a log-density as a function ℓ(params).
"""
struct ObservationSeriesLogDensityModel{F, OS <: ObservationSeries} <: AdvancedMH.DensityModel
"Function to compute logdensity ℓ(θ, observation)"
logdensity::F
"ObservationSeries object storing observation vectors and batching. `get_obs(observation_series)` gets current minibatch observation"
observation_series::OS
end

function emulator_log_density_model(θ, observation_series)
obs = to_decorrelated(get_obs(observation_series), em)

# θ: model params we evaluate at; in original coords.
# transform_to_real = false means g, g_cov, obs_sample are in decorrelated coords.

# Recall predict() written to return multiple N_samples: expects input to be a
# Matrix with N_samples columns. Returned g is likewise a Matrix, and g_cov is a
# Vector of N_samples covariance matrices. For MH, N_samples is always 1, so we
# have to reshape()/re-cast input/output; simpler to do here than add a
# predict() method.
g, g_cov =
Emulators.predict(em, reshape(θ, :, 1), transform_to_real = false, vector_rf_unstandardize = false)
#TODO vector_rf will always unstandardize, but other methods will not, so we require this additional flag.

if isa(g_cov[1], Real)
return logpdf(MvNormal(obs, g_cov[1] * I), vec(g)) + logpdf(prior, θ)
else
return logpdf(MvNormal(obs, g_cov[1]), vec(g)) + logpdf(prior, θ)
end

end
"""
$(DocStringExtensions.TYPEDSIGNATURES)
Expand All @@ -153,29 +189,12 @@ with the MCMC, which is the role of the `DensityModel` class in the `AbstractMCM
function EmulatorPosteriorModel(
prior::ParameterDistribution,
em::Emulator{FT},
obs_sample::AbstractVector{FT},
) where {FT <: AbstractFloat}
return AdvancedMH.DensityModel(
function (θ)
# θ: model params we evaluate at; in original coords.
# transform_to_real = false means g, g_cov, obs_sample are in decorrelated coords.
#
# Recall predict() written to return multiple N_samples: expects input to be a
# Matrix with N_samples columns. Returned g is likewise a Matrix, and g_cov is a
# Vector of N_samples covariance matrices. For MH, N_samples is always 1, so we
# have to reshape()/re-cast input/output; simpler to do here than add a
# predict() method.
g, g_cov =
Emulators.predict(em, reshape(θ, :, 1), transform_to_real = false, vector_rf_unstandardize = false)
#TODO vector_rf will always unstandardize, but other methods will not, so we require this additional flag.

if isa(g_cov[1], Real)
return logpdf(MvNormal(obs_sample, g_cov[1] * I), vec(g)) + logpdf(prior, θ)
else
return logpdf(MvNormal(obs_sample, g_cov[1]), vec(g)) + logpdf(prior, θ)
end
observation_series::OS,
) where {FT <: AbstractFloat, OS <: ObservationSeries}

end,
return ObservationSeriesLogDensityModel(
emulator_log_density_model, # needs to be a function?
observation_series,
)
end

Expand All @@ -192,7 +211,7 @@ Metropolis-Hastings proposal) or old (from rejecting a proposal).
# Fields
$(DocStringExtensions.TYPEDFIELDS)
"""
struct MCMCState{T, L <: Real} <: AdvancedMH.AbstractTransition
struct MCMCState{T, OS <: ObservationSeries, L <: Real} <: AdvancedMH.AbstractTransition
"Sampled value of the parameters at the current state of the MCMC chain."
params::T
"Log probability of `params`, as computed by the model using the prior."
Expand All @@ -207,7 +226,10 @@ MCMCState(model::AdvancedMH.DensityModel, params, accepted = true) =
MCMCState(params, logdensity(model, params), accepted)

# Calculate the log density of the model given some parameterization.
AdvancedMH.logdensity(model::AdvancedMH.DensityModel, t::MCMCState) = t.log_density
function AdvancedMH.logdensity(model::AdvancedMH.DensityModel, params, observation_series::OS)
return model.logdensity(params, obs)
end
AdvancedMH.logdensity(model::AdvancedMH.DensityModel, t::MCMCState) = t.log_density

# AdvancedMH.transition() is only called to create a new proposal, so create a MCMCState
# with accepted = true since that object will only be used if proposal is accepted.
Expand Down Expand Up @@ -260,11 +282,11 @@ function AbstractMCMC.step(
) where {FT <: AbstractFloat}
# Generate a new proposal.
new_params = AdvancedMH.propose(rng, sampler, model, current_state; stepsize = stepsize)

update_minibatch(observation
# Calculate the log acceptance probability and the log density of the candidate.
new_log_density = AdvancedMH.logdensity(model, new_params)
new_log_density = AdvancedMH.logdensity(model, new_params)
log_α =
new_log_density - AdvancedMH.logdensity(model, current_state) +
new_log_density - AdvancedMH.logdensity(model, current_state) +
AdvancedMH.logratio_proposal_density(sampler, current_state, new_params)

# Decide whether to return the previous params or the new one.
Expand Down Expand Up @@ -382,9 +404,11 @@ AbstractMCMC's terminology).
# Fields
$(DocStringExtensions.TYPEDFIELDS)
"""
struct MCMCWrapper
struct MCMCWrapper{OS <: ObservationSeries}
"[`ParameterDistribution`](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/parameter_distributions/) object describing the prior distribution on parameter values."
prior::ParameterDistribution
"[`ObservationSeries`](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/observations/) object describing the observations, noise and batching from the inverse problem."
observation_series::OS
"`AdvancedMH.DensityModel` object, used to evaluate the posterior density being sampled from."
log_posterior_map::AbstractMCMC.AbstractModel
"Object describing a MCMC sampling algorithm and its settings."
Expand Down Expand Up @@ -420,15 +444,14 @@ decorrelation) that was applied in the Emulator. It creates and wraps an instanc
"""
function MCMCWrapper(
mcmc_alg::MCMCProtocol,
obs_sample::AbstractVector{FT},
observation_series::OS
prior::ParameterDistribution,
em::Emulator;
init_params::AbstractVector{FT},
burnin::IT = 0,
init_params::AV,
burnin::Int = 0,
kwargs...,
) where {FT <: AbstractFloat, IT <: Integer}
obs_sample = to_decorrelated(obs_sample, em)
log_posterior_map = EmulatorPosteriorModel(prior, em, obs_sample)
) where {AV <: AbstractVector, OS <: ObservationSeries}
log_posterior_map = EmulatorPosteriorModel(prior, em, observation_series)
mh_proposal_sampler = MetropolisHastingsSampler(mcmc_alg, prior)

# parameter names are needed in every dimension in a MCMCChains object needed for diagnostics
Expand All @@ -448,9 +471,28 @@ function MCMCWrapper(
:chain_type => MCMCChains.Chains,
)
sample_kwargs = merge(sample_kwargs, kwargs) # override defaults with any explicit values
return MCMCWrapper(prior, log_posterior_map, mh_proposal_sampler, sample_kwargs)
return MCMCWrapper(prior, observation_series, log_posterior_map, mh_proposal_sampler, sample_kwargs)
end


function MCMCWrapper(
mcmc_alg::MCMCProtocol,
obs::OB,
args...; kwargs...
) where {OB <: Observation}
observation_series = ObservationSeries(observation)
return MCMCWrapper(mcmc_alg, observation_series, args...; kwargs...)
end

function MCMCWrapper(
mcmc_alg::MCMCProtocol,
obs::AV,
obs_noise_cov::UorM,
args...; kwargs...
) where {AV <: AbstractVector, UorM <: Union{UniformScaling,AbstractMatrix}}
observation = Observation(Dict("samples" => obs, "covariances" => obs_noise_cov, "names" => "observation"))
return MCMCWrapper(mcmc_alg, observation, args...; kwargs...)
end
"""
$(DocStringExtensions.FUNCTIONNAME)([rng,] mcmc::MCMCWrapper, args...; kwargs...)
Expand Down

0 comments on commit 056c970

Please sign in to comment.