From 8d1b5f9a2c3253def2fca0047355fa3dbc20fa6c Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Sun, 10 Sep 2023 16:17:09 -0700 Subject: [PATCH] Type `EliminationTree` subtypes `AbstractVector{Vector{Int}}`. Also, more tests. --- src/architectures.jl | 4 +- src/elimination.jl | 4 +- src/systems.jl | 51 ++++------- test/runtests.jl | 203 ++++++++++++++++++++++++++++++++++--------- 4 files changed, 184 insertions(+), 78 deletions(-) diff --git a/src/architectures.jl b/src/architectures.jl index 76efedc..ee85c2a 100644 --- a/src/architectures.jl +++ b/src/architectures.jl @@ -1,4 +1,4 @@ -# A mailbox in the Shanoy-Shafer architecture. +# A mailbox in the Shenoy-Shafer architecture. mutable struct SSMailbox{T₁, T₂} factor::Union{Nothing, Factor{T₁, T₂}} message_to_parent::Union{Nothing, Factor{T₁, T₂}} @@ -107,7 +107,7 @@ end # Compute the join tree factor -# ψ(v) +# ψᵥ function factor!(arch::SSArchitecture{<:Any, T₁, T₂}, v::Int) where {T₁, T₂} mbx = arch.mailboxes[v] diff --git a/src/elimination.jl b/src/elimination.jl index 12e28be..43d7e82 100644 --- a/src/elimination.jl +++ b/src/elimination.jl @@ -54,14 +54,14 @@ end # An elimination tree. -struct EliminationTree +struct EliminationTree <: AbstractVector{Vector{Int}} rootindex::Int parent::Vector{Int} # pa(v) children::Vector{Vector{Int}} # ch(v) neighbors::Vector{Vector{Int}} # adj⁺(v) end -# Return `true` if +# Determine if # v₁ < v₂ # in the given order. function (order::EliminationOrder)(v₁::Int, v₂::Int) diff --git a/src/systems.jl b/src/systems.jl index eb43762..d890be9 100644 --- a/src/systems.jl +++ b/src/systems.jl @@ -70,55 +70,40 @@ function GaussianSystem(P::T₁, S::T₂, p::T₃, s::T₄, σ::T₅) where { end -function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::MvNormalCanon) where { - T₁, T₂, T₃, T₄, T₅} +function GaussianSystem(d::MvNormalCanon) + CanonicalForm(d.J, d.h) +end - Σ = CanonicalForm(d.J, d.h) - - convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ) + +function GaussianSystem(d::NormalCanon) + CanonicalForm([d.λ;;], [d.η]) end -function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::MvNormal) where { - T₁, T₂, T₃, T₄, T₅} - - Σ = normal(d.μ, d.Σ) - - convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ) +function GaussianSystem(d::MvNormal) + normal(d.μ, d.Σ) end -function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::NormalCanon) where { - T₁, T₂, T₃, T₄, T₅} - - Σ = CanonicalForm([d.λ;;], [d.η]) - - convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ) -end +function GaussianSystem(d::Normal) + normal(d.μ, d.σ) +end -function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::Normal) where { - T₁, T₂, T₃, T₄, T₅} - - Σ = normal(d.μ, d.σ) - - convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ) +function GaussianSystem(cpd::LinearGaussianCPD) + kernel(cpd.a, cpd.b, cpd.σ) end -function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(cpd::LinearGaussianCPD) where { - T₁, T₂, T₃, T₄, T₅} - - Σ = normal(cpd.b, cpd.σ) * [-cpd.a' I] - - convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ) -end +function GaussianSystem(cpd::StaticCPD) + GaussianSystem(cpd.d) +end -function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(cpd::StaticCPD) where { +function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(Σ) where { T₁, T₂, T₃, T₄, T₅} - GaussianSystem{T₁, T₂, T₃, T₄, T₅}(cpd.d) + convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, GaussianSystem(Σ)) end diff --git a/test/runtests.jl b/test/runtests.jl index d638c96..c4a76e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,47 +9,168 @@ using Test @testset "Construction" begin - Σ = normal([3, 1], [1 1; 1 1]) - @test Σ.P ≈ [1/4 1/4; 1/4 1/4] - @test Σ.S ≈ [1/2 -1/2; -1/2 1/2] - @test Σ.p ≈ [1, 1] - @test Σ.s ≈ [1, -1] - @test Σ.σ ≈ 2 - - Σ = normal([3, 1], Eye(2)) - @test Σ.P == [1 0; 0 1] - @test Σ.S == [0 0; 0 0] - @test Σ.p == [3, 1] - @test Σ.s == [0, 0] - @test Σ.σ == 0 - - Σ = normal([3, 1], Zeros(2, 2)) - @test Σ.P == [0 0; 0 0] - @test Σ.S == [1 0; 0 1] - @test Σ.p == [0, 0] - @test Σ.s == [3, 1] - @test Σ.σ == 10 - - Σ = normal(1, 1/2) - @test Σ.P == [4;;] - @test Σ.S == [0;;] - @test Σ.p == [4] - @test Σ.s == [0] - @test Σ.σ == 0 - - Σ = kernel([1 0; 0 1], [3, 1], [1 1; 1 1]) - @test Σ.P ≈ [1/4 1/4 -1/4 -1/4; 1/4 1/4 -1/4 -1/4; -1/4 -1/4 1/4 1/4; -1/4 -1/4 1/4 1/4] - @test Σ.S ≈ [1/2 -1/2 -1/2 1/2; -1/2 1/2 1/2 -1/2; -1/2 1/2 1/2 -1/2; 1/2 -1/2 -1/2 1/2] - @test Σ.p ≈ [-1, -1, 1, 1] - @test Σ.s ≈ [-1, 1, 1, -1] - @test Σ.σ ≈ 2 - - Σ = kernel([1], 1, 1/2) - @test Σ.P == [4 -4; -4 4] - @test Σ.S == [0 0; 0 0] - @test Σ.p == [-4, 4] - @test Σ.s == [ 0, 0] - @test Σ.σ == 0 + @testset "GaussianSystem" begin + d = MvNormalCanon([3, 1], [3 1; 1 2]) + Σ = GaussianSystem(d) + + @test Σ.P == [ + 3 1 + 1 2 + ] + + @test Σ.p == [3, 1] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + + d = NormalCanon(1, 2) + Σ = GaussianSystem(d) + + @test Σ.P == [2;;] + @test Σ.p == [1] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + + d = MvNormal([3, 1], [3 1; 1 2]) + Σ = GaussianSystem(d) + + @test Σ.P ≈ [ + 2/5 -1/5 + -1/5 3/5 + ] + + @test Σ.p ≈ [1, 0] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + + d = Normal(1, √2) + Σ = GaussianSystem(d) + + @test Σ.P ≈ [1/2;;] + @test Σ.p ≈ [1/2] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + + cpd = LinearGaussianCPD(:z, [:x, :y], [1, 2], 1, √2) + Σ = GaussianSystem(cpd) + + @test Σ.P ≈ [ + 0.5 1.0 -0.5 + 1.0 2.0 -1.0 + -0.5 -1.0 0.5 + ] + + @test Σ.p ≈ [-0.5, -1.0, 0.5] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + + cpd = StaticCPD(:x, Normal(1, √2)) + Σ = GaussianSystem(cpd) + + @test Σ.P ≈ [1/2;;] + @test Σ.p ≈ [1/2] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + end + + @testset "normal" begin + Σ = normal([3, 1], [1 1; 1 1]) + + @test Σ.P ≈ [ + 1/4 1/4 + 1/4 1/4 + ] + + @test Σ.S ≈ [ + 1/2 -1/2 + -1/2 1/2 + ] + + @test Σ.p ≈ [1, 1] + @test Σ.s ≈ [1, -1] + @test Σ.σ ≈ 2 + + Σ = normal([3, 1], Eye(2)) + + @test Σ.P == [ + 1 0 + 0 1 + ] + + @test Σ.p == [3, 1] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + + Σ = normal([3, 1], Zeros(2, 2)) + + @test Σ.S == [ + 1 0 + 0 1 + ] + + @test Σ.s == [3, 1] + @test Σ.σ == 10 + + @test iszero(Σ.P) + @test iszero(Σ.p) + + Σ = normal(1, 1/2) + + @test Σ.P == [4;;] + @test Σ.p == [4] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + end + + @testset "kernel" begin + Σ = kernel([1 0; 0 1], [3, 1], [1 1; 1 1]) + + @test Σ.P ≈ [ + 1/4 1/4 -1/4 -1/4 + 1/4 1/4 -1/4 -1/4 + -1/4 -1/4 1/4 1/4 + -1/4 -1/4 1/4 1/4 + ] + + @test Σ.S ≈ [ + 1/2 -1/2 -1/2 1/2 + -1/2 1/2 1/2 -1/2 + -1/2 1/2 1/2 -1/2 + 1/2 -1/2 -1/2 1/2 + ] + + @test Σ.p ≈ [-1, -1, 1, 1] + @test Σ.s ≈ [-1, 1, 1, -1] + @test Σ.σ ≈ 2 + + Σ = kernel([1], 1, 1/2) + + @test Σ.P == [ + 4 -4 + -4 4 + ] + + @test Σ.p == [-4, 4] + + @test iszero(Σ.S) + @test iszero(Σ.s) + @test iszero(Σ.σ) + end end