diff --git a/src/valuations.jl b/src/valuations.jl index c858dff..ebdf728 100644 --- a/src/valuations.jl +++ b/src/valuations.jl @@ -50,7 +50,7 @@ Get the domain of ``\\phi``. domain(ϕ::Valuation) function domain(ϕ::IdentityValuation{T}) where T - Set{T}() + T[] end function domain(ϕ::LabeledBox) @@ -135,6 +135,7 @@ function project(ϕ::LabeledBox, x) end function project(ϕ::LabeledBox{<:Any, <:GaussianSystem}, x) + @assert x ⊆ ϕ.labels m = [X in x for X in ϕ.labels] LabeledBox(ϕ.labels[m], marginal(m, ϕ.box)) end diff --git a/test/runtests.jl b/test/runtests.jl index 402c67a..35d6e14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,28 +87,57 @@ using Test :measure => kernel(R, H), :observe₁ => normal(z₁), :observe₂ => normal(z₂)) + Σ = oapply(composite, box_map) @test isapprox(true_cov, cov(Σ); rtol=1e-3) @test isapprox(true_mean, mean(Σ); rtol=1e-3) kb, query = inference_problem(composite, box_map) + @test query == Set([:x21, :x22, :x23, :x24, :x25, :x26]) + jt = architecture(kb, minfill!(primal_graph(kb), query)) + @test_throws ErrorException("Query not covered by join tree.") answer_query(jt, [:x31]) + @test_throws ErrorException("Query not covered by join tree.") answer_query!(jt, [:x31]) + ϕ = answer_query(jt, query) M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels] + @test length(ϕ) == length(query) @test Set(domain(ϕ)) == query @test isapprox(true_cov, M * cov(ϕ.box) * M'; rtol=1e-3) @test isapprox(true_mean, M * mean(ϕ.box); rtol=1e-3) - jt = architecture(kb, minwidth!(primal_graph(kb), query)) + ϕ = answer_query!(jt, query) + M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels] + @test length(ϕ) == length(query) + @test Set(domain(ϕ)) == query + @test isapprox(true_cov, M * cov(ϕ.box) * M'; rtol=1e-3) + @test isapprox(true_mean, M * mean(ϕ.box); rtol=1e-3) + + jt = architecture(kb, minwidth!(primal_graph(kb), [])) + ϕ = answer_query(jt, query) M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels] + @test length(ϕ) == length(query) @test Set(domain(ϕ)) == query @test isapprox(true_cov, M * cov(ϕ.box) * M'; rtol=1e-3) @test isapprox(true_mean, M * mean(ϕ.box); rtol=1e-3) ϕ = answer_query!(jt, query) M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels] + @test length(ϕ) == length(query) @test Set(domain(ϕ)) == query @test isapprox(true_cov, M * cov(ϕ.box) * M'; rtol=1e-3) @test isapprox(true_mean, M * mean(ϕ.box); rtol=1e-3) end + +@testset "Identity Valuation" begin + ϕ = LabeledBox([:x, :y], normal([1 0; 0 1])) + e = IdentityValuation{Symbol}() + + @test isempty(domain(e)) + @test eltype(domain(e)) == Symbol + @test combine(e, e) === e + @test combine(ϕ, e) === ϕ + @test combine(e, ϕ) === ϕ + @test project(e, []) === e +end