Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CVIProjection with constraints on interfaces #428

Merged
merged 20 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Distributions = "0.24, 0.25"
DomainIntegrals = "0.3.2, 0.4"
DomainSets = "0.5.2, 0.6, 0.7"
ExponentialFamily = "1.6.0"
ExponentialFamilyProjection = "1.1"
ExponentialFamilyProjection = "1.2"
FastCholesky = "1.3.0"
FastGaussQuadrature = "0.4, 0.5"
FixedArguments = "0.1"
Expand Down Expand Up @@ -91,6 +91,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -102,4 +103,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"

[targets]
test = ["Aqua", "CpuId", "ReTestItems", "Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL"]
test = ["Aqua", "CpuId", "ReTestItems", "Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkCI", "BenchmarkTools", "JET", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL"]
69 changes: 59 additions & 10 deletions ext/ReactiveMPProjectionExt/rules/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,39 @@ using TupleTools
import Distributions: Distribution
import BayesBase: AbstractContinuousGenericLogPdf

function create_project_to_ins(::CVIProjection, ::Nothing, m_in::Any)
T = ExponentialFamily.exponential_family_typetag(m_in)
ef_in = convert(ExponentialFamilyDistribution, m_in)
conditioner = getconditioner(ef_in)
return ProjectedTo(
T,
size(m_in)...;
conditioner = conditioner,
parameters = ExponentialFamilyProjection.DefaultProjectionParameters()
)
end

function create_project_to_ins(::CVIProjection, form::ProjectedTo, ::Any)
return form
end

function create_project_to_ins(::CVIProjection, params::ProjectionParameters, m_in::Any)
T = ExponentialFamily.exponential_family_typetag(m_in)
ef_in = convert(ExponentialFamilyDistribution, m_in)
conditioner = getconditioner(ef_in)
return ProjectedTo(
T,
size(m_in)...;
conditioner = conditioner,
parameters = params
)
end

function create_project_to_ins(method::CVIProjection, m_in::Any, k::Int)
form = ReactiveMP.get_kth_in_form(method, k)
return create_project_to_ins(method, form, m_in)
end

@marginalrule DeltaFn(:ins) (m_out::Any, m_ins::ManyOf{1, Any}, meta::DeltaMeta{M}) where {M <: CVIProjection} = begin
method = ReactiveMP.getmethod(meta)
g = getnodefn(meta, Val(:out))
Expand All @@ -13,8 +46,7 @@ import BayesBase: AbstractContinuousGenericLogPdf
F = promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf)
f = convert(F, UnspecifiedDomain(), (z) -> logpdf(m_out, g(z)))

T = ExponentialFamily.exponential_family_typetag(m_in)
prj = ProjectedTo(T, size(m_in)...; conditioner = getconditioner(ef_in), parameters = something(method.prjparams, ExponentialFamilyProjection.DefaultProjectionParameters()))
prj = create_project_to_ins(method, m_in, 1)
q = project_to(prj, f, first(m_ins))

return FactorizedJoint((q,))
Expand All @@ -36,16 +68,33 @@ end

optimize_natural_parameters = let m_ins = m_ins, logp_nc_drop_index = logp_nc_drop_index
(i, pre_samples) -> begin
# Create an `AbstractContinuousGenericLogPdf` with an unspecified domain and the transformed `logpdf` function
df = let i = i, pre_samples = pre_samples, logp_nc_drop_index = logp_nc_drop_index
(z) -> logp_nc_drop_index(z, i, pre_samples)
m_in = m_ins[i]
default_type = ExponentialFamily.exponential_family_typetag(m_in)

prj = create_project_to_ins(method, m_in, i)

typeform = ExponentialFamilyProjection.get_projected_to_type(prj)
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
forms_match = typeform === default_type && dims == size(m_in)

# Create log probability function
df = if forms_match
let i = i, pre_samples = pre_samples, logp_nc_drop_index = logp_nc_drop_index
(z) -> logp_nc_drop_index(z, i, pre_samples)
end
else
let i = i, pre_samples = pre_samples, logp_nc_drop_index = logp_nc_drop_index, m_in = m_in
(z) -> logp_nc_drop_index(z, i, pre_samples) + logpdf(m_in, z)
end
end
logp = convert(promote_variate_type(variate_form(typeof(m_ins[i])), BayesBase.AbstractContinuousGenericLogPdf), UnspecifiedDomain(), df)
conditioner = getconditioner(convert(ExponentialFamilyDistribution, m_ins[i]))
T = ExponentialFamily.exponential_family_typetag(m_ins[i])
prj = ProjectedTo(T, size(m_ins[i])...; conditioner=conditioner, parameters = something(method.prjparams, ExponentialFamilyProjection.DefaultProjectionParameters()))

return project_to(prj, logp, m_ins[i])
logp = convert(
promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf),
UnspecifiedDomain(),
df
)

return forms_match ? project_to(prj, logp, m_in) : project_to(prj, logp)
end
end

Expand Down
36 changes: 25 additions & 11 deletions ext/ReactiveMPProjectionExt/rules/out.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
# First method: when there's no projection form
function create_project_to(::CVIProjection{R, S, Nothing}, q_out::Any, q_out_samples::Any) where {R, S}
T = ExponentialFamily.exponential_family_typetag(q_out)
q_out_ef = convert(ExponentialFamilyDistribution, q_out)
conditioner = getconditioner(q_out_ef)
return ProjectedTo(T, size(first(q_out_samples))...; conditioner = conditioner, parameters = ExponentialFamilyProjection.DefaultProjectionParameters())
end

function create_project_to(method::CVIProjection{R, S, OF}, ::Any, ::Any) where {R, S, OF <: ProjectedTo}
return method.out_prjparams
end

function create_project_to(method::CVIProjection{R, S, OF}, q_out::Any, q_out_samples::Any) where {R, S, OF <: ProjectionParameters}
T = ExponentialFamily.exponential_family_typetag(q_out)
q_out_ef = convert(ExponentialFamilyDistribution, q_out)
conditioner = getconditioner(q_out_ef)
return ProjectedTo(T, size(first(q_out_samples))...; conditioner = conditioner, parameters = method.out_prjparams)
end

@rule DeltaFn(:out, Marginalisation) (m_out::Any, q_out::Any, q_ins::FactorizedJoint, meta::DeltaMeta{U}) where {U <: CVIProjection} = begin
node_function = getnodefn(meta, Val(:out))
method = ReactiveMP.getmethod(meta)
rng = method.rng
q_ins_components = components(q_ins)
method = ReactiveMP.getmethod(meta)
rng = method.rng
q_ins_components = components(q_ins)
q_ins_sample_friendly = map(q_in -> sampling_optimized(q_in), q_ins_components)

samples = map(ReactiveMP.cvilinearize, map(q_in -> rand(rng, q_in, method.outsamples), q_ins_sample_friendly))
q_out_samples = map(x -> node_function(x...), zip(samples...))

T = ExponentialFamily.exponential_family_typetag(q_out)
q_out_ef = convert(ExponentialFamilyDistribution, q_out)
conditioner = getconditioner(q_out_ef)

prj = ProjectedTo(T, size(first(q_out_samples))...; conditioner = conditioner, parameters = something(method.prjparams, ExponentialFamilyProjection.DefaultProjectionParameters()))
est = project_to(prj, q_out_samples)

prj = create_project_to(method, q_out, q_out_samples)
est = project_to(prj, q_out_samples)
return DivisionOf(est, m_out)
end
end
17 changes: 14 additions & 3 deletions src/approximations/cvi_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,28 @@ This structure is a subtype of `AbstractApproximationMethod` and is used to conf
- `rng::R`: The random number generator used for sampling. Default is `Random.MersenneTwister(42)`.
- `marginalsamples::S`: The number of samples used for approximating marginal distributions. Default is `10`.
- `outsamples::S`: The number of samples used for approximating output message distributions. Default is `100`.
- `prjparams::P`: Parameters for the exponential family projection. Default is `nothing`, in which case it will use `ExponentialFamilyProjection.DefaultProjectionParameters()`.
- `out_prjparams::OF`: the form parameter used to select the distribution form on which one to project out edge, if it's not provided will be infered from marginal form
- `in_prjparams::IFS`: a namedtuple like object to select the form on which one to project in the input edge, if it's not provided will be infered from the incoming message onto this edge

!!! note
The `CVIProjection` method is an experimental enhancement of the now-deprecated `CVI`, offering better stability and improved accuracy.
Note that the parameters of this structure, as well as their defaults, are subject to change during the experimentation phase.
"""
Base.@kwdef struct CVIProjection{R, S, P} <: AbstractApproximationMethod
Base.@kwdef struct CVIProjection{R, S, OF, IFS} <: AbstractApproximationMethod
rng::R = Random.MersenneTwister(42)
marginalsamples::S = 10
outsamples::S = 100
prjparams::P = nothing # ExponentialFamilyProjection.DefaultProjectionParameters()
out_prjparams::OF = nothing
in_prjparams::IFS = nothing
end

function get_kth_in_form(::CVIProjection{R, S, OF, Nothing}, ::Int) where {R, S, OF}
return nothing
end

function get_kth_in_form(method::CVIProjection, k::Int)
key = Symbol("in_$k")
return get(method.in_prjparams, key, nothing)
end

# This method should only be invoked if a user did not install `ExponentialFamilyProjection`
Expand Down
138 changes: 138 additions & 0 deletions test/ext/ReactiveMPProjectionExt/ReactiveMPProjectionExt_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,141 @@
@test ext.DivisionOf(d1, d2) == prod(GenericProd(), ext.DivisionOf(d1, d2), missing)
@test ext.DivisionOf(d1, d2) == prod(GenericProd(), missing, ext.DivisionOf(d1, d2))
end

@testitem "create_project_to_ins type stability" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase, Test
using ReactiveMP: CVIProjection
using JET

# `create_project_to_ins` is internal to the extension
ext = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ext)
using .ext

@testset "Complete type stability tests for create_project_to_ins" begin
# Test Case 1: Default form (nothing)
let
method = CVIProjection()
m_in = NormalMeanVariance(0.0, 1.0)
k = 1

@test_opt ext.create_project_to_ins(method, m_in, k)
result = ext.create_project_to_ins(method, m_in, k)
@test result isa ProjectedTo{<:NormalMeanVariance}
end

# Test Case 2: Custom form specified
let
form = ProjectedTo(MvNormalMeanScalePrecision, 2)
method = CVIProjection(in_prjparams = (in_1 = form,))
m_in = NormalMeanVariance(0.0, 1.0) # Input type different from target
k = 1

@test_opt ext.create_project_to_ins(method, m_in, k)
result = ext.create_project_to_ins(method, m_in, k)
@test result isa ProjectedTo{<:MvNormalMeanScalePrecision}
end

# Test Case 3: Multiple forms specified
let
forms = (in_1 = ProjectedTo(NormalMeanVariance), in_2 = ProjectedTo(MvNormalMeanCovariance))
method = CVIProjection(in_prjparams = forms)
m_in = Gamma(2.0, 2.0)

for k in 1:2
@test_opt ext.create_project_to_ins(method, m_in, k)
result = ext.create_project_to_ins(method, m_in, k)

if k == 1
@test result isa ProjectedTo{<:NormalMeanVariance}
else
@test result isa ProjectedTo{<:MvNormalMeanCovariance}
end
end
end

# Test Case 4: not form but just a gradient descent parameters
let
params = ExponentialFamilyProjection.DefaultProjectionParameters()
method = CVIProjection(in_prjparams = (in_1 = params,))
m_in = Gamma(1.0, 1.0)
k = 1

@test_opt ext.create_project_to_ins(method, m_in, k)
result = ext.create_project_to_ins(method, m_in, k)
@test result isa ProjectedTo{<:Gamma}
end
end
end

@testitem "create_project_to type stability" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase, Test
using ReactiveMP: CVIProjection
using JET

# `create_project_to_ins` is internal to the extension
ext = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ext)
using .ext

@testset "Complete type stability tests for create_project_to" begin
# Test Case 1: Default form (Nothing case)
let
method = CVIProjection()
q_out = NormalMeanVariance(0.0, 1.0)
q_out_samples = [[1.0], [2.0], [3.0]]

@test_opt ext.create_project_to(method, q_out, q_out_samples)
result = ext.create_project_to(method, q_out, q_out_samples)

@test result isa ProjectedTo{<:NormalMeanVariance}
@test ExponentialFamilyProjection.get_projected_to_dims(result) == size(first(q_out_samples))
end

# Test Case 2: Existing ProjectedTo form
let
form = ProjectedTo(MvNormalMeanScalePrecision, 2)
method = CVIProjection(out_prjparams = form)
q_out = MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0])
q_out_samples = [[1.0, 2.0], [3.0, 4.0]]

@test_opt ext.create_project_to(method, q_out, q_out_samples)
result = ext.create_project_to(method, q_out, q_out_samples)

@test result === method.out_prjparams
@test result isa ProjectedTo{<:MvNormalMeanScalePrecision}
end

# Test Case 3: Custom ProjectionParameters
let
params = ExponentialFamilyProjection.DefaultProjectionParameters()
method = CVIProjection(out_prjparams = params)
q_out = Gamma(2.0, 2.0)
q_out_samples = [[1.0], [2.0], [3.0]]

@test_opt ext.create_project_to(method, q_out, q_out_samples)
result = ext.create_project_to(method, q_out, q_out_samples)

@test result isa ProjectedTo{<:Gamma}
@test result.parameters === method.out_prjparams
@test ExponentialFamilyProjection.get_projected_to_dims(result) == size(first(q_out_samples))
end

# Test Case 4: Different dimensions and distributions
let
method = CVIProjection()
distributions_and_samples = [
(MvNormalMeanScalePrecision([1, 2, 3], 1), [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), (NormalMeanVariance(0.0, 1.0), [[1.0], [2.0]]), (Gamma(2.0, 2.0), [[1.0], [2.0]])
]

for (dist, samples) in distributions_and_samples
@test_opt ext.create_project_to(method, dist, samples)
result = ext.create_project_to(method, dist, samples)

@test result isa ProjectedTo
@test ExponentialFamilyProjection.get_projected_to_dims(result) == size(first(samples))
@test result.parameters isa ExponentialFamilyProjection.ProjectionParameters
end
end
end
end
48 changes: 47 additions & 1 deletion test/ext/ReactiveMPProjectionExt/rules/marginals_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
end

@testset "f(x) -> x, x~EF, out~EF with Categorical" begin
meta = DeltaMeta(method = CVIProjection(), inverse = nothing)
meta = DeltaMeta(
method = CVIProjection(
in_prjparams = (in_1 = ExponentialFamilyProjection.ProjectionParameters(strategy = ExponentialFamilyProjection.ControlVariateStrategy(nsamples = 4000)),)
)
)
inputs_outputs = [
(Categorical([1 / 4, 1 / 4, 1 / 2]), Categorical([1 / 2, 1 / 8, 3 / 8])),
(Categorical([1 / 8, 1 / 8, 3 / 4]), Categorical([1 / 16, 13 / 16, 1 / 8])),
Expand All @@ -65,3 +69,45 @@
end
end
end

@testitem "CVIProjection form access tests" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase, LinearAlgebra
import ReactiveMP: get_kth_in_form

@testset "Testing input edge form access with get_kth_in_form" begin
# Create forms for specific inputs
form1 = ProjectedTo(NormalMeanVariance)

form2 = ProjectedTo(MvNormalMeanScalePrecision, 2)

# Check form access behavior
method_with_forms = CVIProjection(in_prjparams = (in_1 = form1, in_2 = form2))
@test !isnothing(get_kth_in_form(method_with_forms, 1))
@test !isnothing(get_kth_in_form(method_with_forms, 2))
@test isnothing(get_kth_in_form(method_with_forms, 3)) # Non-existent index

method_default = CVIProjection()
@test isnothing(get_kth_in_form(method_default, 1))
@test isnothing(get_kth_in_form(method_default, 2))

# Test with partial specification
meta_partial = DeltaMeta(method = CVIProjection(
in_prjparams = (in_2 = form2,), # Only specify second input
marginalsamples = 10
), inverse = nothing)

# Setup messages
m_out = MvNormalMeanCovariance([2.0, 3.0], Matrix{Float64}(I, 2, 2))
m_in1 = Gamma(2.0, 2.0)
m_in2 = MvNormalMeanCovariance([1.0, 1.0], [2.0 0.0; 0.0 2.0])

f(x, y) = x .* y

result = @call_marginalrule DeltaFn{f}(:ins) (m_out = m_out, m_ins = ManyOf(m_in1, m_in2), meta = meta_partial)

# First input should use default form (nothing specified)
# Second input should be MvNormalMeanScalePrecision as specified
@test isa(result[1], Gamma)
@test isa(result[2], MvNormalMeanScalePrecision)
end
end
Loading
Loading