Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cvi projection marginal rule (with proposal distribution) #430

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions docs/src/lib/nodes/delta.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ ReactiveMP.UnscentedTransform
ReactiveMP.ProdCVI
ReactiveMP.CVI
ReactiveMP.CVIProjection
ReactiveMP.CVISamplingStrategy
ReactiveMP.FullSampling
ReactiveMP.MeanBased
ReactiveMP.ProposalDistributionContainer
ReactiveMP.cvi_setup!
ReactiveMP.cvi_update!
ReactiveMP.DeltaFnDefaultRuleLayout
Expand Down
47 changes: 45 additions & 2 deletions ext/ReactiveMPProjectionExt/rules/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,51 @@ end
return FactorizedJoint((q,))
end

function create_density_function(forms_match, i, pre_samples, logp_nc_drop_index, m_in)
if forms_match
return z -> logp_nc_drop_index(z, i, pre_samples)
end
return z -> logp_nc_drop_index(z, i, pre_samples) + logpdf(m_in, z)
end

function optimize_parameters(i, pre_samples, m_ins, logp_nc_drop_index, method)
m_in = m_ins[i]
default_type = ExponentialFamily.exponential_family_typetag(m_in)
prj = create_project_to_ins(method, m_in, i)

typeform = ExponentialFamilyProjection.get_projected_to_type(prj)
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
forms_match = typeform === default_type && dims == size(m_in)

df = create_density_function(forms_match, i, pre_samples, logp_nc_drop_index, m_in)
logp = convert(promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf), UnspecifiedDomain(), df)

return forms_match ? project_to(prj, logp, m_in) : project_to(prj, logp)
end

function generate_samples(rng, ::Nothing, m_ins, sampling_strategy::FullSampling)
return zip(map(m_in -> ReactiveMP.cvilinearize(rand(rng, m_in, sampling_strategy.samples)), m_ins)...)
end

function generate_samples(::Any, ::Nothing, m_ins, ::MeanBased)
return zip(map(m_in -> [mean(m_in)], m_ins)...)
end

function generate_samples(rng, proposal_distribution::FactorizedJoint, ::Any, sampling_strategy::FullSampling)
return zip(map(q_in -> ReactiveMP.cvilinearize(rand(rng, q_in, sampling_strategy.samples)), proposal_distribution.multipliers)...)
end

function generate_samples(::Any, proposal_distribution::FactorizedJoint, ::Any, ::MeanBased)
return zip(map(q_in -> [mean(q_in)], proposal_distribution.multipliers)...)
end

@marginalrule DeltaFn(:ins) (m_out::Any, m_ins::ManyOf{N, Any}, meta::DeltaMeta{M}) where {N, M <: CVIProjection} = begin
method = ReactiveMP.getmethod(meta)
rng = method.rng
pre_samples = zip(map(m_in_k -> ReactiveMP.cvilinearize(rand(rng, m_in_k, method.marginalsamples)), m_ins)...)
proposal_distribution_container = method.proposal_distribution
sampling_strategy = method.sampling_strategy

pre_samples = generate_samples(rng, proposal_distribution_container.distribution, m_ins, sampling_strategy)

logp_nc_drop_index = let g = getnodefn(meta, Val(:out)), pre_samples = pre_samples
(z, i, pre_samples) -> begin
Expand Down Expand Up @@ -84,5 +125,7 @@ end
end
end

return FactorizedJoint(ntuple(i -> optimize_natural_parameters(i, pre_samples), length(m_ins)))
result = FactorizedJoint(ntuple(i -> optimize_natural_parameters(i, pre_samples), length(m_ins)))
proposal_distribution_container.distribution = result
return result
end
110 changes: 101 additions & 9 deletions src/approximations/cvi_projection.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,124 @@
export CVIProjection

export CVISamplingStrategy, FullSampling, MeanBased

"""
CVISamplingStrategy

An abstract type representing the sampling strategy for the CVI projection method.
Concrete subtypes implement different approaches for generating samples used in
approximating distributions.
"""
abstract type CVISamplingStrategy end

"""
FullSampling <: CVISamplingStrategy
FullSampling(samples::Int = 10)

A sampling strategy that uses multiple samples drawn from distributions.

# Arguments
- `samples::Int`: The number of samples to draw from each distribution. Default is 10.

# Example
```julia
# Use 100 samples for more accurate approximation
strategy = FullSampling(100)
```
"""
struct FullSampling <: CVISamplingStrategy
samples::Int

FullSampling(samples::Int = 10) = new(samples)
end

"""
MeanBased <: CVISamplingStrategy

A sampling strategy that uses only the mean of the proposal distribution as a single sample.
"""
struct MeanBased <: CVISamplingStrategy end

"""
ProposalDistributionContainer{PD}

A mutable wrapper for proposal distributions used in the CVI projection method.

The container allows the proposal distribution to be updated during inference
without recreating the entire approximation method structure.

# Fields
- `distribution::PD`: The wrapped proposal distribution, can be of any compatible type.
"""
mutable struct ProposalDistributionContainer{PD}
distribution::PD
end

"""
CVIProjection(; parameters...)

A structure representing the parameters for the Conjugate Variational Inference (CVI) projection method.
This structure is a subtype of `AbstractApproximationMethod` and is used to configure the settings for CVI.

!!! note
The `CVIProjection` method requires `ExponentialFamilyProjection` package installed in the current environment.
CVI approximates the posterior distribution by projecting it onto a family of distributions with a conjugate form.

# Requirements

The `CVIProjection` method requires the `ExponentialFamilyProjection` package to be installed and loaded
in the current environment with `using ExponentialFamilyProjection`.

# Parameters

- `rng::R`: The random number generator used for sampling. Default is `Random.MersenneTwister(42)`.
- `marginalsamples::S`: The number of samples used for approximating marginal distributions. Default is `10`.
- `outsamples::S`: The number of samples used for approximating output message distributions. Default is `100`.
- `out_prjparams::OF`: the form parameter used to select the distribution form on which one to project out edge, if it's not provided will be infered from marginal form
- `in_prjparams::IFS`: a namedtuple like object to select the form on which one to project in the input edge, if it's not provided will be infered from the incoming message onto this edge
- `out_prjparams::OF`: The form parameter used to specify the target distribution family for the output message.
If `nothing` (default), the form will be inferred from the marginal form.
- `in_prjparams::IFS`: A NamedTuple-like object that specifies the target distribution family for each input edge.
Keys should be of the form `:in_k` where `k` is the input edge index. If `nothing` (default), the forms
will be inferred from the incoming messages.
- `proposal_distribution::PD`: The proposal distribution used for generating samples. If not provided or set to
`nothing`, it will be inferred from incoming messages and automatically updated during iterations.
- `sampling_strategy::SS`: The strategy for approximating the logpdf:
- `FullSampling(n)`: Uses `n` samples drawn from distributions (default: `n=10`). Provides more accurate
approximation at the cost of increased computation time.
- `MeanBased()`: Uses only the mean of each distribution as a single sample. Significantly faster but
less accurate for non-linear nodes or complex distributions.

# Examples

```julia
# Standard CVI projection with default settings
method = CVIProjection()

# Fast approximation using mean-based sampling
method = CVIProjection(sampling_strategy = MeanBased())

# Custom proposal with increased sample count
using Distributions
proposal = FactorizedJoint((NormalMeanVariance(0.0, 1.0), NormalMeanVariance(0.0, 1.0)))
method = CVIProjection(
proposal_distribution = ProposalDistributionContainer(proposal),
sampling_strategy = FullSampling(1000)
)

# Specify projection family for the output message
method = CVIProjection(out_prjparams = NormalMeanPrecision)

# Specify projection family for input edges
method = CVIProjection(in_prjparams = (in_1 = NormalMeanVariance, in_2 = GammaMeanShape))
```

!!! note
The `CVIProjection` method is an experimental enhancement of the now-deprecated `CVI`, offering better stability and improved accuracy.
Note that the parameters of this structure, as well as their defaults, are subject to change during the experimentation phase.
The `CVIProjection` method is an enhanced version of the deprecated `CVI`, offering better stability
and improved accuracy. Parameters and defaults may change as the implementation evolves.
"""
Base.@kwdef struct CVIProjection{R, S, OF, IFS} <: AbstractApproximationMethod
Base.@kwdef struct CVIProjection{R, S, OF, IFS, PD, SS} <: AbstractApproximationMethod
rng::R = Random.MersenneTwister(42)
marginalsamples::S = 10
outsamples::S = 100
out_prjparams::OF = nothing
in_prjparams::IFS = nothing
proposal_distribution::PD = ProposalDistributionContainer{Any}(nothing)
sampling_strategy::SS = FullSampling(10)
end

function get_kth_in_form(::CVIProjection{R, S, OF, Nothing}, ::Int) where {R, S, OF}
Expand Down
86 changes: 86 additions & 0 deletions test/ext/ReactiveMPProjectionExt/cvi_projection_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
@testitem "CVI Projection Extension Tests" begin
using ExponentialFamily
using ExponentialFamilyProjection
using BayesBase
using ReactiveMP
using Distributions
using Random

ReactiveMPProjectionExt = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ReactiveMPProjectionExt)
using .ReactiveMPProjectionExt

@testset "create_density_function" begin
# Mock functions and data for testing
pre_samples = [1.0, 2.0, 3.0]

# Mock message with a simple normal distribution
m_in = NormalMeanVariance(0.0, 1.0)

# Mock logp_nc_drop_index function that just returns a constant + the input value
logp_nc_drop_index = (z, i, samples) -> -0.5 * z^2

# Test when forms match (should not include the message logpdf)
forms_match = true
df_match = ReactiveMPProjectionExt.create_density_function(forms_match, 1, pre_samples, logp_nc_drop_index, m_in)
@test df_match(0.5) ≈ logp_nc_drop_index(0.5, 1, pre_samples)
@test df_match(1.0) ≈ -0.5 # Just the logp_nc_drop_index result

# Test when forms don't match (should include the message logpdf)
forms_match = false
df_no_match = ReactiveMPProjectionExt.create_density_function(forms_match, 1, pre_samples, logp_nc_drop_index, m_in)
# Expected: logp_nc_drop_index + logpdf of the message
expected_value = logp_nc_drop_index(0.5, 1, pre_samples) + logpdf(m_in, 0.5)
@test df_no_match(0.5) ≈ expected_value
end

@testset "optimize_parameters" begin
# Test with normal distribution - we can derive exact expected results
m_in = NormalMeanVariance(0.0, 1.0) # Prior: mean=0, variance=1 (precision=1)
m_ins = [m_in]
pre_samples = [0.0, 0.5, -0.5]
method = CVIProjection()

# Case 1: Quadratic log-likelihood centered at 0 (-0.5*z²) corresponds to Normal(0,1)
# When combining Normal(0,1) prior with Normal(0,1) likelihood:
# Expected posterior: Normal(0, 0.5) - precision adds (1+1=2, variance=1/2=0.5)
log_fn1 = (z, i, samples) -> -0.5 * z^2
result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method)

@test result1 isa NormalMeanVariance
@test mean(result1) ≈ 0.0 atol = 1e-1
@test var(result1) ≈ 0.5 atol = 1e-1

# Case 2: Quadratic centered at 2.0 (-0.5*(z-2)²) corresponds to Normal(2,1)
# Combining Normal(0,1) prior with Normal(2,1) likelihood:
# Expected posterior: Normal(1, 0.5) - weighted average of means
log_fn2 = (z, i, samples) -> -0.5 * (z - 2.0)^2
result2 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn2, method)

@test result2 isa NormalMeanVariance
@test mean(result2) ≈ 1.0 atol = 1e-1 # (0*1 + 2*1)/(1+1) = 1.0
@test var(result2) ≈ 0.5 atol = 1e-1 # 1/(1+1) = 0.5

# Case 3: Stronger quadratic (-2.0*(z-2)²) corresponds to Normal(2,0.25)
# Combining Normal(0,1) prior with Normal(2,0.25) likelihood:
# Expected posterior: Normal(1.6, 0.2)
log_fn3 = (z, i, samples) -> -2.0 * (z - 2.0)^2
result3 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn3, method)

@test result3 isa NormalMeanVariance
@test mean(result3) ≈ 1.6 atol = 1e-1 # (0*1 + 2*4)/(1+4) = 8/5 = 1.6
@test var(result3) ≈ 0.2 atol = 1e-1 # 1/(1+4) = 0.2

# Case 4: Test with a different prior
m_in2 = NormalMeanVariance(1.0, 2.0) # Prior: mean=1, variance=2 (precision=0.5)
m_ins2 = [m_in2]

# Combining Normal(1,2) prior with Normal(2,1) likelihood:
# Expected posterior: Normal(5/3, 2/3)
result4 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins2, log_fn2, method)

@test result4 isa NormalMeanVariance
@test mean(result4) ≈ 5 / 3 atol = 1e-1 # (1*0.5 + 2*1)/(0.5+1) = 1.67
@test var(result4) ≈ 2 / 3 atol = 1e-1 # 1/(0.5+1) = 0.67
end
end
89 changes: 88 additions & 1 deletion test/ext/ReactiveMPProjectionExt/rules/marginals_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ end
# Test with partial specification
meta_partial = DeltaMeta(method = CVIProjection(
in_prjparams = (in_2 = form2,), # Only specify second input
marginalsamples = 10
sampling_strategy = FullSampling(10)
), inverse = nothing)

# Setup messages
Expand All @@ -111,3 +111,90 @@ end
@test isa(result[2], MvNormalMeanScalePrecision)
end
end

@testitem "CVIProjection proposal distribution convergence tests" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase, LinearAlgebra
using Random, Distributions

@testset "Posterior approximation quality" begin
rng = MersenneTwister(123)
method = CVIProjection(rng = rng, sampling_strategy = FullSampling(2000))
meta = DeltaMeta(method = method, inverse = nothing)

f(x, y) = x * y

# Define distributions
m_out = NormalMeanVariance(2.0, 0.1)
m_in1 = NormalMeanVariance(0.0, 2.0)
m_in2 = NormalMeanVariance(0.0, 2.0)

# Function to compute unnormalized log posterior for a sample
function log_posterior(x, y)
return logpdf(m_in1, x) + logpdf(m_in2, y) + logpdf(m_out, f(x, y))
end

# Estimate KL divergence using samples
function estimate_kl_divergence(q_result)
n_samples = 10000
samples_q = [(rand(rng, q_result[1]), rand(rng, q_result[2])) for _ in 1:n_samples]

# Compute E_q[log q(x,y) - log p(x,y)]
log_q_terms = [logpdf(q_result[1], x) + logpdf(q_result[2], y) for (x, y) in samples_q]
log_p_terms = [log_posterior(x, y) for (x, y) in samples_q]

return mean(log_q_terms .- log_p_terms)
end

# Run multiple iterations and collect KL divergences
n_iterations = 10
kl_divergences = Vector{Float64}(undef, n_iterations)

for i in 1:n_iterations
result = @call_marginalrule DeltaFn{f}(:ins) (m_out = m_out, m_ins = ManyOf(m_in1, m_in2), meta = meta)
kl_divergences[i] = estimate_kl_divergence(result)
end

@test kl_divergences[1] > kl_divergences[end]
end
end

@testitem "Basic checks for marginal rule with mean based approximation" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase
import ReactiveMP: @test_rules, @test_marginalrules

@testset "f(x, y) -> [x, y], x~Normal, y~Normal, out~MvNormal (marginalization)" begin
f(x, y) = [x, y]
meta = DeltaMeta(method = CVIProjection(sampling_strategy = MeanBased()), inverse = nothing)
@test_marginalrules [check_type_promotion = false, atol = 1e-1] DeltaFn{f}(:ins) [(
input = (m_out = MvGaussianMeanCovariance(ones(2), [2 0; 0 2]), m_ins = ManyOf(NormalMeanVariance(0, 1), NormalMeanVariance(1, 2)), meta = meta),
output = FactorizedJoint((NormalMeanVariance(1 / 3, 2 / 3), NormalMeanVariance(1.0, 1.0)))
)]
end
end

@testitem "DeltaNode - CVI sampling strategy performance comparison" begin
using Test
using BenchmarkTools
using BayesBase, ExponentialFamily, ExponentialFamilyProjection

f(x, y) = [x, y]

function run_marginal_test(strategy)
meta = DeltaMeta(method = CVIProjection(sampling_strategy = strategy))
m_out = MvGaussianMeanCovariance(ones(2), [2 0; 0 2])
m_in1 = NormalMeanVariance(0.0, 2.0)
m_in2 = NormalMeanVariance(0.0, 2.0)
return @belapsed begin
@call_marginalrule DeltaFn{f}(:ins) (m_out = $m_out, m_ins = ManyOf($m_in1, $m_in2), meta = $meta)
end samples = 2
end

# Run benchmarks
full_time = run_marginal_test(FullSampling(10))
mean_time = run_marginal_test(MeanBased())

@test mean_time < full_time

# Optional: Print the actual times for verification
@info "Sampling strategy performance" full_time mean_time ratio = (full_time / mean_time)
end