Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign of MCSE #63

Merged
merged 49 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2d82222
Add mcse_sbm
sethaxen Jan 16, 2023
e4b067b
Update description of `estimator`
sethaxen Jan 16, 2023
34e0221
Add specialized estimators for mean, std, and quantile
sethaxen Jan 16, 2023
6839cd7
Remove vector methods, defaulting to sbm
sethaxen Jan 16, 2023
7dbda2d
Update docstring
sethaxen Jan 16, 2023
b93ce5b
Fix bugs
sethaxen Jan 16, 2023
790a99b
Update docstrings
sethaxen Jan 16, 2023
9a1d3d5
Update docstring
sethaxen Jan 16, 2023
c0d5a94
Move helper functions to own file
sethaxen Jan 16, 2023
cf908af
Rearrange tests
sethaxen Jan 16, 2023
93d121e
Update mcse tests
sethaxen Jan 16, 2023
fe99356
Export mcse_sbm
sethaxen Jan 16, 2023
d12648b
Increment minor version number with DEV suffix
sethaxen Jan 16, 2023
465afac
Merge branch 'main' into mcseupdate
sethaxen Jan 16, 2023
7dac85e
Increment docs and tests version numbers
sethaxen Jan 16, 2023
9407369
Add additional citation
sethaxen Jan 16, 2023
f95a066
Merge branch 'main' into mcseupdate
sethaxen Jan 16, 2023
88b6c41
Update diagnostics to use new mcse
sethaxen Jan 16, 2023
899711e
Increase tolerance of mcse tests
sethaxen Jan 17, 2023
01b8dbc
Increase tolerance more
sethaxen Jan 17, 2023
60d6441
Add mcse_sbm to docs
sethaxen Jan 17, 2023
2441bcb
Skip high autocorrelation tests for mcse_sbm
sethaxen Jan 17, 2023
8e3b06a
Note underestimation for SBM
sethaxen Jan 17, 2023
e7ca85a
Merge branch 'main' into mcseupdate
sethaxen Jan 17, 2023
2af67e9
Update src/mcse.jl
sethaxen Jan 18, 2023
da4ed63
Merge branch 'main' into mcseupdate
sethaxen Jan 18, 2023
b7cd495
Don't enforce type
sethaxen Jan 18, 2023
4d55716
Document kwargs passed to mcse
sethaxen Jan 18, 2023
d9f6734
Cross-link mcse and ess_rhat docstrings
sethaxen Jan 18, 2023
1c48266
Document derivation of mcse for std
sethaxen Jan 18, 2023
f072b9e
Test type-inferrability of ess_rhat
sethaxen Jan 18, 2023
bb47887
Make sure ess_rhat for quantiles not promoted
sethaxen Jan 18, 2023
7f61907
Make sure ess_rhat for median type-inferrable
sethaxen Jan 18, 2023
a03cc2a
Implement specific method for median
sethaxen Jan 18, 2023
bac8a3c
Return missing if any are missing
sethaxen Jan 18, 2023
652b86f
Add mcse tests
sethaxen Jan 18, 2023
8dbae84
Decrease the number of checks
sethaxen Jan 18, 2023
d9aff61
Make ESS/MCSE for median with with Union{Missing,Real}
sethaxen Jan 18, 2023
cced4be
Make _fold_around_median type-inferrable
sethaxen Jan 18, 2023
ce9d427
Increase tolerance for exhaustive tests
sethaxen Jan 18, 2023
787a05f
Fix _fold_around_median
sethaxen Jan 18, 2023
34f3771
Fix count of checks
sethaxen Jan 18, 2023
d10740a
Increase the number of draws
sethaxen Jan 18, 2023
cf09e4e
Apply suggestions from code review
sethaxen Jan 18, 2023
39d74a9
Make sure heideldiag and gewekediag preserve input type
sethaxen Jan 18, 2023
a575fc8
Consistently use first and last for ess_rhat
sethaxen Jan 18, 2023
b4eea8d
Copy comment to _fold_around_median
sethaxen Jan 18, 2023
3738347
Make mcse_sbm an internal function
sethaxen Jan 18, 2023
69fe0a1
Update tests
sethaxen Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.5"
version = "0.3.0-DEV"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Documenter = "0.27"
MCMCDiagnosticTools = "0.2"
MCMCDiagnosticTools = "0.3"
MLJBase = "0.19, 0.20, 0.21"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
julia = "1.3"
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ BDAESSMethod

```@docs
mcse
mcse_sbm
```

## R⋆ diagnostic
Expand Down
2 changes: 1 addition & 1 deletion src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export ess_rhat, ess_rhat_bulk, ess_tail, rhat_tail, ESSMethod, FFTESSMethod, BD
export gelmandiag, gelmandiag_multivariate
export gewekediag
export heideldiag
export mcse
export mcse, mcse_sbm
export rafterydiag
export rstar

Expand Down
9 changes: 6 additions & 3 deletions src/gewekediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5,
n = length(x)
x1 = x[1:round(Int, first * n)]
x2 = x[round(Int, n - last * n + 1):n]
z =
(Statistics.mean(x1) - Statistics.mean(x2)) /
hypot(mcse(x1; kwargs...), mcse(x2; kwargs...))
T = float(eltype(x))
s = hypot(
Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)),
Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)),
devmotion marked this conversation as resolved.
Show resolved Hide resolved
)::T
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
z = (Statistics.mean(x1) - Statistics.mean(x2)) / s
p = SpecialFunctions.erfc(abs(z) / sqrt(2))

return (zscore=z, pvalue=p)
Expand Down
7 changes: 5 additions & 2 deletions src/heideldiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ function heideldiag(
n = length(x)
delta = trunc(Int, 0.10 * n)
y = x[trunc(Int, n / 2):end]
S0 = length(y) * mcse(y; kwargs...)^2
T = float(eltype(x))
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))::T
devmotion marked this conversation as resolved.
Show resolved Hide resolved
S0 = length(y) * s^2
i, pvalue, converged, ybar = 1, 1.0, false, NaN
while i < n / 2
y = x[i:end]
Expand All @@ -33,7 +35,8 @@ function heideldiag(
end
i += delta
end
halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * mcse(y; kwargs...)
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))::T
devmotion marked this conversation as resolved.
Show resolved Hide resolved
halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * s
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
passed = halfwidth / abs(ybar) <= eps
return (
burnin=i + start - 2,
Expand Down
167 changes: 111 additions & 56 deletions src/mcse.jl
Original file line number Diff line number Diff line change
@@ -1,72 +1,127 @@
Base.@irrational normcdf1 0.8413447460685429486 StatsFuns.normcdf(big(1))
Base.@irrational normcdfn1 0.1586552539314570514 StatsFuns.normcdf(big(-1))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

"""
mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...)

Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of
shape `(draws, chains, parameters)`
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

Compute the Monte Carlo standard error (MCSE) of samples `x`.
The optional argument `method` describes how the errors are estimated. Possible options are:
## Estimators

- `:bm` for batch means [^Glynn1991]
- `:imse` initial monotone sequence estimator [^Geyer1992]
- `:ipse` initial positive sequence estimator [^Geyer1992]
`estimator` must accept a vector of the same eltype as `samples` and return a real estimate.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

[^Glynn1991]: Glynn, P. W., & Whitt, W. (1991). Estimating the asymptotic variance with batch means. Operations Research Letters, 10(8), 431-435.
For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate
of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to
`ess_rhat`:
- `Statistics.mean`
- `Statistics.median`
- `Statistics.std`
- `Base.Fix2(Statistics.quantile, p::Real)`

[^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483.
For arbitrary estimator, the subsampling bootstrap method [`mcse_sbm`](@ref) is used, and
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
`kwargs` are forwarded to that function.
"""
function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
return if method === :bm
mcse_bm(x; kwargs...)
elseif method === :imse
mcse_imse(x)
elseif method === :ipse
mcse_ipse(x)
else
throw(ArgumentError("unsupported MCSE method $method"))
mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...)
function mcse(
::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
S = ess_rhat(Statistics.mean, samples; kwargs...)[1]
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S)
end
function mcse(
::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2
S = ess_rhat(Statistics.mean, x; kwargs...)[1]
mean_var = dropdims(Statistics.mean(x; dims=(1, 2)); dims=(1, 2))
mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2))
return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2
end
function mcse(
f::Base.Fix2{typeof(Statistics.quantile),<:Real},
samples::AbstractArray{<:Union{Missing,Real},3};
kwargs...,
)
p = f.x
S = ess_rhat(f, samples; kwargs...)[1]
T = eltype(S)
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
values = similar(S, R)
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
values[i] = _mcse_quantile(vec(xi), p, Si)
end
return values
end

function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x))))
n = length(x)
m = min(div(n, 2), size)
m == size || @warn "batch size was reduced to $m"
mcse = StatsBase.sem(Statistics.mean(@view(x[(i + 1):(i + m)])) for i in 0:m:(n - m))
return mcse
function _mcse_quantile(x, p, Seff)
Seff === missing && return missing
S = length(x)
# quantile error distribution is asymptotically normal; estimate σ (mcse) with 2
# quadrature points: xl and xu, chosen as quantiles so that xu - xl = 2σ
# compute quantiles of error distribution in probability space (i.e. quantiles passed through CDF)
# Beta(α,β) is the approximate error distribution of quantile estimates
α = Seff * p + 1
β = Seff * (1 - p) + 1
prob_x_upper = StatsFuns.betainvcdf(α, β, normcdf1)
prob_x_lower = StatsFuns.betainvcdf(α, β, normcdfn1)
# use inverse ECDF to get quantiles in quantile (x) space
l = max(floor(Int, prob_x_lower * S), 1)
u = min(ceil(Int, prob_x_upper * S), S)
iperm = partialsortperm(x, l:u) # sort as little of x as possible
xl = x[first(iperm)]
xu = x[last(iperm)]
# estimate mcse from quantiles
return (xu - xl) / 2
end
function mcse(
::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
return mcse(Base.Fix2(Statistics.quantile, 1//2), samples; kwargs...)
end

function mcse_imse(x::AbstractVector{<:Real})
n = length(x)
lags = [0, 1]
ghat = StatsBase.autocov(x, lags)
Ghat = sum(ghat)
@inbounds value = Ghat + ghat[2]
@inbounds for i in 2:2:(n - 2)
lags[1] = i
lags[2] = i + 1
StatsBase.autocov!(ghat, x, lags)
Ghat = min(Ghat, sum(ghat))
Ghat > 0 || break
value += 2 * Ghat
end
"""
mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size)

mcse = sqrt(value / n)
Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples`
using the subsampling bootstrap method (SBM).[^FlegalJones2011][^Flegal2012]

return mcse
end
`samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of
the same eltype as `samples` and return a real estimate.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

function mcse_ipse(x::AbstractVector{<:Real})
n = length(x)
lags = [0, 1]
ghat = StatsBase.autocov(x, lags)
@inbounds value = ghat[1] + 2 * ghat[2]
@inbounds for i in 2:2:(n - 2)
lags[1] = i
lags[2] = i + 1
StatsBase.autocov!(ghat, x, lags)
Ghat = sum(ghat)
Ghat > 0 || break
value += 2 * Ghat
end
`batch_size` indicates the size of the overlapping batches used to estimate the MCSE,
defaulting to `floor(Int, sqrt(draws * chains))`.

mcse = sqrt(value / n)
!!! note
SBM tends to underestimate the MCSE, especially for highly autocorrelated chains.
SBM should only be used as a fallbeck when a specific [`mcse`](@ref) method for
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This begs the question whether we should include mcse_sbm at all. On the one hand it's useful to have a fallback, and I don't know of any other general-purpose methods for estimating MCSE. On the other hand, we would prefer a fallback that errs on the side of overestimating MCSE instead of overestimating.

I lean towards including it with this caveat clearly documented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we only want it to be used as a fallback, maybe we do not even want to add another function for it but also call it mcse? And just document there that it is used as a fallback but one should be aware of its limitations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, it's useful to have it as a standalone method so if we want, we can compare its results with those of the specific estimators. But we can make the function internal and copy the documentation to mcse.

sethaxen marked this conversation as resolved.
Show resolved Hide resolved
`estimator` is not available and when the bulk- and tail- [`ess_rhat`](@ref) values
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
indicate low autocorrelation.

return mcse
[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence.
Handbook of Markov Chain Monte Carlo. pp. 175-97.
[pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf)
[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo.
Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72.
doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18)
"""
function mcse_sbm(
f,
x::AbstractArray{<:Union{Missing,Real},3};
batch_size::Int=floor(Int, sqrt(size(x, 1) * size(x, 2))),
)
T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1))
values = similar(x, T, (axes(x, 3),))
for (i, xi) in zip(eachindex(values), eachslice(x; dims=3))
values[i] = _mcse_sbm(f, vec(xi); batch_size=batch_size)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen in the literature a study of how best to combine MCMC chains when estimating MCSE with SBM. I benchmarked a few alternatives and found this one underestimated the MCSE the least: arviz-devs/arviz#1974 (comment)

end
return values
end
function _mcse_sbm(f, x; batch_size)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
n = length(x)
i1 = firstindex(x)
v = Statistics.var(
f(view(x, i:(i + batch_size - 1))) for i in i1:(i1 + n - batch_size);
corrected=false,
)
return sqrt(v * (batch_size//n))
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ DynamicHMC = "3"
FFTW = "1.1"
LogDensityProblems = "0.12, 1, 2"
LogExpFunctions = "0.3"
MCMCDiagnosticTools = "0.2"
MCMCDiagnosticTools = "0.3"
MLJBase = "0.19, 0.20, 0.21"
MLJLIBSVMInterface = "0.1, 0.2"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
Expand Down
38 changes: 0 additions & 38 deletions test/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,44 +32,6 @@ function LogDensityProblems.capabilities(p::CauchyProblem)
return LogDensityProblems.LogDensityOrder{1}()
end

# AR(1) process
function ar1(φ::Real, σ::Real, n::Int...)
T = float(Base.promote_eltype(φ, σ))
x = randn(T, n...)
x .*= σ
accumulate!(x, x; dims=1) do xi, ϵ
return muladd(φ, xi, ϵ)
end
return x
end

asymptotic_dist(::typeof(mean), dist) = Normal(mean(dist), std(dist))
function asymptotic_dist(::typeof(var), dist)
μ = var(dist)
σ = μ * sqrt(kurtosis(dist) + 2)
return Normal(μ, σ)
end
function asymptotic_dist(::typeof(std), dist)
μ = std(dist)
σ = μ * sqrt(kurtosis(dist) + 2) / 2
return Normal(μ, σ)
end
asymptotic_dist(::typeof(median), dist) = asymptotic_dist(Base.Fix2(quantile, 1//2), dist)
function asymptotic_dist(f::Base.Fix2{typeof(quantile),<:Real}, dist)
p = f.x
μ = quantile(dist, p)
σ = sqrt(p * (1 - p)) / pdf(dist, μ)
return Normal(μ, σ)
end
function asymptotic_dist(::typeof(mad), dist::Normal)
# Example 21.10 of Asymptotic Statistics. Van der Vaart
d = Normal(zero(dist.μ), dist.σ)
dtrunc = truncated(d; lower=0)
μ = median(dtrunc)
σ = 1 / (4 * pdf(d, quantile(d, 3//4)))
return Normal(μ, σ) / quantile(Normal(), 3//4)
end

@testset "ess.jl" begin
@testset "ESS and R̂ (IID samples)" begin
# Repeat tests with different scales
Expand Down
39 changes: 39 additions & 0 deletions test/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Distributions, Statistics, StatsBase

# AR(1) process
function ar1(φ::Real, σ::Real, n::Int...)
T = float(Base.promote_eltype(φ, σ))
x = randn(T, n...)
x .*= σ
accumulate!(x, x; dims=1) do xi, ϵ
return muladd(φ, xi, ϵ)
end
return x
end

asymptotic_dist(::typeof(mean), dist) = Normal(mean(dist), std(dist))
function asymptotic_dist(::typeof(var), dist)
μ = var(dist)
σ = μ * sqrt(kurtosis(dist) + 2)
return Normal(μ, σ)
end
function asymptotic_dist(::typeof(std), dist)
μ = std(dist)
σ = μ * sqrt(kurtosis(dist) + 2) / 2
return Normal(μ, σ)
end
asymptotic_dist(::typeof(median), dist) = asymptotic_dist(Base.Fix2(quantile, 1//2), dist)
function asymptotic_dist(f::Base.Fix2{typeof(quantile),<:Real}, dist)
p = f.x
μ = quantile(dist, p)
σ = sqrt(p * (1 - p)) / pdf(dist, μ)
return Normal(μ, σ)
end
function asymptotic_dist(::typeof(mad), dist::Normal)
# Example 21.10 of Asymptotic Statistics. Van der Vaart
d = Normal(zero(dist.μ), dist.σ)
dtrunc = truncated(d; lower=0)
μ = median(dtrunc)
σ = 1 / (4 * pdf(d, quantile(d, 3//4)))
return Normal(μ, σ) / quantile(Normal(), 3//4)
end
Loading