Skip to content

Commit

Permalink
Supernodal elimination trees, sampling, and the Lauritzen-Spiegelhalt…
Browse files Browse the repository at this point in the history
…er architecture.
  • Loading branch information
samuelsonric committed Sep 19, 2023
1 parent 8d1b5f9 commit 03ffb60
Show file tree
Hide file tree
Showing 22 changed files with 1,605 additions and 592 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractTrees = "0.4"
AMD = "0.5"
AbstractTrees = "0.4"
BayesNets = "3.4"
Catlab = "0.15"
CommonSolve = "0.2"
Expand Down
142 changes: 94 additions & 48 deletions docs/literate/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ using Distributions
using FillArrays
using LinearAlgebra
using Random
using StatsPlots
# A Kalman filter with ``n`` steps is a probability distribution over states
# ``(s_1, \dots, s_n)`` and measurements ``(z_1, \dots, z_n)`` determined by the equations
# ``(x_1, \dots, x_n)`` and measurements ``(y_1, \dots, y_n)`` determined by the equations
# ```math
# s_{i+1} \mid s_i \sim \mathcal{N}(As_i, P)
# p(x_{i+1} \mid x_i) = \mathcal{N}(Ax_i, P)
# ```
# and
# ```math
# z_i \mid s_i \sim \mathcal{N}(Bs_i, Q).
# p(y_i \mid x_i) = \mathcal{N}(Bx_i, Q).
# ```
θ = π / 15

A = [
cos(θ) -sin(θ)
sin(θ) cos(θ)
sin(θ) cos(θ)
]

B = [
Expand All @@ -37,73 +38,118 @@ Q = [
0.0 10.0
]

function generate_data(n; seed=42)

function generate_data(n::Integer; seed::Integer=42)
Random.seed!(seed)
x = zeros(2)
data = Vector{Float64}[]

data = Matrix{Float64}(undef, 2, n)

N₁ = MvNormal(P)
N₂ = MvNormal(Q)

x = Zeros(2)

for i in 1:n
x = rand(MvNormal(A * x, P))
push!(data, rand(MvNormal(B * x, Q)))
x = rand(N₁) + A * x
data[:, i] .= rand(N₂) + B * x
end

data
end;
# The *filtering problem* involves predicting the value of the state ``s_n`` given
# observations of ``(z_1, \dots, z_n)``. The function `kalman` constructs a wiring diagram
# that represents the filtering problem.
function kalman_step(i)
kf = HypergraphDiagram{String, String}(["X"])
add_box!(kf, ["X"]; name="state")
add_box!(kf, ["X", "X"]; name="predict")
add_box!(kf, ["X", "Z"]; name="measure")
add_box!(kf, ["Z"]; name="z$i")

add_wires!(kf, [
(0, 1) => (2, 2),
(1, 1) => (2, 1),
(1, 1) => (3, 1),
(3, 2) => (4, 1)])

kf
end
# The *prediction problem* involves finding the posterior mean and covariance of the state
# ``x_{n + 1}`` given observations of ``(y_1, \dots, y_n)``.
function make_diagram(n::Integer)
outer_ports = ["X"]

uwd = TypedRelationDiagram{String, String, Tuple{Int, Int}}(outer_ports)

x = add_junction!(uwd, "X"; variable=(1, 1))
y = add_junction!(uwd, "Y"; variable=(2, 1))

state = add_box!(uwd, ["X"]; name="state")
predict = add_box!(uwd, ["X", "X"]; name="predict")
measure = add_box!(uwd, ["X", "Y"]; name="measure")
context = add_box!(uwd, ["Y"]; name="y1")

set_junction!(uwd, (state, 1), x)
set_junction!(uwd, (predict, 1), x)
set_junction!(uwd, (measure, 1), x)
set_junction!(uwd, (measure, 2), y)
set_junction!(uwd, (context, 1), y)

for i in 2:n
x = add_junction!(uwd, "X"; variable=(1, i))
y = add_junction!(uwd, "Y"; variable=(2, i))

set_junction!(uwd, (predict, 2), x)

predict = add_box!(uwd, ["X", "X"]; name="predict")
measure = add_box!(uwd, ["X", "Y"]; name="measure")
context = add_box!(uwd, ["Y"]; name="y$i")

function kalman(n)
reduce((kf, i) -> ocompose(kalman_step(i), 1, kf), 2:n; init=kalman_step(1))
set_junction!(uwd, (predict, 1), x)
set_junction!(uwd, (measure, 1), x)
set_junction!(uwd, (measure, 2), y)
set_junction!(uwd, (context, 1), y)
end

i = n + 1
x = add_junction!(uwd, "X"; variable=(1, i))

set_junction!(uwd, (0, 1), x)
set_junction!(uwd, (predict, 2), x)

uwd
end

to_graphviz(kalman(5), box_labels=:name; implicit_junctions=true)
# We generate ``100`` points of data and solve the filtering problem.
n = 100; kf = kalman(n); data = generate_data(n)
to_graphviz(make_diagram(5), box_labels=:name; junction_labels=:variable)
# We generate ``100`` points of data and solve the prediction problem.
n = 100

evidence = Dict("z$i" => normal(data[i], Zeros(2, 2)) for i in 1:n)
diagram = make_diagram(n)

hom_map = Dict{String, DenseGaussianSystem{Float64}}(
evidence...,
"state" => normal(Zeros(2), 100I(2)),
"predict" => kernel(A, Zeros(2), P),
"measure" => kernel(B, Zeros(2), Q))

ob_map = Dict(
"X" => 2,
"Z" => 2)
"Y" => 2)

ob_attr = :junction_type
data = generate_data(n)

Σ = oapply(kf, hom_map, ob_map; ob_attr)
for i in 1:n
hom_map["y$i"] = normal(data[:, i], Zeros(2, 2))
end

μ = mean)
problem = InferenceProblem(diagram, hom_map, ob_map)

round.(μ; digits=4)
#
@benchmark oapply(kf, hom_map, ob_map; ob_attr)
# Since the filtering problem is large, we may wish to solve it using belief propagation.
ip = InferenceProblem(kf, hom_map, ob_map; ob_attr)
solver = init(problem)

Σ = solve(ip, MinFill())
Σ = solve!(solver)

μ = mean(Σ)
m = mean(Σ)

round.(μ; digits=4)
round.(m; digits=4)
#
@benchmark solve(ip, MinFill())
V = cov(Σ)

round.(V; digits=4)
# The smoothing problem involves finding the posterior means and covariances of the states
# ``(x_1, \dots, x_{n - 1})`` given observations of ``(y_1, \dots, y_n)``.
#
# Calling `mean(solver)` computes a dictionary with the posterior mean of every variable in
# the model.
ms = mean(solver)

x = Matrix{Float64}(undef, 2, n)
y = Matrix{Float64}(undef, 2, n)

for i in 1:n
x[:, i] .= ms[1, i]
y[:, i] .= ms[2, i]
end

plot()
plot!(x[1, :], label="x₁")
plot!(x[2, :], label="x₂")
19 changes: 12 additions & 7 deletions docs/literate/regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ P = [
0 0 1 0 0 1
]

hom_map = Dict(
hom_map = Dict{Symbol, DenseGaussianSystem{Float64}}(
:X => kernel(X, Zeros(3), Zeros(3, 3)),
:+ => kernel(P, Zeros(3), Zeros(3, 3)),
=> normal(Zeros(3), W),
Expand All @@ -72,9 +72,11 @@ ob_map = Dict(
:m => 2,
:n => 3)

ob_attr = :junction_type
problem = InferenceProblem(wd, hom_map, ob_map)

β̂ = mean(oapply(wd, hom_map, ob_map; ob_attr))
Σ̂ = solve(problem)

β̂ = mean(Σ̂)

round.(β̂; digits=4)
# ## Bayesian Linear Regression
Expand Down Expand Up @@ -116,7 +118,7 @@ end

to_graphviz(wd; box_labels=:name, implicit_junctions=true)
# Then we assign values to the boxes in `wd` and compute the result.
hom_map = Dict(
hom_map = Dict{Symbol, DenseGaussianSystem{Float64}}(
=> normal(m, V),
:X => kernel(X, Zeros(3), Zeros(3, 3)),
:+ => kernel(P, Zeros(3), Zeros(3, 3)),
Expand All @@ -127,15 +129,18 @@ ob_map = Dict(
:m => 2,
:n => 3)

ob_attr = :junction_type
problem = InferenceProblem(wd, hom_map, ob_map)

Σ̂ = solve(problem)

= mean(oapply(wd, hom_map, ob_map; ob_attr))
= mean(Σ̂)

round.(m̂; digits=4)
#
= cov(oapply(wd, hom_map, ob_map; ob_attr))
= cov(Σ̂)

round.(V̂; digits=4)
#
plot()
covellipse!(m, V, aspect_ratio=:equal, label="prior")
covellipse!(m̂, V̂, aspect_ratio=:equal, label="posterior")
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using BayesNets
using Catlab, Catlab.WiringDiagrams
using Documenter
using Literate
using Random

for file in readdir(joinpath(@__DIR__, "literate"))
Literate.markdown(
Expand Down
28 changes: 20 additions & 8 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Library Reference
e# Library Reference
## Systems

```@docs
GaussianSystem
CanonicalForm
DenseGaussianSystem
DenseCanonicalForm
GaussianSystem(::AbstractMatrix, ::AbstractMatrix, ::AbstractVector, ::AbstractVector, ::Real)
CanonicalForm(::AbstractMatrix, ::AbstractVector)
Expand All @@ -16,35 +18,45 @@ cov(::GaussianSystem)
invcov(::GaussianSystem)
var(::GaussianSystem)
mean(::GaussianSystem)
oapply(::AbstractUWD, ::AbstractVector{<:GaussianSystem}, ::AbstractVector)
```

## Problems
```@docs
InferenceProblem
InferenceProblem(::AbstractUWD, ::AbstractDict, ::AbstractDict)
InferenceProblem(::AbstractUWD, ::AbstractVector, ::AbstractVector)
InferenceProblem(::RelationDiagram, ::AbstractDict, ::AbstractDict, ::AbstractDict)
InferenceProblem(::BayesNet, ::AbstractVector, ::AbstractDict)
solve(::InferenceProblem, alg::EliminationAlgorithm)
init(::InferenceProblem, alg::EliminationAlgorithm)
solve(::InferenceProblem, ::EliminationAlgorithm, ::SupernodeType, ::ArchitectureType)
init(::InferenceProblem, ::EliminationAlgorithm, ::SupernodeType, ::ArchitectureType)
```

## Solvers
```@docs
InferenceSolver
solve!(::InferenceSolver)
mean(::InferenceSolver)
rand(::AbstractRNG, ::InferenceSolver)
```

## Algorithms
## Elimination
```@docs
EliminationAlgorithm
MinDegree
MinFill
CuthillMcKeeJL_RCM
AMDJL_AMD
MetisJL_ND
SupernodeType
Node
MaximalSupernode
```

## Architectures
```@docs
ArchitectureType
ShenoyShafer
LauritzenSpiegelhalter
```
Loading

0 comments on commit 03ffb60

Please sign in to comment.