Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MvNormal is unintuitive and has performance issues relative to Distributions.jl #246

Open
ptiede opened this issue Nov 20, 2022 · 0 comments

Comments

@ptiede
Copy link

ptiede commented Nov 20, 2022

Hi,

I just noticed that d = MvNormal{(:μ, :Σ)} seems at least unintuitive or possibly broken. When I tried to create a MvNormal using what I guessed what the interface is gives me

d  = MvNormal=ones(2), Σ=Diagonal(ones(2)))
logdensityof(d, zeros(2))
ERROR: inverse of AffineTransform= [1.0, 1.0], Σ = [1.0 0.0; 0.0, 1.0]) is not defined
Stacktrace:
 [1] error(::String, ::AffineTransform{(, ), Tuple{Vector{Float64}, Diagonal{Float64, Vector{Float64}}}}, ::String)
   @ Base ./error.jl:44
 [2] (::InverseFunctions.NoInverse{AffineTransform{(, ), Tuple{Vector{Float64}, Diagonal{Float64, Vector{Float64}}}}})(x::Vector{Float64})
   @ InverseFunctions ~/.julia/packages/InverseFunctions/NUvSJ/src/inverse.jl:67
 [3] logdensity_def
   @ ~/.julia/packages/MeasureTheory/gA2Wa/src/combinators/affine.jl:243 [inlined]
 [4] logdensity_def
   @ ~/.julia/packages/MeasureBase/brgOa/src/proxies.jl:17 [inlined]
 [5] unsafe_logdensityof
   @ ~/.julia/packages/MeasureBase/brgOa/src/density-core.jl:59 [inlined]
 [6] logdensityof::MvNormal{(, ), Tuple{Vector{Float64}, Diagonal{Float64, Vector{Float64}}}}, x::Vector{Float64})
   @ MeasureBase ~/.julia/packages/MeasureBase/brgOa/src/density-core.jl:32
 [7] top-level scope
   @ REPL[6]:1

Looking into the code, it does appear that this is because MvNormal expects the covariance matrix of type Cholesky. This is pretty unintuitive for a non-power user and requires a pretty high level of sophistication for someone just wanting to compute the density of a MvNormal, i.e., they have to know what the Cholesky decomposition is and why this makes sense for evaluating the density function of a MvNormal.

Another thing I noticed during this is that there are some pretty regular use cases where MvNormal in MeasureTheory has substantially worse performance than in Distributions. For example, consider the following

import MeasureTheory as MT
import Distributions as Dists
using MeasureTheory: logdensityof
using BenchmarkTools

μ = ones(5)
σ = ones(5)
x = zeros(5)

@btime logdensityof($dmt, $x)
#  139.685 ns (5 allocations: 480 bytes)
# -7.094692666023363

@btime logdensityof($ddist, $x)
# 39.214 ns (1 allocation: 96 bytes)
# -7.094692666023364

I understand part of this performance gap is because MeasureTheory doesn't construct the normalization when MvNormal is called it, but rather constructs it on the fly as part of basemeasure. However, this doesn't entirely explain the performance gap since Dists.MvNormal(μ, σ) only takes 31 ns on my machine which means there is still a factor of 2 difference in performance. Additionally, I think caching the normalization during the construction of MvNormal is a good idea. First, constructing the normalization every time is pretty wasteful, especially if you plan on evaluating the distribution several times, which is pretty common in statistical computing. My guess is the hope is that the compiler would do some constant prop to figure this out, but that seems pretty dicey since these aren't simple computations.

Environment

versioninfo()
Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × AMD Ryzen 9 7950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, znver3)
  Threads: 1 on 32 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 1
julia> Pkg.status()
Status `~/test/Project.toml`
  [31c24e10] Distributions v0.25.78
  [eadaa1a4] MeasureTheory v0.19.0 `https://github.com/cscherrer/MeasureTheory.jl.git#dev`

I have tested this on MeasureTheory 0.18.1 and on the current dev branch and see consistent behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant