Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign of MCSE #63

Merged
merged 49 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2d82222
Add mcse_sbm
sethaxen Jan 16, 2023
e4b067b
Update description of `estimator`
sethaxen Jan 16, 2023
34e0221
Add specialized estimators for mean, std, and quantile
sethaxen Jan 16, 2023
6839cd7
Remove vector methods, defaulting to sbm
sethaxen Jan 16, 2023
7dbda2d
Update docstring
sethaxen Jan 16, 2023
b93ce5b
Fix bugs
sethaxen Jan 16, 2023
790a99b
Update docstrings
sethaxen Jan 16, 2023
9a1d3d5
Update docstring
sethaxen Jan 16, 2023
c0d5a94
Move helper functions to own file
sethaxen Jan 16, 2023
cf908af
Rearrange tests
sethaxen Jan 16, 2023
93d121e
Update mcse tests
sethaxen Jan 16, 2023
fe99356
Export mcse_sbm
sethaxen Jan 16, 2023
d12648b
Increment minor version number with DEV suffix
sethaxen Jan 16, 2023
465afac
Merge branch 'main' into mcseupdate
sethaxen Jan 16, 2023
7dac85e
Increment docs and tests version numbers
sethaxen Jan 16, 2023
9407369
Add additional citation
sethaxen Jan 16, 2023
f95a066
Merge branch 'main' into mcseupdate
sethaxen Jan 16, 2023
88b6c41
Update diagnostics to use new mcse
sethaxen Jan 16, 2023
899711e
Increase tolerance of mcse tests
sethaxen Jan 17, 2023
01b8dbc
Increase tolerance more
sethaxen Jan 17, 2023
60d6441
Add mcse_sbm to docs
sethaxen Jan 17, 2023
2441bcb
Skip high autocorrelation tests for mcse_sbm
sethaxen Jan 17, 2023
8e3b06a
Note underestimation for SBM
sethaxen Jan 17, 2023
e7ca85a
Merge branch 'main' into mcseupdate
sethaxen Jan 17, 2023
2af67e9
Update src/mcse.jl
sethaxen Jan 18, 2023
da4ed63
Merge branch 'main' into mcseupdate
sethaxen Jan 18, 2023
b7cd495
Don't enforce type
sethaxen Jan 18, 2023
4d55716
Document kwargs passed to mcse
sethaxen Jan 18, 2023
d9f6734
Cross-link mcse and ess_rhat docstrings
sethaxen Jan 18, 2023
1c48266
Document derivation of mcse for std
sethaxen Jan 18, 2023
f072b9e
Test type-inferrability of ess_rhat
sethaxen Jan 18, 2023
bb47887
Make sure ess_rhat for quantiles not promoted
sethaxen Jan 18, 2023
7f61907
Make sure ess_rhat for median type-inferrable
sethaxen Jan 18, 2023
a03cc2a
Implement specific method for median
sethaxen Jan 18, 2023
bac8a3c
Return missing if any are missing
sethaxen Jan 18, 2023
652b86f
Add mcse tests
sethaxen Jan 18, 2023
8dbae84
Decrease the number of checks
sethaxen Jan 18, 2023
d9aff61
Make ESS/MCSE for median with with Union{Missing,Real}
sethaxen Jan 18, 2023
cced4be
Make _fold_around_median type-inferrable
sethaxen Jan 18, 2023
ce9d427
Increase tolerance for exhaustive tests
sethaxen Jan 18, 2023
787a05f
Fix _fold_around_median
sethaxen Jan 18, 2023
34f3771
Fix count of checks
sethaxen Jan 18, 2023
d10740a
Increase the number of draws
sethaxen Jan 18, 2023
cf09e4e
Apply suggestions from code review
sethaxen Jan 18, 2023
39d74a9
Make sure heideldiag and gewekediag preserve input type
sethaxen Jan 18, 2023
a575fc8
Consistently use first and last for ess_rhat
sethaxen Jan 18, 2023
b4eea8d
Copy comment to _fold_around_median
sethaxen Jan 18, 2023
3738347
Make mcse_sbm an internal function
sethaxen Jan 18, 2023
69fe0a1
Update tests
sethaxen Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Distributions: Distributions
using MLJModelInterface: MLJModelInterface as MMI
using SpecialFunctions: SpecialFunctions
using StatsBase: StatsBase
using StatsFuns: StatsFuns
using StatsFuns: StatsFuns, sqrt2
using Tables: Tables

using LinearAlgebra: LinearAlgebra
Expand Down
19 changes: 13 additions & 6 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains`
``\\widehat{R} < 1.01``.[^VehtariGelman2021]

See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref),
[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref)
[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref), [`mcse`](@ref)

## Estimators

Expand Down Expand Up @@ -435,8 +435,8 @@ function ess_tail(
# workaround for https://github.com/JuliaStats/Statistics.jl/issues/136
T = Base.promote_eltype(x, tail_prob)
return min.(
ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)[1],
ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)[1],
first(ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)),
first(ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)),
)
end

Expand Down Expand Up @@ -464,13 +464,20 @@ See also: [`ess_tail`](@ref), [`ess_rhat_bulk`](@ref)
doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221)
arXiv: [1903.08008](https://arxiv.org/abs/1903.08008)
"""
rhat_tail(x; kwargs...) = ess_rhat_bulk(_fold_around_median(x); kwargs...)[2]
rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...))

# Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``.
# If no proxy expectand for `f` is known, `nothing` is returned.
_expectand_proxy(f, x) = nothing
function _expectand_proxy(::typeof(Statistics.median), x)
return x .≤ Statistics.median(x; dims=(1, 2))
y = similar(x)
# avoid using the `dims` keyword for median because it
# - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8)
# - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39)
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
yi .= xi .≤ Statistics.median(vec(xi))
end
return y
end
function _expectand_proxy(::typeof(Statistics.std), x)
return (x .- Statistics.mean(x; dims=(1, 2))) .^ 2
Expand All @@ -480,7 +487,7 @@ function _expectand_proxy(::typeof(StatsBase.mad), x)
return _expectand_proxy(Statistics.median, x_folded)
end
function _expectand_proxy(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x)
y = similar(x, Bool)
y = similar(x)
Copy link
Member Author

@sethaxen sethaxen Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary because otherwise our transformation discards the type of x, so a Float32 eltype x will get a Float64 ESS/R-hat/MCSE.

# currently quantile does not support a dims keyword argument
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
yi .= xi .≤ f(vec(xi))
Expand Down
12 changes: 8 additions & 4 deletions src/gewekediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ samples are independent. A non-significant test p-value indicates convergence.
p-values indicate non-convergence and the possible need to discard initial samples as a
burn-in sequence or to simulate additional samples.

`kwargs` are forwarded to [`mcse`](@ref).

[^Geweke1991]: Geweke, J. F. (1991). Evaluating the accuracy of sampling-based approaches to the calculation of posterior moments (No. 148). Federal Reserve Bank of Minneapolis.
"""
function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, kwargs...)
Expand All @@ -22,10 +24,12 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5,
n = length(x)
x1 = x[1:round(Int, first * n)]
x2 = x[round(Int, n - last * n + 1):n]
z =
(Statistics.mean(x1) - Statistics.mean(x2)) /
hypot(mcse(x1; kwargs...), mcse(x2; kwargs...))
p = SpecialFunctions.erfc(abs(z) / sqrt(2))
s = hypot(
Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)),
Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)),
devmotion marked this conversation as resolved.
Show resolved Hide resolved
)
z = (Statistics.mean(x1) - Statistics.mean(x2)) / s
p = SpecialFunctions.erfc(abs(z) / sqrt2)

return (zscore=z, pvalue=p)
end
15 changes: 10 additions & 5 deletions src/heideldiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,36 @@ means are within a target ratio. Stationarity is rejected (0) for significant te
Halfwidth tests are rejected (0) if observed ratios are greater than the target, as is the
case for `s2` and `beta[1]`.

`kwargs` are forwarded to [`mcse`](@ref).

[^Heidelberger1983]: Heidelberger, P., & Welch, P. D. (1983). Simulation run length control in the presence of an initial transient. Operations Research, 31(6), 1109-1144.
"""
function heideldiag(
x::AbstractVector{<:Real}; alpha::Real=0.05, eps::Real=0.1, start::Int=1, kwargs...
x::AbstractVector{<:Real}; alpha::Real=1//20, eps::Real=0.1, start::Int=1, kwargs...
)
n = length(x)
delta = trunc(Int, 0.10 * n)
y = x[trunc(Int, n / 2):end]
S0 = length(y) * mcse(y; kwargs...)^2
i, pvalue, converged, ybar = 1, 1.0, false, NaN
T = typeof(zero(eltype(x)) / 1)
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))
S0 = length(y) * s^2
i, pvalue, converged, ybar = 1, one(T), false, T(NaN)
while i < n / 2
y = x[i:end]
m = length(y)
ybar = Statistics.mean(y)
B = cumsum(y) - ybar * collect(1:m)
Bsq = (B .* B) ./ (m * S0)
I = sum(Bsq) / m
pvalue = 1.0 - pcramer(I)
pvalue = 1 - T(pcramer(I))
converged = pvalue > alpha
if converged
break
end
i += delta
end
halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * mcse(y; kwargs...)
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))
halfwidth = sqrt2 * SpecialFunctions.erfcinv(T(alpha)) * s
passed = halfwidth / abs(ybar) <= eps
return (
burnin=i + start - 2,
Expand Down
171 changes: 114 additions & 57 deletions src/mcse.jl
Original file line number Diff line number Diff line change
@@ -1,72 +1,129 @@
const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1)
const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1)

"""
mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...)

Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of
shape `(draws, chains, parameters)`.

See also: [`ess_rhat`](@ref)

## Estimators

`estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate.

Compute the Monte Carlo standard error (MCSE) of samples `x`.
The optional argument `method` describes how the errors are estimated. Possible options are:
For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate
of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to
`ess_rhat`:
- `Statistics.mean`
- `Statistics.median`
- `Statistics.std`
- `Base.Fix2(Statistics.quantile, p::Real)`

- `:bm` for batch means [^Glynn1991]
- `:imse` initial monotone sequence estimator [^Geyer1992]
- `:ipse` initial positive sequence estimator [^Geyer1992]
For other estimators, the subsampling bootstrap method (SBM)[^FlegalJones2011][^Flegal2012]
is used as a fallback, and the only accepted `kwargs` are `batch_size`, which indicates the
size of the overlapping batches used to estimate the MCSE, defaulting to
`floor(Int, sqrt(draws * chains))`. Note that SBM tends to underestimate the MCSE,
especially for highly autocorrelated chains. One should verify that autocorrelation is low
by checking the bulk- and tail-[`ess_rhat`](@ref) values.

[^Glynn1991]: Glynn, P. W., & Whitt, W. (1991). Estimating the asymptotic variance with batch means. Operations Research Letters, 10(8), 431-435.
[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence.
Handbook of Markov Chain Monte Carlo. pp. 175-97.
[pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf)
[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo.
Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72.
doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18)

[^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483.
"""
function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
return if method === :bm
mcse_bm(x; kwargs...)
elseif method === :imse
mcse_imse(x)
elseif method === :ipse
mcse_ipse(x)
else
throw(ArgumentError("unsupported MCSE method $method"))
mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _mcse_sbm(f, x; kwargs...)
function mcse(
::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
S = first(ess_rhat(Statistics.mean, samples; kwargs...))
return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S)
end
function mcse(
::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy
S = first(ess_rhat(Statistics.mean, x; kwargs...))
# asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]²,
# where μ₄ is the 4th central moment
# by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4,
# See e.g. Chapter 3 of Van der Vaart, AW. (200) Asymptotic statistics. Vol. 3.
mean_var = dropdims(Statistics.mean(x; dims=(1, 2)); dims=(1, 2))
mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2))
return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2
end
function mcse(
f::Base.Fix2{typeof(Statistics.quantile),<:Real},
samples::AbstractArray{<:Union{Missing,Real},3};
kwargs...,
)
p = f.x
S = first(ess_rhat(f, samples; kwargs...))
T = eltype(S)
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
values = similar(S, R)
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
values[i] = _mcse_quantile(vec(xi), p, Si)
end
return values
end
function mcse(
::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
S = first(ess_rhat(Statistics.median, samples; kwargs...))
T = eltype(S)
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
values = similar(S, R)
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
values[i] = _mcse_quantile(vec(xi), 1//2, Si)
end
return values
end

function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x))))
n = length(x)
m = min(div(n, 2), size)
m == size || @warn "batch size was reduced to $m"
mcse = StatsBase.sem(Statistics.mean(@view(x[(i + 1):(i + m)])) for i in 0:m:(n - m))
return mcse
function _mcse_quantile(x, p, Seff)
Seff === missing && return missing
S = length(x)
# quantile error distribution is asymptotically normal; estimate σ (mcse) with 2
# quadrature points: xl and xu, chosen as quantiles so that xu - xl = 2σ
# compute quantiles of error distribution in probability space (i.e. quantiles passed through CDF)
# Beta(α,β) is the approximate error distribution of quantile estimates
α = Seff * p + 1
β = Seff * (1 - p) + 1
prob_x_upper = StatsFuns.betainvcdf(α, β, normcdf1)
prob_x_lower = StatsFuns.betainvcdf(α, β, normcdfn1)
# use inverse ECDF to get quantiles in quantile (x) space
l = max(floor(Int, prob_x_lower * S), 1)
u = min(ceil(Int, prob_x_upper * S), S)
iperm = partialsortperm(x, l:u) # sort as little of x as possible
xl = x[first(iperm)]
xu = x[last(iperm)]
# estimate mcse from quantiles
return (xu - xl) / 2
end

function mcse_imse(x::AbstractVector{<:Real})
n = length(x)
lags = [0, 1]
ghat = StatsBase.autocov(x, lags)
Ghat = sum(ghat)
@inbounds value = Ghat + ghat[2]
@inbounds for i in 2:2:(n - 2)
lags[1] = i
lags[2] = i + 1
StatsBase.autocov!(ghat, x, lags)
Ghat = min(Ghat, sum(ghat))
Ghat > 0 || break
value += 2 * Ghat
function _mcse_sbm(
f,
x::AbstractArray{<:Union{Missing,Real},3};
batch_size::Int=floor(Int, sqrt(size(x, 1) * size(x, 2))),
)
T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1))
values = similar(x, T, (axes(x, 3),))
for (i, xi) in zip(eachindex(values), eachslice(x; dims=3))
values[i] = _mcse_sbm(f, vec(xi), batch_size)
end

mcse = sqrt(value / n)

return mcse
return values
end

function mcse_ipse(x::AbstractVector{<:Real})
function _mcse_sbm(f, x, batch_size)
any(x -> x === missing, x) && return missing
n = length(x)
lags = [0, 1]
ghat = StatsBase.autocov(x, lags)
@inbounds value = ghat[1] + 2 * ghat[2]
@inbounds for i in 2:2:(n - 2)
lags[1] = i
lags[2] = i + 1
StatsBase.autocov!(ghat, x, lags)
Ghat = sum(ghat)
Ghat > 0 || break
value += 2 * Ghat
end

mcse = sqrt(value / n)

return mcse
i1 = firstindex(x)
v = Statistics.var(
f(view(x, i:(i + batch_size - 1))) for i in i1:(i1 + n - batch_size);
corrected=false,
)
return sqrt(v * (batch_size//n))
end
11 changes: 10 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,16 @@ end

Compute the absolute deviation of `x` from `Statistics.median(x)`.
"""
_fold_around_median(data) = abs.(data .- Statistics.median(data; dims=(1, 2)))
function _fold_around_median(x)
y = similar(x)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
# avoid using the `dims` keyword for median because it
# - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8)
# - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39)
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
yi .= abs.(xi .- Statistics.median(vec(xi)))
end
return y
end

"""
_rank_normalize(x::AbstractArray{<:Any,3})
Expand Down
Loading