Skip to content

Commit

Permalink
Can construct a GraphicalModel without specifying types. The Shenoy-S…
Browse files Browse the repository at this point in the history
…hafer architecture is no longer implemented as a recursive algorithm. Factors store objects.
  • Loading branch information
samuelsonric committed Sep 7, 2023
1 parent 0bacd4e commit 41a4e3f
Show file tree
Hide file tree
Showing 19 changed files with 1,271 additions and 938 deletions.
13 changes: 3 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ wd = @relation (x,) where (x::X, y::Y) begin
evidence(y)
end

hom_map = Dict(
hom_map = Dict{Symbol, DenseGaussianSystem{Float64}}(
:prior => normal(0, 1), # x ~ N(0, 1)
:likelihood => kernel([1], 0, 1), # y | x ~ N(x, 1)
:evidence => normal(2, 0)) # y = 2
Expand All @@ -36,15 +36,8 @@ ob_attr = :junction_type
Σ = oapply(wd, hom_map, ob_map; ob_attr)

# Solve using belief propagation.
T₁ = Int
T₂ = DenseGaussianSystem{Float64}
T₃ = Int
T₄ = Vector{Float64}

ip = InferenceProblem{T₁, T₂, T₃, T₄}(wd, hom_map, ob_map; ob_attr)
alg = MinFill()

Σ = solve(ip, alg)
ip = InferenceProblem(wd, hom_map, ob_map; ob_attr)
Σ = solve(ip, MinFill())
```

![inference](./inference.svg)
14 changes: 4 additions & 10 deletions docs/literate/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ n = 100; kf = kalman(n); data = generate_data(n)

evidence = Dict("z$i" => normal(data[i], Zeros(2, 2)) for i in 1:n)

hom_map = Dict(
hom_map = Dict{String, DenseGaussianSystem{Float64}}(
evidence...,
"state" => normal(Zeros(2), 100I(2)),
"predict" => kernel(A, Zeros(2), P),
Expand All @@ -94,14 +94,8 @@ mean(oapply(kf, hom_map, ob_map; ob_attr))
#
@benchmark oapply(kf, hom_map, ob_map; ob_attr)
# Since the filtering problem is large, we may wish to solve it using belief propagation.
T₁ = Int
T₂ = DenseGaussianSystem{Float64}
T₃ = Int
T₄ = Vector{Float64}
ip = InferenceProblem(kf, hom_map, ob_map; ob_attr)

ip = InferenceProblem{T₁, T₂, T₃, T₄}(kf, hom_map, ob_map; ob_attr)
is = init(ip, MinFill())

mean(solve(is))
mean(solve(ip, MinFill()))
#
@benchmark solve(is)
@benchmark solve(ip, MinFill())
10 changes: 6 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

```@docs
GaussianSystem
CanonicalForm
GaussianSystem(::AbstractMatrix, ::AbstractMatrix, ::AbstractVector, ::AbstractVector, ::Real)
CanonicalForm(::AbstractMatrix, ::AbstractVector)
normal
kernel
Expand All @@ -21,9 +24,9 @@ oapply(::AbstractUWD, ::AbstractVector{<:GaussianSystem}, ::AbstractVector)
```@docs
InferenceProblem
InferenceProblem{T₁, T₂, T₃, T₄}(::AbstractUWD, ::AbstractDict, ::AbstractDict) where {T₁, T₂, T₃, T₄}
InferenceProblem{T₁, T₂, T₃, T₄}(::AbstractUWD, ::AbstractVector, ::AbstractVector) where {T₁, T₂, T₃, T₄}
InferenceProblem{T₁, T₂, T₃, T₄}(::BayesNet, ::AbstractVector, ::AbstractDict) where {T₁, T₂, T₃, T₄}
InferenceProblem(::AbstractUWD, ::AbstractDict, ::AbstractDict)
InferenceProblem(::AbstractUWD, ::AbstractVector, ::AbstractVector)
InferenceProblem(::BayesNet, ::AbstractVector, ::AbstractDict)
solve(::InferenceProblem, alg::EliminationAlgorithm)
init(::InferenceProblem, alg::EliminationAlgorithm)
Expand All @@ -34,7 +37,6 @@ init(::InferenceProblem, alg::EliminationAlgorithm)
```@docs
InferenceSolver
solve(::InferenceSolver)
solve!(::InferenceSolver)
```

Expand Down
14 changes: 4 additions & 10 deletions docs/src/generated/kalman.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ n = 100; kf = kalman(n); data = generate_data(n)
evidence = Dict("z$i" => normal(data[i], Zeros(2, 2)) for i in 1:n)
hom_map = Dict(
hom_map = Dict{String, DenseGaussianSystem{Float64}}(
evidence...,
"state" => normal(Zeros(2), 100I(2)),
"predict" => kernel(A, Zeros(2), P),
Expand All @@ -119,18 +119,12 @@ mean(oapply(kf, hom_map, ob_map; ob_attr))
Since the filtering problem is large, we may wish to solve it using belief propagation.

````@example kalman
T₁ = Int
T₂ = DenseGaussianSystem{Float64}
T₃ = Int
T₄ = Vector{Float64}
ip = InferenceProblem(kf, hom_map, ob_map; ob_attr)
ip = InferenceProblem{T₁, T₂, T₃, T₄}(kf, hom_map, ob_map; ob_attr)
is = init(ip, MinFill())
mean(solve(is))
mean(solve(ip, MinFill()))
````

````@example kalman
@benchmark solve(is)
@benchmark solve(ip, MinFill())
````

18 changes: 14 additions & 4 deletions src/AlgebraicInference.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
module AlgebraicInference


# Systems
export CanonicalForm, DenseCanonicalForm, DenseGaussianSystem, GaussianSystem
export , cov, invcov, normal, kernel, mean, oapply, pushforward, var
export , cov, invcov, normal, kernel, mean, oapply, var


# Inference Problems
export InferenceProblem
export init


# Inference Solvers
export InferenceSolver
export solve, solve!


# Algorithms
export EliminationAlgorithm, MinDegree, MinFill


using AbstractTrees
using BayesNets
using Catlab.ACSetInterface, Catlab.Graphs, Catlab.Programs, Catlab.Theories,
Expand All @@ -24,25 +29,30 @@ using FillArrays
using LinearAlgebra
using LinearSolve


using Base: OneTo
using FillArrays: SquareEye, ZerosMatrix, ZerosVector
using LinearAlgebra: checksquare


import AbstractTrees
import Catlab
import CommonSolve
import Distributions
import Graphs
import Statistics


include("./kkt.jl")
include("./systems.jl")
include("./factors.jl")
include("./graphs.jl")
include("./trees.jl")
include("./labels.jl")
include("./models.jl")
include("./algorithms.jl")
include("./elimination.jl")
include("./architectures.jl")
include("./problems.jl")
include("./solvers.jl")
include("./utils.jl")


end
21 changes: 0 additions & 21 deletions src/algorithms.jl

This file was deleted.

178 changes: 178 additions & 0 deletions src/architectures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# A mailbox in the Shanoy-Shafer architecture.
mutable struct SSMailbox{T₁, T₂}
factor::Union{Nothing, Factor{T₁, T₂}}
message_to_parent::Union{Nothing, Factor{T₁, T₂}}
message_from_parent::Union{Nothing, Factor{T₁, T₂}}
end


# The Shenoy-Shafer architecture.
mutable struct SSArchitecture{T₁, T₂, T₃}
labels::Labels{T₁}
factors::Vector{Factor{T₂, T₃}}
order::EliminationOrder
tree::EliminationTree
v_to_fs::Vector{Vector{Int}}
mailboxes::Vector{SSMailbox{T₂, T₃}}
mailboxes_full::Bool
end


function SSMailbox{T₁, T₂}() where {T₁, T₂}
SSMailbox{T₁, T₂}(nothing, nothing, nothing)
end


function SSArchitecture(
labels::Labels,
factors::Vector{Factor{T₁, T₂}},
order::EliminationOrder,
tree::EliminationTree,
v_to_fs::Vector{Vector{Int}}) where {T₁, T₂}

mailboxes = [SSMailbox{T₁, T₂}() for _ in 1:length(labels)]
mailboxes_full = false

SSArchitecture(
labels,
factors,
order,
tree,
v_to_fs,
mailboxes,
mailboxes_full)
end


function SSArchitecture(model::GraphicalModel, order::EliminationOrder)
labels = copy(model.labels)
factors = copy(model.factors)
v_to_fs = deepcopy(model.v_to_fs)

tree = EliminationTree(model.graph, order)

for v in order, f in v_to_fs[v], w in factors[f].vars
if v != w
setdiff!(v_to_fs[w], f)
end
end

SSArchitecture(labels, factors, order, tree, v_to_fs)
end


# Answer a query.
function CommonSolve.solve!(arch::SSArchitecture, query)
if !arch.mailboxes_full
fill_mailboxes!(arch)
end

vars = [arch.labels.index[l] for l in query]

for v in arch.order
if vars [v; arch.tree[v]]
fac = factor!(arch, v)

for t in childindices(arch.tree, v)
msg = message_to_parent!(arch, t)
fac = combine(fac, msg)
end

if v != rootindex(arch.tree)
msg = message_from_parent!(arch, v)
fac = combine(fac, msg)
end

fac = project(fac, vars)

return permute(fac, vars)
end
end

error("Query not covered by join tree.")
end


function fill_mailboxes!(arch::SSArchitecture)
for v in arch.order[1:end - 1] # Collect phase
message_to_parent!(arch, v)
end

for v in arch.order[end - 1:-1:1] # Distribute phase
message_from_parent!(arch, v)
end

arch.mailboxes_full = true
end


# Compute the join tree factor
# ψ(v)
function factor!(arch::SSArchitecture{T₁, T₂, T₃}, v::Int) where {T₁, T₂, T₃}
mbx = arch.mailboxes[v]

if isnothing(mbx.factor)
fac = zero(Factor{T₂, T₃})

for f in arch.v_to_fs[v]
fac = combine(fac, arch.factors[f])
end

mbx.factor = fac
end

mbx.factor::Factor{T₂, T₃}
end


# Compute the message
# μ v → pa(v)
function message_to_parent!(arch::SSArchitecture{T₁, T₂, T₃}, v::Int) where {T₁, T₂, T₃}
@assert v != rootindex(arch.tree)

mbx = arch.mailboxes[v]

if isnothing(mbx.message_to_parent)
fac = factor!(arch, v)

for t in childindices(arch.tree, v)
msg = message_to_parent!(arch, t)
fac = combine(fac, msg)
end

mbx.message_to_parent = project(fac, arch.tree[v])
end

mbx.message_to_parent::Factor{T₂, T₃}
end


# Compute the message
# μ pa(v) → v
function message_from_parent!(arch::SSArchitecture{T₁, T₂, T₃}, v::Int) where {T₁, T₂, T₃}
@assert v != rootindex(arch.tree)

mbx = arch.mailboxes[v]

if isnothing(mbx.message_from_parent)
u = parentindex(arch.tree, v)

fac = factor!(arch, u)

for t in childindices(arch.tree, u)
if t != v
msg = message_to_parent!(arch, t)
fac = combine(fac, msg)
end
end

if u != rootindex(arch.tree)
msg = message_from_parent!(arch, u)
fac = combine(fac, msg)
end

mbx.message_from_parent = project(fac, arch.tree[v])
end

mbx.message_from_parent::Factor{T₂, T₃}
end
Loading

0 comments on commit 41a4e3f

Please sign in to comment.