diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4867c046..b96ee980 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: version: - - '1.0' + - '1.3' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index c6298096..05009511 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.1.5" +version = "0.2.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -23,7 +23,7 @@ MLJModelInterface = "1.6" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.33" Tables = "1" -julia = "1" +julia = "1.3" [extras] Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/docs/Project.toml b/docs/Project.toml index 1ff9375c..ad15f1d3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" -MCMCDiagnosticTools = "0.1" +MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJXGBoostInterface = "0.1, 0.2, 0.3" julia = "1.3" diff --git a/src/discretediag.jl b/src/discretediag.jl index ddf17248..72e9ca6e 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -372,7 +372,7 @@ function discretediag_sub( start_iter::Int, step_size::Int, ) - num_iters, num_vars, num_chains = size(c) + num_iters, num_chains, num_vars = size(c) ## Between-chain diagnostic length_results = length(start_iter:step_size:num_iters) @@ -384,7 +384,7 @@ function discretediag_sub( pvalue=Vector{Float64}(undef, num_vars), ) for j in 1:num_vars - X = convert(AbstractMatrix{Int}, c[:, j, :]) + X = convert(AbstractMatrix{Int}, c[:, :, j]) result = diag_all(X, method, nsim, start_iter, step_size) plot_vals_stat[:, j] .= result.stat ./ result.df @@ -403,7 +403,7 @@ function discretediag_sub( ) for k in 1:num_chains for j in 1:num_vars - x = convert(AbstractVector{Int}, c[:, j, k]) + x = convert(AbstractVector{Int}, c[:, k, j]) idx1 = 1:round(Int, frac * num_iters) idx2 = round(Int, num_iters - frac * num_iters + 1):num_iters @@ -423,14 +423,16 @@ function discretediag_sub( end """ - discretediag(chains::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000) + discretediag(samples::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000) -Compute discrete diagnostic where `method` can be one of `:weiss`, `:hangartner`, +Compute discrete diagnostic on `samples` with shape `(draws, chains, parameters)`. + +`method` can be one of `:weiss`, `:hangartner`, `:DARBOOT`, `:MCBOOT`, `:billinsgley`, and `:billingsleyBOOT`. # References -Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable. +Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable. """ function discretediag( chains::AbstractArray{<:Real,3}; frac::Real=0.3, method::Symbol=:weiss, nsim::Int=1000 diff --git a/src/ess.jl b/src/ess.jl index ee06b13c..d385ad8d 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -201,7 +201,7 @@ end ) Estimate the effective sample size and the potential scale reduction of the `samples` of -shape (draws, parameters, chains) with the `method` and a maximum lag of `maxlag`. +shape `(draws, chains, parameters)` with the `method` and a maximum lag of `maxlag`. See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref) """ @@ -212,8 +212,8 @@ function ess_rhat( ) # compute size of matrices (each chain is split!) niter = size(chains, 1) ÷ 2 - nparams = size(chains, 2) - nchains = 2 * size(chains, 3) + nparams = size(chains, 3) + nchains = 2 * size(chains, 2) ntotal = niter * nchains # do not compute estimates if there is only one sample or lag @@ -238,7 +238,7 @@ function ess_rhat( rhat = Vector{T}(undef, nparams) # for each parameter - for (i, chains_slice) in enumerate((view(chains, :, i, :) for i in axes(chains, 2))) + for (i, chains_slice) in enumerate(eachslice(chains; dims=3)) # check that no values are missing if any(x -> x === missing, chains_slice) rhat[i] = missing diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 87626672..001ab282 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -1,14 +1,15 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) - niters, nparams, nchains = size(psi) + niters, nchains, nparams = size(psi) nchains > 1 || error("Gelman diagnostic requires at least 2 chains") rfixed = (niters - 1) / niters rrandomscale = (nchains + 1) / (nchains * niters) - S2 = map(Statistics.cov, (view(psi, :, :, i) for i in axes(psi, 3))) + # `eachslice(psi; dims=2)` breaks type inference + S2 = map(x -> Statistics.cov(x; dims=1), (view(psi, :, i, :) for i in axes(psi, 2))) W = Statistics.mean(S2) - psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)' + psibar = dropdims(Statistics.mean(psi; dims=1); dims=1) B = niters .* Statistics.cov(psibar) w = LinearAlgebra.diag(W) @@ -52,9 +53,10 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) end """ - gelmandiag(chains::AbstractArray{<:Real,3}; alpha::Real=0.95) + gelmandiag(samples::AbstractArray{<:Real,3}; alpha::Real=0.95) -Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998]. Values of the +Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998] on `samples` +with shape `(draws, chains, parameters)`. Values of the diagnostic’s potential scale reduction factor (PSRF) that are close to one suggest convergence. As a rule-of-thumb, convergence is rejected if the 97.5 percentile of a PSRF is greater than 1.2. @@ -70,12 +72,13 @@ function gelmandiag(chains::AbstractArray{<:Real,3}; kwargs...) end """ - gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; alpha::Real=0.05) + gelmandiag_multivariate(samples::AbstractArray{<:Real,3}; alpha::Real=0.05) -Compute the multivariate Gelman, Rubin and Brooks diagnostics. +Compute the multivariate Gelman, Rubin and Brooks diagnostics on `samples` with shape +`(draws, chains, parameters)`. """ function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...) - niters, nparams, nchains = size(chains) + niters, nchains, nparams = size(chains) if nparams < 2 error( "computation of the multivariate potential scale reduction factor requires ", diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl index 8f34f0b8..a6960838 100644 --- a/src/rafterydiag.jl +++ b/src/rafterydiag.jl @@ -38,7 +38,7 @@ function rafterydiag( dichot = Int[(x .<= StatsBase.quantile(x, q))...] kthin = 0 bic = 1.0 - local test , ntest + local test, ntest while bic >= 0.0 kthin += 1 test = dichot[1:kthin:nx] diff --git a/src/rstar.jl b/src/rstar.jl index 2def08d4..7eb61c9b 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -1,15 +1,91 @@ """ rstar( - rng=Random.GLOBAL_RNG, - classifier, - samples::AbstractMatrix, + rng::Random.AbstractRNG=Random.default_rng(), + classifier::MLJModelInterface.Supervised, + samples, chain_indices::AbstractVector{Int}; subset::Real=0.8, verbosity::Int=0, ) -Compute the ``R^*`` convergence statistic of the `samples` with shape (draws, parameters) -and corresponding chains `chain_indices` with the `classifier`. +Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`. + +`samples` must be either an `AbstractMatrix`, an `AbstractVector`, or a table +(i.e. implements the Tables.jl interface) whose rows are draws and whose columns are +parameters. + +`chain_indices` indicates the chain ids of each row of `samples`. + +This method supports ragged chains, i.e. chains of nonequal lengths. +""" +function rstar( + rng::Random.AbstractRNG, + classifier::MLJModelInterface.Supervised, + x, + y::AbstractVector{Int}; + subset::Real=0.8, + verbosity::Int=0, +) + # checks + MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch()) + 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) + + # randomly sub-select training and testing set + N = length(y) + Ntrain = round(Int, N * subset) + 0 < Ntrain < N || + throw(ArgumentError("training and test data subsets must not be empty")) + ids = Random.randperm(rng, N) + train_ids = view(ids, 1:Ntrain) + test_ids = view(ids, (Ntrain + 1):N) + + xtable = _astable(x) + + # train classifier on training data + ycategorical = MLJModelInterface.categorical(y) + xtrain = MLJModelInterface.selectrows(xtable, train_ids) + fitresult, _ = MLJModelInterface.fit( + classifier, verbosity, xtrain, ycategorical[train_ids] + ) + + # compute predictions on test data + xtest = MLJModelInterface.selectrows(xtable, test_ids) + predictions = _predict(classifier, fitresult, xtest) + + # compute statistic + ytest = ycategorical[test_ids] + result = _rstar(predictions, ytest) + + return result +end + +_astable(x::AbstractVecOrMat) = Tables.table(x) +_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table")) + +# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863 +# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information +# TODO: Remove once the upstream issue is fixed +function _predict(model::MLJModelInterface.Model, fitresult, x) + y = MLJModelInterface.predict(model, fitresult, x) + return if :predict in MLJModelInterface.reporting_operations(model) + first(y) + else + y + end +end + +""" + rstar( + rng::Random.AbstractRNG=Random.default_rng(), + classifier::MLJModelInterface.Supervised, + samples::AbstractArray{<:Real,3}; + subset::Real=0.8, + verbosity::Int=0, + ) + +Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`. + +`samples` is an array of draws with the shape `(draws, chains, parameters)`.` This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari. @@ -29,19 +105,17 @@ is returned (algorithm 2). # Examples -```jldoctest rstar; setup = :(using Random; Random.seed!(100)) +```jldoctest rstar; setup = :(using Random; Random.seed!(101)) julia> using MLJBase, MLJXGBoostInterface, Statistics -julia> samples = fill(4.0, 300, 2); - -julia> chain_indices = repeat(1:3; outer=100); +julia> samples = fill(4.0, 100, 3, 2); ``` One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the probabilistic classifier. ```jldoctest rstar -julia> distribution = rstar(XGBoostClassifier(), samples, chain_indices); +julia> distribution = rstar(XGBoostClassifier(), samples); julia> isapprox(mean(distribution), 1; atol=0.1) true @@ -54,7 +128,7 @@ predicting the mode. In MLJ this corresponds to a pipeline of models. ```jldoctest rstar julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); -julia> value = rstar(xgboost_deterministic, samples, chain_indices); +julia> value = rstar(xgboost_deterministic, samples); julia> isapprox(value, 1; atol=0.2) true @@ -67,60 +141,20 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic function rstar( rng::Random.AbstractRNG, classifier::MLJModelInterface.Supervised, - x::AbstractMatrix, - y::AbstractVector{Int}; - subset::Real=0.8, - verbosity::Int=0, + x::AbstractArray{<:Any,3}; + kwargs..., ) - # checks - size(x, 1) != length(y) && throw(DimensionMismatch()) - 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) - - # randomly sub-select training and testing set - N = length(y) - Ntrain = round(Int, N * subset) - 0 < Ntrain < N || - throw(ArgumentError("training and test data subsets must not be empty")) - ids = Random.randperm(rng, N) - train_ids = view(ids, 1:Ntrain) - test_ids = view(ids, (Ntrain + 1):N) - - # train classifier on training data - ycategorical = MLJModelInterface.categorical(y) - fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, Tables.table(x[train_ids, :]), ycategorical[train_ids] - ) - - # compute predictions on test data - xtest = Tables.table(x[test_ids, :]) - predictions = _predict(classifier, fitresult, xtest) - - # compute statistic - ytest = ycategorical[test_ids] - result = _rstar(predictions, ytest) - - return result + samples = reshape(x, :, size(x, 3)) + chain_inds = repeat(axes(x, 2); inner=size(x, 1)) + return rstar(rng, classifier, samples, chain_inds; kwargs...) end -# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863 -# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information -# TODO: Remove once the upstream issue is fixed -function _predict(model::MLJModelInterface.Model, fitresult, x) - y = MLJModelInterface.predict(model, fitresult, x) - return if :predict in MLJModelInterface.reporting_operations(model) - first(y) - else - y - end +function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...) + return rstar(Random.default_rng(), classif, x, y; kwargs...) end -function rstar( - classif::MLJModelInterface.Supervised, - x::AbstractMatrix, - y::AbstractVector{Int}; - kwargs..., -) - return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) +function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...) + return rstar(Random.default_rng(), classif, x; kwargs...) end # R⋆ for deterministic predictions (algorithm 1) diff --git a/test/Project.toml b/test/Project.toml index c3a16aaf..09fca5b2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,22 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" +MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Distributions = "0.25" FFTW = "1.1" -julia = "1" +MCMCDiagnosticTools = "0.2" +MLJBase = "0.19, 0.20, 0.21" +MLJLIBSVMInterface = "0.1, 0.2" +MLJXGBoostInterface = "0.1, 0.2, 0.3" +Tables = "1" +julia = "1.3" diff --git a/test/discretediag.jl b/test/discretediag.jl index 8415614f..2796aa96 100644 --- a/test/discretediag.jl +++ b/test/discretediag.jl @@ -1,7 +1,8 @@ @testset "discretediag.jl" begin nparams = 4 + ndraws = 100 nchains = 2 - samples = rand(-100:100, 100, nparams, nchains) + samples = rand(-100:100, ndraws, nchains, nparams) @testset "results" begin for method in diff --git a/test/ess.jl b/test/ess.jl index 4528c735..c58c00b7 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -31,7 +31,7 @@ end @testset "ESS and R̂ (IID samples)" begin - rawx = randn(10_000, 40, 10) + rawx = randn(10_000, 10, 40) # Repeat tests with different scales for scale in (1, 50, 100) @@ -58,7 +58,7 @@ end @testset "ESS and R̂ (identical samples)" begin - x = ones(10_000, 40, 10) + x = ones(10_000, 10, 40) ess_standard, rhat_standard = ess_rhat(x) ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod()) @@ -75,15 +75,15 @@ end @testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed - x = rand(1, 5, 3) + x = rand(1, 3, 5) for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) # analyze array ess_array, rhat_array = ess_rhat(x; method=method) - @test length(ess_array) == size(x, 2) + @test length(ess_array) == size(x, 3) @test all(ismissing, ess_array) # since min(maxlag, niter - 1) = 0 - @test length(rhat_array) == size(x, 2) + @test length(rhat_array) == size(x, 3) @test all(ismissing, rhat_array) end end diff --git a/test/gelmandiag.jl b/test/gelmandiag.jl index 045459e1..23e1a3dd 100644 --- a/test/gelmandiag.jl +++ b/test/gelmandiag.jl @@ -1,7 +1,8 @@ @testset "gelmandiag.jl" begin nparams = 4 + ndraws = 100 nchains = 2 - samples = randn(100, nparams, nchains) + samples = randn(ndraws, nchains, nparams) @testset "results" begin result = @inferred(gelmandiag(samples)) @@ -23,7 +24,7 @@ end @testset "exceptions" begin - @test_throws ErrorException gelmandiag(samples[:, :, 1:1]) - @test_throws ErrorException gelmandiag_multivariate(samples[:, 1:1, :]) + @test_throws ErrorException gelmandiag(samples[:, 1:1, :]) + @test_throws ErrorException gelmandiag_multivariate(samples[:, :, 1:1]) end end diff --git a/test/gewekediag.jl b/test/gewekediag.jl index f877ccde..5cd8f211 100644 --- a/test/gewekediag.jl +++ b/test/gewekediag.jl @@ -3,7 +3,7 @@ @testset "results" begin @test @inferred(gewekediag(samples)) isa - NamedTuple{(:zscore, :pvalue),Tuple{Float64,Float64}} + NamedTuple{(:zscore, :pvalue),Tuple{Float64,Float64}} end @testset "exceptions" begin diff --git a/test/rstar.jl b/test/rstar.jl new file mode 100644 index 00000000..928e50a4 --- /dev/null +++ b/test/rstar.jl @@ -0,0 +1,112 @@ +using MCMCDiagnosticTools + +using Distributions +using MLJBase +using MLJLIBSVMInterface +using MLJXGBoostInterface +using Tables + +using Random +using Test + +const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode) + +@testset "rstar.jl" begin + classifiers = (XGBoostClassifier(), xgboost_deterministic, SVC()) + N = 1_000 + + @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] + @testset "examples (classifier = $classifier)" for classifier in classifiers + sz = wrapper === Vector ? N : (N, 2) + # Compute R⋆ statistic for a mixed chain. + samples = wrapper(randn(sz...)) + dist = rstar(classifier, samples, rand(1:3, N)) + + # Mean of the statistic should be focused around 1, i.e., the classifier does not + # perform better than random guessing. + if classifier isa MLJBase.Deterministic + @test dist isa Float64 + else + @test dist isa LocationScale + @test dist.ρ isa PoissonBinomial + @test minimum(dist) == 0 + @test maximum(dist) == 3 + end + @test mean(dist) ≈ 1 rtol = 0.2 + wrapper === Vector && break + + # Compute R⋆ statistic for a mixed chain. + samples = wrapper(randn(4 * N, 8)) + chain_indices = repeat(1:4, N) + dist = rstar(classifier, samples, chain_indices) + + # Mean of the statistic should be closte to 1, i.e., the classifier does not perform + # better than random guessing. + if classifier isa MLJBase.Deterministic + @test dist isa Float64 + else + @test dist isa LocationScale + @test dist.ρ isa PoissonBinomial + @test minimum(dist) == 0 + @test maximum(dist) == 4 + end + @test mean(dist) ≈ 1 rtol = 0.15 + + # Compute the R⋆ statistic for a non-mixed chain. + samples = wrapper([ + sin.(1:N) cos.(1:N) + 100 .* cos.(1:N) 100 .* sin.(1:N) + ]) + chain_indices = repeat(1:2; inner=N) + dist = rstar(classifier, samples, chain_indices) + + # Mean of the statistic should be close to 2, i.e., the classifier should be able to + # learn an almost perfect decision boundary between chains. + if classifier isa MLJBase.Deterministic + @test dist isa Float64 + else + @test dist isa LocationScale + @test dist.ρ isa PoissonBinomial + @test minimum(dist) == 0 + @test maximum(dist) == 2 + end + @test mean(dist) ≈ 2 rtol = 0.15 + end + wrapper === Vector && continue + + @testset "exceptions (classifier = $classifier)" for classifier in classifiers + samples = wrapper(randn(N - 1, 2)) + @test_throws DimensionMismatch rstar(classifier, samples, rand(1:3, N)) + for subset in (-0.3, 0, 1 / (3 * N), 1 - 1 / (3 * N), 1, 1.9) + samples = wrapper(randn(N, 2)) + @test_throws ArgumentError rstar( + classifier, samples, rand(1:3, N); subset=subset + ) + end + end + end + + @testset "table with chain_ids produces same result as 3d array" begin + nparams = 2 + nchains = 3 + samples = randn(N, nchains, nparams) + + # manually construct samples_mat and chain_inds for comparison + samples_mat = reshape(samples, N * nchains, nparams) + chain_inds = Vector{Int}(undef, N * nchains) + i = 1 + for chain in 1:nchains, draw in 1:N + chain_inds[i] = chain + i += 1 + end + + @testset "classifier = $classifier" for classifier in classifiers + rng = MersenneTwister(42) + dist1 = rstar(rng, classifier, samples_mat, chain_inds) + Random.seed!(rng, 42) + dist2 = rstar(rng, classifier, samples) + @test dist1 == dist2 + @test typeof(rstar(classifier, samples)) === typeof(dist2) + end + end +end diff --git a/test/rstar/Project.toml b/test/rstar/Project.toml deleted file mode 100644 index 91673a0c..00000000 --- a/test/rstar/Project.toml +++ /dev/null @@ -1,15 +0,0 @@ -[deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" -MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -Distributions = "0.25" -MCMCDiagnosticTools = "0.1" -MLJBase = "0.19, 0.20, 0.21" -MLJLIBSVMInterface = "0.1, 0.2" -MLJXGBoostInterface = "0.1, 0.2, 0.3" -julia = "1.3" diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl deleted file mode 100644 index 00869743..00000000 --- a/test/rstar/runtests.jl +++ /dev/null @@ -1,79 +0,0 @@ -using MCMCDiagnosticTools - -using Distributions -using MLJBase -using MLJLIBSVMInterface -using MLJXGBoostInterface - -using Test - -const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode) - -@testset "rstar.jl" begin - classifiers = (XGBoostClassifier(), xgboost_deterministic, SVC()) - N = 1_000 - - @testset "examples (classifier = $classifier)" for classifier in classifiers - # Compute R⋆ statistic for a mixed chain. - samples = randn(N, 2) - dist = rstar(classifier, randn(N, 2), rand(1:3, N)) - - # Mean of the statistic should be focused around 1, i.e., the classifier does not - # perform better than random guessing. - if classifier isa MLJBase.Deterministic - @test dist isa Float64 - else - @test dist isa LocationScale - @test dist.ρ isa PoissonBinomial - @test minimum(dist) == 0 - @test maximum(dist) == 3 - end - @test mean(dist) ≈ 1 rtol = 0.2 - - # Compute R⋆ statistic for a mixed chain. - samples = randn(4 * N, 8) - chain_indices = repeat(1:4, N) - dist = rstar(classifier, samples, chain_indices) - - # Mean of the statistic should be closte to 1, i.e., the classifier does not perform - # better than random guessing. - if classifier isa MLJBase.Deterministic - @test dist isa Float64 - else - @test dist isa LocationScale - @test dist.ρ isa PoissonBinomial - @test minimum(dist) == 0 - @test maximum(dist) == 4 - end - @test mean(dist) ≈ 1 rtol = 0.15 - - # Compute the R⋆ statistic for a non-mixed chain. - samples = [ - sin.(1:N) cos.(1:N) - 100 .* cos.(1:N) 100 .* sin.(1:N) - ] - chain_indices = repeat(1:2; inner=N) - dist = rstar(classifier, samples, chain_indices) - - # Mean of the statistic should be close to 2, i.e., the classifier should be able to - # learn an almost perfect decision boundary between chains. - if classifier isa MLJBase.Deterministic - @test dist isa Float64 - else - @test dist isa LocationScale - @test dist.ρ isa PoissonBinomial - @test minimum(dist) == 0 - @test maximum(dist) == 2 - end - @test mean(dist) ≈ 2 rtol = 0.15 - end - - @testset "exceptions (classifier = $classifier)" for classifier in classifiers - @test_throws DimensionMismatch rstar(classifier, randn(N - 1, 2), rand(1:3, N)) - for subset in (-0.3, 0, 1 / (3 * N), 1 - 1 / (3 * N), 1, 1.9) - @test_throws ArgumentError rstar( - classifier, randn(N, 2), rand(1:3, N); subset=subset - ) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 54c058a6..63fdade0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,5 @@ using Pkg -# Activate test environment on older Julia versions -@static if VERSION < v"1.2" - Pkg.activate(@__DIR__) - Pkg.develop(PackageSpec(; path=dirname(@__DIR__))) - Pkg.instantiate() -end - using MCMCDiagnosticTools using FFTW @@ -43,16 +36,11 @@ Random.seed!(1) include("rafterydiag.jl") end @testset "R⋆ diagnostic" begin - # MLJXGBoostInterface requires Julia >= 1.3 # XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 - if VERSION >= v"1.3" && Sys.WORD_SIZE == 64 - # run tests related to rstar statistic - Pkg.activate("rstar") - Pkg.develop(; path=dirname(dirname(pathof(MCMCDiagnosticTools)))) - Pkg.instantiate() - include(joinpath("rstar", "runtests.jl")) + if Sys.WORD_SIZE == 64 + include("rstar.jl") else - @info "R⋆ not tested: requires Julia >= 1.3 and a 64bit architecture" + @info "R⋆ not tested: requires 64bit architecture" end end end