Skip to content

Commit

Permalink
Introduce strict_logdensityof
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Nov 9, 2024
1 parent d65c1eb commit 85fde05
Show file tree
Hide file tree
Showing 17 changed files with 80 additions and 40 deletions.
7 changes: 6 additions & 1 deletion ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checke

# = return type inference ====================================================

using MeasureBase: logdensityof_rt
using MeasureBase: logdensityof_rt, strict_logdensityof_rt

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
logdensityof_rt(target, v), _logdensityof_rt_pullback
end

_strict_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(strict_logdensityof_rt), target, v)
strict_logdensityof_rt(target, v), _strict_logdensityof_rt_pullback

Check warning on line 56 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L54-L56

Added lines #L54 - L56 were not covered by tests
end

end # module MeasureBaseChainRulesCoreExt
4 changes: 2 additions & 2 deletions src/combinators/half.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
return abs(rand(rng, T, unhalf(μ)))
end

function logdensityof::Half, x)
ld = logdensityof(unhalf(μ), x) - loghalf
function strict_logdensityof::Half, x)
ld = strict_logdensityof(unhalf(μ), x) - loghalf
return x 0 ? ld : oftype(ld, -Inf)
end

Expand Down
2 changes: 1 addition & 1 deletion src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ params(d::PowerMeasure) = params(first(marginals(d)))
basemeasure(d.parent)^d.axes
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(d::PowerMeasure{M}, x) where {M}
parent = d.parent
sum(x) do xj
Expand Down
6 changes: 3 additions & 3 deletions src/combinators/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function _rand_product(
end |> collect
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(d::AbstractProductMeasure, x)
mapreduce($func, +, marginals(d), x)
end
Expand All @@ -82,7 +82,7 @@ struct ProductMeasure{M} <: AbstractProductMeasure
marginals::M
end

@inline function logdensity_rel::ProductMeasure, ν::ProductMeasure, x)
@inline function strict_logdensity_rel::ProductMeasure, ν::ProductMeasure, x)

Check warning on line 85 in src/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/product.jl#L85

Added line #L85 was not covered by tests
mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x)
end

Expand All @@ -109,7 +109,7 @@ end
return q
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
# For tuples, `mapreduce` has trouble with type inference
@eval @inline function $func(d::ProductMeasure{T}, x) where {T<:Tuple}
ℓs = map($func, marginals(d), x)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/spikemixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
SpikeMixture(basemeasure.m), static(1.0), static(1.0))
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func::SpikeMixture, x)
# NOTE: We could instead write this as
# R1 = typeof(log(one(μ.s)))
Expand Down
10 changes: 5 additions & 5 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ function Pretty.tile(ν::PushforwardMeasure)
end

# TODO: THIS IS ALMOST CERTAINLY WRONG
# @inline function logdensity_rel(
# @inline function strict_logdensity_rel(
# ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure},
# β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure},
# y,
# ) where {FF1,IF1,M1,FF2,IF2,M2}
# x = β.inv_f(y)
# f = ν.inv_f ∘ β.f
# inv_f = β.inv_f ∘ ν.f
# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
# strict_logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
# end

# TODO: Would profit from custom pullback:
Expand All @@ -132,7 +132,7 @@ function _combine_logd_with_ladj(logd_orig::Real, ladj::Real)
end
end

function logdensityof(
function strict_logdensityof(
@nospecialize::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}),
@nospecialize(v::Any)
) where {M}
Expand All @@ -143,7 +143,7 @@ function logdensityof(
)
end

function logdensityof(
function strict_logdensityof(
@nospecialize::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}),
@nospecialize(v::Any)
) where {M}
Expand All @@ -154,7 +154,7 @@ function logdensityof(
)
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval function $func::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M}
f_inv = unwrap.finv)
x, inv_ladj = with_logabsdet_jacobian(f_inv, y)
Expand Down
53 changes: 44 additions & 9 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ To compute log-density relative to `basemeasure(m)` or *define* a log-density
`logdensity_def`.
To compute a log-density relative to a specific base-measure, see
`logdensity_rel`.
`logdensity_rel`.
# Implementation
Do not specialize `logdensityof` directly for subtypes of `AbstractMeasure`,
specialize `MeasureBase.logdensity_def` and `MeasureBase.strict_logdensityof` instead.
"""
@inline function logdensityof::AbstractMeasure, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
@inline function logdensityof::AbstractMeasure, x) #!!!!!!!!!!!!!!!!!
strict_logdensityof(μ, x)
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Expand All @@ -41,6 +45,24 @@ _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))

export unsafe_logdensityof

"""
MeasureBase.strict_logdensityof(μ, x)
Compute the log-density of the measure `μ` at `x` relative to `rootmeasure(m)`.
In contrast to [`logdensityof(μ, x)`](@ref), this will not take implicit pushforwards
of `μ` (depending on the type of `x`) into account.
"""
function strict_logdensityof end

@inline function strict_logdensityof(μ, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
end

@inline function strict_logdensityof_rt(::T, ::U) where {T,U}
Core.Compiler.return_type(strict_logdensityof, Tuple{T,U})

Check warning on line 63 in src/density-core.jl

View check run for this annotation

Codecov / codecov/patch

src/density-core.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
end

# https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer
"""
unsafe_logdensityof(m, x)
Expand Down Expand Up @@ -68,14 +90,27 @@ See also `logdensityof`.
end

"""
logdensity_rel(m1, m2, x)
logdensity_rel(μ, ν, x)
Compute the log-density of `m1` relative to `m2` at `x`. This function checks
whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is
Compute the log-density of `μ` relative to `ν` at `x`. This function checks
whether `x` is in the support of `μ` or `ν` (or both, or neither). If `x` is
known to be in the support of both, it can be more efficient to call
`unsafe_logdensity_rel`.
`unsafe_logdensity_rel`.
"""
function logdensity_rel(μ, ν, x)
strict_logdensity_rel(μ, ν, x)
end

"""
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
MeasureBase.strict_logdensity_rel(μ, ν, x)
Compute the log-density of `μ` relative to `ν` at `x`. In contrast to
[`logdensity_rel(μ, ν, x)`](@ref), this will not take implicit pushforwards
of `μ` and `ν` (depending on the type of `x`) into account.
"""
function strict_logdensity_rel end

@inline function strict_logdensity_rel::M, ν::N, x::X) where {M,N,X}
T = unstatic(
promote_type(
return_type(logdensity_def, (μ, x)),
Expand Down
4 changes: 2 additions & 2 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)

density_def::DensityMeasure, x) = densityof.f, x)

function logdensityof::DensityMeasure, x::Any)
function strict_logdensityof::DensityMeasure, x::Any)
integrand, μ_base = μ.f, μ.base

base_logval = logdensityof(μ_base, x)
base_logval = strict_logdensityof(μ_base, x)

T = typeof(base_logval)
U = logdensityof_rt(integrand, x)
Expand Down
4 changes: 2 additions & 2 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ basemeasure(μ::PrimitiveMeasure) = μ

@inline basemeasure_depth(::PrimitiveMeasure) = static(0)

@inline logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
@inline logdensityof(::PrimitiveMeasure, x) = false
@inline strict_logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
@inline strict_logdensityof(::PrimitiveMeasure, x) = false

Check warning on line 23 in src/primitive.jl

View check run for this annotation

Codecov / codecov/patch

src/primitive.jl#L23

Added line #L23 was not covered by tests

logdensity_def(::PrimitiveMeasure, x) = static(0.0)

Expand Down
6 changes: 3 additions & 3 deletions src/primitives/counting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ struct Counting{T} <: AbstractMeasure
Counting(supp) = new{Core.Typeof(supp)}(supp)
end

@inline function logdensityof::Counting, x::Real)
@inline function strict_logdensityof::Counting, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

@inline logdensityof::Counting, x) = insupport(μ, x) ? 0.0 : -Inf
@inline strict_logdensityof::Counting, x) = insupport(μ, x) ? 0.0 : -Inf

Check warning on line 20 in src/primitives/counting.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives/counting.jl#L20

Added line #L20 was not covered by tests

@inline logdensity_def::Counting, x) = logdensityof(μ, x)
@inline logdensity_def::Counting, x) = strict_logdensityof(μ, x)

basemeasure(::Counting) = CountingBase()

Expand Down
4 changes: 2 additions & 2 deletions src/primitives/dirac.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ basemeasure(d::Dirac) = CountingBase()

massof(::Dirac) = static(1.0)

function logdensityof::Dirac, x::Real)
function strict_logdensityof::Dirac, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

logdensityof::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf
strict_logdensityof::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf

Check warning on line 28 in src/primitives/dirac.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives/dirac.jl#L28

Added line #L28 was not covered by tests

logdensity_def(::Dirac, x::Real) = zero(float(typeof(x)))
logdensity_def(::Dirac, x) = 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/primitives/lebesgue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ insupport(μ::Lebesgue, x) = x ∈ μ.support

insupport(::Lebesgue{RealNumbers}, ::Real) = true

@inline function logdensityof::Lebesgue, x::Real)
@inline function strict_logdensityof::Lebesgue, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

@inline logdensityof::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf
@inline strict_logdensityof::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf

Check warning on line 71 in src/primitives/lebesgue.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives/lebesgue.jl#L71

Added line #L71 was not covered by tests

massof(::Lebesgue{RealNumbers}, s::Interval) = width(s)

Expand Down
2 changes: 1 addition & 1 deletion src/standard/stdexponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export StdExponential

insupport(::StdExponential, x) = x zero(x)

@inline function logdensityof(::StdExponential, x)
@inline function strict_logdensityof(::StdExponential, x)
R = float(typeof(x))
x zero(R) ? convert(R, -x) : R(-Inf)
end
Expand Down
4 changes: 2 additions & 2 deletions src/standard/stdlogistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ export StdLogistic

@inline insupport(d::StdLogistic, x) = true

@inline logdensityof(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u))
@inline strict_logdensityof(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u))

@inline logdensity_def(::StdLogistic, x) = logdensityof(StdLogistic(), x)
@inline logdensity_def(::StdLogistic, x) = strict_logdensityof(StdLogistic(), x)
@inline basemeasure(::StdLogistic) = LebesgueBase()

@inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x)
Expand Down
2 changes: 1 addition & 1 deletion src/standard/stdnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export StdNormal

@inline insupport(::StdNormal, x) = true

@inline logdensityof(::StdNormal, x) = (-x^2 - log2π) / 2
@inline strict_logdensityof(::StdNormal, x) = (-x^2 - log2π) / 2

@inline logdensity_def(::StdNormal, x) = -x^2 / 2
@inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), LebesgueBase())
Expand Down
2 changes: 1 addition & 1 deletion src/standard/stduniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export StdUniform

insupport(::StdUniform, x) = zero(x) x one(x)

@inline function logdensityof(::StdUniform, x)
@inline function strict_logdensityof(::StdUniform, x)
R = float(typeof(x))
zero(x) x one(x) ? zero(R) : R(-Inf)
end
Expand Down
4 changes: 2 additions & 2 deletions src/transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ end

function ChangesOfVariables.with_logabsdet_jacobian(f::TransportFunction, x)
y = f(x)
logpdf_src = logdensityof(f.μ, x)
logpdf_trg = logdensityof(f.ν, y)
logpdf_src = strict_logdensityof(f.μ, x)
logpdf_trg = strict_logdensityof(f.ν, y)
ladj = logpdf_src - logpdf_trg
# If logpdf_src and logpdf_trg are -Inf setting lafj to zero is safe:
fixed_ladj = logpdf_src == logpdf_trg == -Inf ? zero(ladj) : ladj
Expand Down

0 comments on commit 85fde05

Please sign in to comment.