Skip to content

Commit

Permalink
Added more tests, and an @assert to function project.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelsonric committed Jun 5, 2023
1 parent 0062c5c commit e78a3ba
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/valuations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Get the domain of ``\\phi``.
domain::Valuation)

function domain::IdentityValuation{T}) where T
Set{T}()
T[]
end

function domain::LabeledBox)
Expand Down Expand Up @@ -135,6 +135,7 @@ function project(ϕ::LabeledBox, x)
end

function project::LabeledBox{<:Any, <:GaussianSystem}, x)
@assert x ϕ.labels
m = [X in x for X in ϕ.labels]
LabeledBox.labels[m], marginal(m, ϕ.box))
end
Expand Down
31 changes: 30 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,57 @@ using Test
:measure => kernel(R, H),
:observe₁ => normal(z₁),
:observe₂ => normal(z₂))

Σ = oapply(composite, box_map)
@test isapprox(true_cov, cov(Σ); rtol=1e-3)
@test isapprox(true_mean, mean(Σ); rtol=1e-3)

kb, query = inference_problem(composite, box_map)
@test query == Set([:x21, :x22, :x23, :x24, :x25, :x26])

jt = architecture(kb, minfill!(primal_graph(kb), query))
@test_throws ErrorException("Query not covered by join tree.") answer_query(jt, [:x31])
@test_throws ErrorException("Query not covered by join tree.") answer_query!(jt, [:x31])

ϕ = answer_query(jt, query)
M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels]
@test length(ϕ) == length(query)
@test Set(domain(ϕ)) == query
@test isapprox(true_cov, M * cov.box) * M'; rtol=1e-3)
@test isapprox(true_mean, M * mean.box); rtol=1e-3)

jt = architecture(kb, minwidth!(primal_graph(kb), query))
ϕ = answer_query!(jt, query)
M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels]
@test length(ϕ) == length(query)
@test Set(domain(ϕ)) == query
@test isapprox(true_cov, M * cov.box) * M'; rtol=1e-3)
@test isapprox(true_mean, M * mean.box); rtol=1e-3)

jt = architecture(kb, minwidth!(primal_graph(kb), []))

ϕ = answer_query(jt, query)
M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels]
@test length(ϕ) == length(query)
@test Set(domain(ϕ)) == query
@test isapprox(true_cov, M * cov.box) * M'; rtol=1e-3)
@test isapprox(true_mean, M * mean.box); rtol=1e-3)

ϕ = answer_query!(jt, query)
M = [i == j for i in [:x21, :x22, :x23, :x24, :x25, :x26], j in ϕ.labels]
@test length(ϕ) == length(query)
@test Set(domain(ϕ)) == query
@test isapprox(true_cov, M * cov.box) * M'; rtol=1e-3)
@test isapprox(true_mean, M * mean.box); rtol=1e-3)
end

@testset "Identity Valuation" begin
ϕ = LabeledBox([:x, :y], normal([1 0; 0 1]))
e = IdentityValuation{Symbol}()

@test isempty(domain(e))
@test eltype(domain(e)) == Symbol
@test combine(e, e) === e
@test combine(ϕ, e) === ϕ
@test combine(e, ϕ) === ϕ
@test project(e, []) === e
end

0 comments on commit e78a3ba

Please sign in to comment.