Skip to content

Commit

Permalink
Type EliminationTree subtypes AbstractVector{Vector{Int}}. Also, …
Browse files Browse the repository at this point in the history
…more tests.
  • Loading branch information
samuelsonric committed Sep 10, 2023
1 parent 5e37a32 commit 8d1b5f9
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 78 deletions.
4 changes: 2 additions & 2 deletions src/architectures.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# A mailbox in the Shanoy-Shafer architecture.
# A mailbox in the Shenoy-Shafer architecture.
mutable struct SSMailbox{T₁, T₂}
factor::Union{Nothing, Factor{T₁, T₂}}
message_to_parent::Union{Nothing, Factor{T₁, T₂}}
Expand Down Expand Up @@ -107,7 +107,7 @@ end


# Compute the join tree factor
# ψ(v)
# ψᵥ
function factor!(arch::SSArchitecture{<:Any, T₁, T₂}, v::Int) where {T₁, T₂}
mbx = arch.mailboxes[v]

Expand Down
4 changes: 2 additions & 2 deletions src/elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ end


# An elimination tree.
struct EliminationTree
struct EliminationTree <: AbstractVector{Vector{Int}}
rootindex::Int
parent::Vector{Int} # pa(v)
children::Vector{Vector{Int}} # ch(v)
neighbors::Vector{Vector{Int}} # adj⁺(v)
end

# Return `true` if
# Determine if
# v₁ < v₂
# in the given order.
function (order::EliminationOrder)(v₁::Int, v₂::Int)
Expand Down
51 changes: 18 additions & 33 deletions src/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,55 +70,40 @@ function GaussianSystem(P::T₁, S::T₂, p::T₃, s::T₄, σ::T₅) where {
end


function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::MvNormalCanon) where {
T₁, T₂, T₃, T₄, T₅}
function GaussianSystem(d::MvNormalCanon)
CanonicalForm(d.J, d.h)
end

Σ = CanonicalForm(d.J, d.h)
convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ)

function GaussianSystem(d::NormalCanon)
CanonicalForm([d.λ;;], [d.η])
end


function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::MvNormal) where {
T₁, T₂, T₃, T₄, T₅}

Σ = normal(d.μ, d.Σ)

convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ)
function GaussianSystem(d::MvNormal)
normal(d.μ, d.Σ)
end


function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::NormalCanon) where {
T₁, T₂, T₃, T₄, T₅}

Σ = CanonicalForm([d.λ;;], [d.η])

convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ)
end
function GaussianSystem(d::Normal)
normal(d.μ, d.σ)
end


function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(d::Normal) where {
T₁, T₂, T₃, T₄, T₅}

Σ = normal(d.μ, d.σ)

convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ)
function GaussianSystem(cpd::LinearGaussianCPD)
kernel(cpd.a, cpd.b, cpd.σ)
end


function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(cpd::LinearGaussianCPD) where {
T₁, T₂, T₃, T₄, T₅}

Σ = normal(cpd.b, cpd.σ) * [-cpd.a' I]

convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, Σ)
end
function GaussianSystem(cpd::StaticCPD)
GaussianSystem(cpd.d)
end


function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(cpd::StaticCPD) where {
function GaussianSystem{T₁, T₂, T₃, T₄, T₅}(Σ) where {
T₁, T₂, T₃, T₄, T₅}

GaussianSystem{T₁, T₂, T₃, T₄, T₅}(cpd.d)
convert(GaussianSystem{T₁, T₂, T₃, T₄, T₅}, GaussianSystem(Σ))
end


Expand Down
203 changes: 162 additions & 41 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,168 @@ using Test


@testset "Construction" begin
Σ = normal([3, 1], [1 1; 1 1])
@test Σ.P [1/4 1/4; 1/4 1/4]
@test Σ.S [1/2 -1/2; -1/2 1/2]
@test Σ.p [1, 1]
@test Σ.s [1, -1]
@test Σ.σ 2

Σ = normal([3, 1], Eye(2))
@test Σ.P == [1 0; 0 1]
@test Σ.S == [0 0; 0 0]
@test Σ.p == [3, 1]
@test Σ.s == [0, 0]
@test Σ.σ == 0

Σ = normal([3, 1], Zeros(2, 2))
@test Σ.P == [0 0; 0 0]
@test Σ.S == [1 0; 0 1]
@test Σ.p == [0, 0]
@test Σ.s == [3, 1]
@test Σ.σ == 10

Σ = normal(1, 1/2)
@test Σ.P == [4;;]
@test Σ.S == [0;;]
@test Σ.p == [4]
@test Σ.s == [0]
@test Σ.σ == 0

Σ = kernel([1 0; 0 1], [3, 1], [1 1; 1 1])
@test Σ.P [1/4 1/4 -1/4 -1/4; 1/4 1/4 -1/4 -1/4; -1/4 -1/4 1/4 1/4; -1/4 -1/4 1/4 1/4]
@test Σ.S [1/2 -1/2 -1/2 1/2; -1/2 1/2 1/2 -1/2; -1/2 1/2 1/2 -1/2; 1/2 -1/2 -1/2 1/2]
@test Σ.p [-1, -1, 1, 1]
@test Σ.s [-1, 1, 1, -1]
@test Σ.σ 2

Σ = kernel([1], 1, 1/2)
@test Σ.P == [4 -4; -4 4]
@test Σ.S == [0 0; 0 0]
@test Σ.p == [-4, 4]
@test Σ.s == [ 0, 0]
@test Σ.σ == 0
@testset "GaussianSystem" begin
d = MvNormalCanon([3, 1], [3 1; 1 2])
Σ = GaussianSystem(d)

@test Σ.P == [
3 1
1 2
]

@test Σ.p == [3, 1]

@test iszero.S)
@test iszero.s)
@test iszero.σ)

d = NormalCanon(1, 2)
Σ = GaussianSystem(d)

@test Σ.P == [2;;]
@test Σ.p == [1]

@test iszero.S)
@test iszero.s)
@test iszero.σ)

d = MvNormal([3, 1], [3 1; 1 2])
Σ = GaussianSystem(d)

@test Σ.P [
2/5 -1/5
-1/5 3/5
]

@test Σ.p [1, 0]

@test iszero.S)
@test iszero.s)
@test iszero.σ)

d = Normal(1, 2)
Σ = GaussianSystem(d)

@test Σ.P [1/2;;]
@test Σ.p [1/2]

@test iszero.S)
@test iszero.s)
@test iszero.σ)

cpd = LinearGaussianCPD(:z, [:x, :y], [1, 2], 1, 2)
Σ = GaussianSystem(cpd)

@test Σ.P [
0.5 1.0 -0.5
1.0 2.0 -1.0
-0.5 -1.0 0.5
]

@test Σ.p [-0.5, -1.0, 0.5]

@test iszero.S)
@test iszero.s)
@test iszero.σ)

cpd = StaticCPD(:x, Normal(1, 2))
Σ = GaussianSystem(cpd)

@test Σ.P [1/2;;]
@test Σ.p [1/2]

@test iszero.S)
@test iszero.s)
@test iszero.σ)
end

@testset "normal" begin
Σ = normal([3, 1], [1 1; 1 1])

@test Σ.P [
1/4 1/4
1/4 1/4
]

@test Σ.S [
1/2 -1/2
-1/2 1/2
]

@test Σ.p [1, 1]
@test Σ.s [1, -1]
@test Σ.σ 2

Σ = normal([3, 1], Eye(2))

@test Σ.P == [
1 0
0 1
]

@test Σ.p == [3, 1]

@test iszero.S)
@test iszero.s)
@test iszero.σ)

Σ = normal([3, 1], Zeros(2, 2))

@test Σ.S == [
1 0
0 1
]

@test Σ.s == [3, 1]
@test Σ.σ == 10

@test iszero.P)
@test iszero.p)

Σ = normal(1, 1/2)

@test Σ.P == [4;;]
@test Σ.p == [4]

@test iszero.S)
@test iszero.s)
@test iszero.σ)
end

@testset "kernel" begin
Σ = kernel([1 0; 0 1], [3, 1], [1 1; 1 1])

@test Σ.P [
1/4 1/4 -1/4 -1/4
1/4 1/4 -1/4 -1/4
-1/4 -1/4 1/4 1/4
-1/4 -1/4 1/4 1/4
]

@test Σ.S [
1/2 -1/2 -1/2 1/2
-1/2 1/2 1/2 -1/2
-1/2 1/2 1/2 -1/2
1/2 -1/2 -1/2 1/2
]

@test Σ.p [-1, -1, 1, 1]
@test Σ.s [-1, 1, 1, -1]
@test Σ.σ 2

Σ = kernel([1], 1, 1/2)

@test Σ.P == [
4 -4
-4 4
]

@test Σ.p == [-4, 4]

@test iszero.S)
@test iszero.s)
@test iszero.σ)
end
end


Expand Down

0 comments on commit 8d1b5f9

Please sign in to comment.