Skip to content

Commit

Permalink
style: 🎨
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Nov 15, 2024
1 parent 9ae05c9 commit 4509abb
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions test/ext/ReactiveMPProjectionExt/rules/marginals_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ end
rng = MersenneTwister(123)
method = CVIProjection(rng = rng, marginalsamples = 2000)
meta = DeltaMeta(method = method, inverse = nothing)

f(x, y) = x * y

# Define distributions
m_out = NormalMeanVariance(2.0, 0.1)
m_in1 = NormalMeanVariance(0.0, 2.0)
m_in2 = NormalMeanVariance(0.0, 2.0)

# Function to compute unnormalized log posterior for a sample
function log_posterior(x, y)
return logpdf(m_in1, x) + logpdf(m_in2, y) + logpdf(m_out, f(x, y))
Expand All @@ -137,24 +137,20 @@ end
function estimate_kl_divergence(q_result)
n_samples = 10000
samples_q = [(rand(rng, q_result[1]), rand(rng, q_result[2])) for _ in 1:n_samples]

# Compute E_q[log q(x,y) - log p(x,y)]
log_q_terms = [logpdf(q_result[1], x) + logpdf(q_result[2], y) for (x, y) in samples_q]
log_p_terms = [log_posterior(x, y) for (x, y) in samples_q]

return mean(log_q_terms .- log_p_terms)
end

# Run multiple iterations and collect KL divergences
n_iterations = 10
kl_divergences = Vector{Float64}(undef, n_iterations)

for i in 1:n_iterations
result = @call_marginalrule DeltaFn{f}(:ins) (
m_out = m_out,
m_ins = ManyOf(m_in1, m_in2),
meta = meta
)
result = @call_marginalrule DeltaFn{f}(:ins) (m_out = m_out, m_ins = ManyOf(m_in1, m_in2), meta = meta)
kl_divergences[i] = estimate_kl_divergence(result)
end

Expand All @@ -180,7 +176,7 @@ end
using Test
using BenchmarkTools
using BayesBase, ExponentialFamily, ExponentialFamilyProjection

f(x, y) = [x, y]

function run_marginal_test(strategy)
Expand All @@ -189,11 +185,7 @@ end
m_in1 = NormalMeanVariance(0.0, 2.0)
m_in2 = NormalMeanVariance(0.0, 2.0)
return @belapsed begin
@call_marginalrule DeltaFn{f}(:ins) (
m_out = $m_out,
m_ins = ManyOf($m_in1, $m_in2),
meta = $meta
)
@call_marginalrule DeltaFn{f}(:ins) (m_out = $m_out, m_ins = ManyOf($m_in1, $m_in2), meta = $meta)
end samples = 2
end

Expand All @@ -204,5 +196,5 @@ end
@test mean_time < full_time

# Optional: Print the actual times for verification
@info "Sampling strategy performance" full_time mean_time ratio=(full_time/mean_time)
@info "Sampling strategy performance" full_time mean_time ratio = (full_time / mean_time)
end

0 comments on commit 4509abb

Please sign in to comment.