Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jan 30, 2025
1 parent 905e0a7 commit 7758e10
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions src/MarkovChainMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function to_decorrelated(data::AbstractVector{FT}, em::Emulator{FT}) where {FT <
return [vec(out_data)]
end

function to_decorrelated(data::AVV, em::Emulator{FT}) where {AVV <: AbstractVector, FT <: AbstractFloat}
function to_decorrelated(data::AVV, em::Emulator{FT}) where {AVV <: AbstractVector, FT <: AbstractFloat}
# method for vector of samples
if isa(data[1], AbstractVector)
return [vec(to_decorrelated(reshape(d, :, 1), em)) for d in data] # calls matrix version
Expand Down Expand Up @@ -159,27 +159,31 @@ $(DocStringExtensions.TYPEDSIGNATURES)

Defines the internal log-density function over a vector of observation samples using an assumed conditionally indepedent likelihood, that is with a log-likelihood of `ℓ(y,θ) = sum^n_i log( p(y_i|θ) )`.
"""
function emulator_log_density_model(θ, prior::ParameterDistribution, em::Emulator{FT}, obs_vec::AV) where {FT <: AbstractFloat, AV <: AbstractVector}

function emulator_log_density_model(
θ,
prior::ParameterDistribution,
em::Emulator{FT},
obs_vec::AV,
) where {FT <: AbstractFloat, AV <: AbstractVector}

# θ: model params we evaluate at; in original coords.
# transform_to_real = false means g, g_cov, obs_sample are in decorrelated coords.

# Recall predict() written to return multiple N_samples: expects input to be a
# Matrix with N_samples columns. Returned g is likewise a Matrix, and g_cov is a
# Vector of N_samples covariance matrices. For MH, N_samples is always 1, so we
# have to reshape()/re-cast input/output; simpler to do here than add a
# predict() method.
g, g_cov =
Emulators.predict(em, reshape(θ, :, 1), transform_to_real = false, vector_rf_unstandardize = false)
g, g_cov = Emulators.predict(em, reshape(θ, :, 1), transform_to_real = false, vector_rf_unstandardize = false)
#TODO vector_rf will always unstandardize, but other methods will not, so we require this additional flag.

if isa(g_cov[1], Real)
return 1.0/length(obs_vec)*sum([logpdf(MvNormal(obs, g_cov[1] * I), vec(g)) for obs in obs_vec]) + logpdf(prior, θ)

return 1.0 / length(obs_vec) * sum([logpdf(MvNormal(obs, g_cov[1] * I), vec(g)) for obs in obs_vec]) + logpdf(prior, θ)
else
return 1.0/length(obs_vec)*sum([logpdf(MvNormal(obs, g_cov[1]), vec(g)) for obs in obs_vec]) + logpdf(prior, θ)
return 1.0 / length(obs_vec) * sum([logpdf(MvNormal(obs, g_cov[1]), vec(g)) for obs in obs_vec]) + logpdf(prior, θ)
end

end

"""
Expand All @@ -196,9 +200,7 @@ function EmulatorPosteriorModel(
obs_vec::AV,
) where {FT <: AbstractFloat, AV <: AbstractVector}

return AdvancedMH.DensityModel(
x -> emulator_log_density_model(x, prior, em, obs_vec)
)
return AdvancedMH.DensityModel(x -> emulator_log_density_model(x, prior, em, obs_vec))
end

# ------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -228,7 +230,7 @@ end
MCMCState(model::AdvancedMH.DensityModel, params, accepted = true) =
MCMCState(params, logdensity(model, params), accepted)

AdvancedMH.logdensity(model::AdvancedMH.DensityModel, t::MCMCState) = t.log_density
AdvancedMH.logdensity(model::AdvancedMH.DensityModel, t::MCMCState) = t.log_density

# AdvancedMH.transition() is only called to create a new proposal, so create a MCMCState
# with accepted = true since that object will only be used if proposal is accepted.
Expand Down Expand Up @@ -278,13 +280,13 @@ function AbstractMCMC.step(
current_state::MCMCState;
stepsize::FT = 1.0,
kwargs...,
) where {FT <: AbstractFloat}
) where {FT <: AbstractFloat}
# Generate a new proposal.
new_params = AdvancedMH.propose(rng, sampler, model, current_state; stepsize = stepsize)
# Calculate the log acceptance probability and the log density of the candidate.
new_log_density = AdvancedMH.logdensity(model, new_params)
new_log_density = AdvancedMH.logdensity(model, new_params)
log_α =
new_log_density - AdvancedMH.logdensity(model, current_state) +
new_log_density - AdvancedMH.logdensity(model, current_state) +
AdvancedMH.logratio_proposal_density(sampler, current_state, new_params)

# Decide whether to return the previous params or the new one.
Expand Down Expand Up @@ -402,7 +404,7 @@ AbstractMCMC's terminology).
# Fields
$(DocStringExtensions.TYPEDFIELDS)
"""
struct MCMCWrapper{ AMorAV <: Union{AbstractVector,AbstractMatrix}, AV <: AbstractVector }
struct MCMCWrapper{AMorAV <: Union{AbstractVector, AbstractMatrix}, AV <: AbstractVector}
"[`ParameterDistribution`](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/parameter_distributions/) object describing the prior distribution on parameter values."
prior::ParameterDistribution
"A vector, matrix, or vector or vectors describing the observation(s) provided by the user."
Expand Down Expand Up @@ -450,10 +452,10 @@ function MCMCWrapper(
init_params::AV,
burnin::Int = 0,
kwargs...,
) where {AV <: AbstractVector, AMorAV <: Union{AbstractVector,AbstractMatrix}}
) where {AV <: AbstractVector, AMorAV <: Union{AbstractVector, AbstractMatrix}}

# decorrelate observations into a vector
decorrelated_obs = to_decorrelated(observation, em)
decorrelated_obs = to_decorrelated(observation, em)

log_posterior_map = EmulatorPosteriorModel(prior, em, decorrelated_obs)
mh_proposal_sampler = MetropolisHastingsSampler(mcmc_alg, prior)
Expand Down

0 comments on commit 7758e10

Please sign in to comment.