diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index b24980ec9..9750d4047 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -156,18 +156,57 @@ end function _rand!(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon}, - x::AbstractVector{<:Real}) - for (i, αi) in zip(eachindex(x), d.alpha) - @inbounds x[i] = rand(rng, Gamma(αi)) + x::AbstractVector{E}) where {E<:Real} + + if any(a -> a >= .5, d.alpha) + for (i, αi) in zip(eachindex(x), d.alpha) + @inbounds x[i] = rand(rng, Gamma(αi)) + end + + return lmul!(inv(sum(x)), x) + else + # Sample in log-space to lower underflow risk + for (i, αi) in zip(eachindex(x), d.alpha) + @inbounds x[i] = _logrand(rng, Gamma(αi)) + end + + if all(isinf, x) + # Final fallback, parameters likely deeply subnormal + # Distribution behavior approaches categorical as Σα -> 0 + p = copy(d.alpha) + p .*= floatmax(eltype(p)) # rescale to non-subnormal + x .= zero(E) + x[rand(rng, Categorical(inv(sum(p)) .* p))] = one(E) + return x + end + + return softmax!(x) end - lmul!(inv(sum(x)), x) # this returns x end function _rand!(rng::AbstractRNG, d::Dirichlet{T,<:FillArrays.AbstractFill{T}}, - x::AbstractVector{<:Real}) where {T<:Real} - rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) - lmul!(inv(sum(x)), x) # this returns x + x::AbstractVector{E}) where {T<:Real, E<:Real} + + if FillArrays.getindex_value(d.alpha) >= 0.5 + rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) + return lmul!(inv(sum(x)), x) + else + # Sample in log-space to lower underflow risk + _logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) + + if all(isinf, x) + # Final fallback, parameters likely deeply subnormal + # Distribution behavior approaches categorical as Σα -> 0 + n = length(d.alpha) + p = Fill(inv(n), n) + x .= zero(E) + x[rand(rng, Categorical(p))] = one(E) + return x + end + + return softmax!(x) + end end ####################################### diff --git a/src/samplers.jl b/src/samplers.jl index 794f2bff4..988fe5cd2 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -20,6 +20,7 @@ for fname in ["aliastable.jl", "poisson.jl", "exponential.jl", "gamma.jl", + "expgamma.jl", "multinomial.jl", "vonmises.jl", "vonmisesfisher.jl", diff --git a/src/samplers/expgamma.jl b/src/samplers/expgamma.jl new file mode 100644 index 000000000..3fefbb3a4 --- /dev/null +++ b/src/samplers/expgamma.jl @@ -0,0 +1,86 @@ +# These are used to bypass subnormals when sampling from + +# Inverse Power sampler +# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1 +struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous} + s::S #sampler for Gamma(1+shape,scale) + nia::T #-1/scale +end + +ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler) +function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable} + shape_d = shape(d) + sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d))) + return GammaIPSampler(sampler, -inv(shape_d)) +end + +function rand(rng::AbstractRNG, s::ExpGammaIPSampler) + x = log(rand(rng, s.s)) + e = randexp(rng, typeof(x)) + return muladd(s.nia, e, x) +end + + +# Small Shape sampler +# From Liu, C., Martin, R., and Syring, N. (2015) for when shape < 0.3 +struct ExpGammaSSSampler{T<:Real} <: Sampleable{Univariate,Continuous} + α::T + θ::T + λ::T + ω::T + ωω::T +end + +function ExpGammaSSSampler(d::Gamma) + α = shape(d) + ω = α / MathConstants.e / (1 - α) + return ExpGammaSSSampler(promote( + α, + scale(d), + inv(α) - 1, + ω, + inv(ω + 1) + )...) +end + +function rand(rng::AbstractRNG, s::ExpGammaSSSampler{T})::Float64 where T + flT = float(T) + while true + U = rand(rng, flT) + z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng, flT)) / s.λ + h = exp(-z - exp(-z / s.α)) + η = z >= zero(T) ? exp(-z) : s.ω * s.λ * exp(s.λ * z) + if h / η > rand(rng, flT) + return s.θ - z / s.α + end + end +end + + +function _logsampler(d::Gamma) + if shape(d) < 0.3 + return ExpGammaSSSampler(d) + else + return ExpGammaIPSampler(d) + end +end + +function _logrand(rng::AbstractRNG, d::Gamma) + if shape(d) < 0.3 + return rand(rng, ExpGammaSSSampler(d)) + else + return rand(rng, ExpGammaIPSampler(d)) + end +end + +function _logrand!(rng::AbstractRNG, d::Gamma, A::AbstractArray{<:Real}) + if shape(d) < 0.3 + @inbounds for i in eachindex(A) + A[i] = rand(rng, ExpGammaSSSampler(d)) + end + else + @inbounds for i in eachindex(A) + A[i] = rand(rng, ExpGammaIPSampler(d)) + end + end +end diff --git a/test/multivariate/dirichlet.jl b/test/multivariate/dirichlet.jl index 78de162dc..6da49bf1a 100644 --- a/test/multivariate/dirichlet.jl +++ b/test/multivariate/dirichlet.jl @@ -158,3 +158,29 @@ end end end end + +@testset "Dirichlet rand Inf and NaN (#1702)" begin + for d in [ + Dirichlet([8e-5, 1e-5, 2e-5]), + Dirichlet([8e-4, 1e-4, 2e-4]), + Dirichlet([4.5e-5, 8e-5]), + Dirichlet([6e-5, 2e-5, 3e-5, 4e-5, 5e-5]), + Dirichlet(FillArrays.Fill(1e-5, 5)) + ] + x = rand(d, 10^6) + @test mean(x, dims = 2) ≈ mean(d) atol=0.01 + @test var(x, dims = 2) ≈ var(d) atol=0.01 + end + + for (d, μ) in [ # Subnormal params cause mean(d) to error + + (Dirichlet([5e-310, 5e-310, 5e-310]), [1/3, 1/3, 1/3]), + (Dirichlet(FillArrays.Fill(5e-310, 3)), [1/3, 1/3, 1/3]), + (Dirichlet([5e-321, 1e-321, 4e-321]), [.5, .1, .4]), + (Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4]), + (Dirichlet(FillArrays.Fill(1e-321, 4)), [.25, .25, .25, .25]) + ] + x = rand(d, 10^6) + @test mean(x, dims = 2) ≈ μ atol=0.01 + end +end \ No newline at end of file