Skip to content

Commit

Permalink
No longer assuming junctions are R1.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelsonric committed Jun 29, 2023
1 parent 5c58f94 commit 3e21016
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 301 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
4 changes: 3 additions & 1 deletion src/AlgebraicInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ using FillArrays
using Graphs
using LinearAlgebra
using LinearSolve
using OrderedCollections

using Base: OneTo
using Catlab.CategoricalAlgebra: FinSet, StructuredCospanOb, StructuredMulticospan
using FillArrays: ZerosMatrix
using FillArrays: SquareEye, ZerosMatrix, ZerosVector
using Graphs: neighbors
using LinearAlgebra: checksquare

Expand Down
29 changes: 14 additions & 15 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ An inference problem over a valuation algebra. Construct a solver for an inferen
with the function [`init`](@ref), or solve it directly with [`solve`](@ref).
"""
mutable struct InferenceProblem{T₁, T₂}
kb::Vector{Valuation{T₁}}
objects::T₂
pg::Graph{Int}
factors::Vector{Valuation{T₁}}
objects::Vector{T₂}
graph::Graph{Int}
query::Vector{Int}
end

Expand Down Expand Up @@ -35,12 +35,12 @@ struct MinFill end
Construct an inference problem that performs undirected composition. Before being composed,
the values of `hom_map` are converted to type `T`.
"""
function InferenceProblem{T}(wd::AbstractUWD, hom_map::AbstractDict,
function InferenceProblem{T₁, T₂}(wd::AbstractUWD, hom_map::AbstractDict,
ob_map::Union{Nothing, AbstractDict}=nothing;
hom_attr=:name, ob_attr=:variable) where T
hom_attr=:name, ob_attr=:variable) where {T₁, T₂}
homs = [hom_map[x] for x in subpart(wd, hom_attr)]
obs = isnothing(ob_map) ? nothing : [ob_map[x] for x in subpart(wd, ob_attr)]
InferenceProblem{T}(wd, homs, obs)
InferenceProblem{T₁, T₂}(wd, homs, obs)
end

"""
Expand All @@ -50,26 +50,25 @@ end
Construct an inference problem that performs undirected composition. Before being composed,
the elements of `homs` are converted to type `T`.
"""
function InferenceProblem{T}(wd::AbstractUWD, homs::AbstractVector,
obs::Union{Nothing, AbstractVector}=nothing) where T
function InferenceProblem{T₁, T₂}(wd::AbstractUWD, homs::AbstractVector, obs::AbstractVector) where {T₁, T₂}
@assert nboxes(wd) == length(homs)
@assert isnothing(obs) || njunctions(wd) == length(obs)
@assert njunctions(wd) == length(obs)
query = collect(subpart(wd, :outer_junction))
ports = collect(subpart(wd, :junction))
kb = Vector{Valuation{T}}(undef, nboxes(wd))
pg = Graph(njunctions(wd))
factors = Vector{Valuation{T}}(undef, nboxes(wd))
graph = Graph(njunctions(wd))
i = 1
for i₁ in 2:length(ports)
for i₂ in i:i₁ - 1
if ports[i₁] != ports[i₂]
add_edge!(pg, ports[i₁], ports[i₂])
add_edge!(graph, ports[i₁], ports[i₂])
end
end
if box(wd, i) != box(wd, i₁)
kb[box(wd, i)] = Valuation{T}(homs[box(wd, i)], ports[i:i₁ - 1])
factors[box(wd, i)] = Valuation{T}(homs[box(wd, i)], ports[i:i₁ - 1])
i = i₁
end
end
kb[end] = Valuation{T}(homs[end], ports[i:end])
InferenceProblem(kb, obs, pg, query)
factors[end] = Valuation{T}(homs[end], ports[i:end])
InferenceProblem{T₁, T₂}(factors, obs, graph, query)
end
62 changes: 34 additions & 28 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ sol2 = solve(is)
```
"""
mutable struct InferenceSolver{T₁, T₂}
jt::JoinTree{T₁}
objects::T₂
tree::JoinTree{T₁}
objects::Vector{T₂}
query::Vector{Int}
end

Expand All @@ -27,30 +27,32 @@ Construct a solver for an inference problem. The options for `alg` are
"""
init(ip::InferenceProblem, alg)

function init(ip::InferenceProblem{T}, ::MinDegree) where T
pg = copy(ip.pg)
function init(ip::InferenceProblem{T₁, T₂}, ::MinDegree) where {T₁, T₂}
graph = copy(ip.graph)
for i₁ in 2:length(ip.query)
for i₂ in 1:i₁ - 1
if ip.query[i₁] != ip.query[i₂]
add_edge!(pg, ip.query[i₁], ip.query[i₂])
add_edge!(graph, ip.query[i₁], ip.query[i₂])
end
end
end
order = mindegree!(copy(pg))
InferenceSolver(JoinTree(ip.kb, pg, order), ip.objects, ip.query)
order = mindegree!(copy(graph))
tree = JoinTree{T₁}(ip.factors, ip.objects, graph, order)
InferenceSolver{T₁, T₂}(tree, ip.objects, ip.query)
end

function init(ip::InferenceProblem{T}, ::MinFill) where T
pg = copy(ip.pg)
function init(ip::InferenceProblem{T₁, T₂}, ::MinFill) where {T₁, T₂}
graph = copy(ip.graph)
for i₁ in 2:length(ip.query)
for i₂ in 1:i₁ - 1
if ip.query[i₁] != ip.query[i₂]
add_edge!(pg, ip.query[i₁], ip.query[i₂])
add_edge!(graph, ip.query[i₁], ip.query[i₂])
end
end
end
order = minfill!(copy(pg))
InferenceSolver(JoinTree(ip.kb, pg, order), ip.objects, ip.query)
order = minfill!(copy(graph))
tree = JoinTree{T₁}(ip.factors, ip.objects, graph, order)
InferenceSolver{T₁, T₂}(tree, ip.objects, ip.query)
end

"""
Expand All @@ -68,19 +70,21 @@ solve(ip::InferenceProblem, alg)
Solve an inference problem.
"""
function solve(is::InferenceSolver{T}) where T
dom = collect(Set(is.query))
for node in PreOrderDFS(is.jt)
if dom node.domain
variables = unique(is.query)
for node in PreOrderDFS(is.tree)
if variables node.domain
factor = node.factor
for child in node.children
factor = combine(factor, message_to_parent(child)::Valuation{T})
message = message_to_parent(child, is.objects)::Valuation{T}
factor = combine(factor, message, is.objects)
end
if !isroot(node)
factor = combine(factor, message_from_parent(node)::Valuation{T})
message = message_from_parent(node, is.objects)::Valuation{T}
factor = combine(factor, message, is.objects)
end
factor = project(factor, domain(factor) dom)
factor = extend(factor, dom, is.objects)
return expand(factor, is.query)
factor = project(factor, domain(factor) variables, is.objects)
factor = extend(factor, variables, is.objects)
return expand(factor, is.query, is.objects)
end
end
error("Query not covered by join tree.")
Expand All @@ -92,19 +96,21 @@ end
Solve an inference problem, caching intermediate computations.
"""
function solve!(is::InferenceSolver{T}) where T
dom = collect(Set(is.query))
for node in PreOrderDFS(is.jt)
if dom node.domain
variables = unique(is.query)
for node in PreOrderDFS(is.tree)
if variables node.domain
factor = node.factor
for child in node.children
factor = combine(factor, message_to_parent!(child)::Valuation{T})
message = message_to_parent!(child, is.objects)::Valuation{T}
factor = combine(factor, message, is.objects)
end
if !isroot(node)
factor = combine(factor, message_from_parent!(node)::Valuation{T})
message = message_from_parent!(node, is.objects)::Valuation{T}
factor = combine(factor, message, is.objects)
end
factor = project(factor, domain(factor) dom)
factor = extend(factor, dom, is.objects)
return expand(factor, is.query)
factor = project(factor, domain(factor) variables, is.objects)
factor = extend(factor, variables, is.objects)
return expand(factor, is.query, is.objects)
end
end
error("Query not covered by join tree.")
Expand Down
92 changes: 41 additions & 51 deletions src/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ struct GaussianSystem{T₁, T₂, T₃, T₄, T₅}
end
end

const AbstractGaussianSystem{T} = GaussianSystem{
<:AbstractMatrix{T},
<:AbstractMatrix{T},
<:AbstractVector{T},
<:AbstractVector{T},
T}
const AbstractGaussianSystem{T₁, T₂, T₃, T₄, T₅} = GaussianSystem{
<:AbstractMatrix{T},
<:AbstractMatrix{T},
<:AbstractVector{T},
<:AbstractVector{T},
T}

const DenseGaussianSystem{T} = GaussianSystem{
Matrix{T},
Expand Down Expand Up @@ -74,12 +74,25 @@ function GaussianSystem(P::T₁, S::T₂, p::T₃, s::T₄, σ::T₅) where {
GaussianSystem{T₁, T₂, T₃, T₄, T₅}(P, S, p, s, σ)
end


function convert(::Type{GaussianSystem{T₁, T₂, T₃, T₄, T₅}}, Σ::GaussianSystem) where {
T₁, T₂, T₃, T₄, T₅}
GaussianSystem{T₁, T₂, T₃, T₄, T₅}.P, Σ.S, Σ.p, Σ.s, Σ.σ)
end

function convert(::Type{T}, L::AbstractMatrix) where T <: GaussianSystem
n = size(L, 1)
convert(T, kernel(L, Zeros(n), Zeros(n, n)))
end

function convert(::Type{T}, μ::AbstractVector) where T <: GaussianSystem
n = length(μ)
convert(T, normal(μ, Zeros(n, n)))
end

function convert(::Type{T}, μ::Real) where T <: GaussianSystem
convert(T, [μ])
end

"""
normal(μ::AbstractVector, Σ::AbstractMatrix)
Expand Down Expand Up @@ -235,6 +248,7 @@ function zero(Σ::GaussianSystem)
end

function zero(::Type{GaussianSystem{T₁, T₂, T₃, T₄, T₅}}, n) where {T₁, T₂, T₃, T₄, T₅}
@assert n >= 0
GaussianSystem{T₁, T₂, T₃, T₄, T₅}(Zeros(n, n), Zeros(n, n), Zeros(n), Zeros(n), 0)
end

Expand Down Expand Up @@ -265,57 +279,33 @@ function pushforward(Σ::GaussianSystem, M::AbstractMatrix)
end

"""
marginal(Σ::GaussianSystem, is::AbstractVector{Int})
oapply(wd::AbstractUWD, homs::AbstractVector{<:GaussianSystem}, obs::AbstractVector)
Compute the marginal of `Σ` along the indices specified by `is`.
Compose Gaussian systems according to the undirected wiring diagram `wd`.
"""
function marginal::GaussianSystem, is::AbstractVector{Int})
P, S = Σ.P, Σ.S
p, s = Σ.p, Σ.s
σ = Σ.σ

n = length(Σ)
js = setdiff(1:n, is)
function oapply(wd::AbstractUWD, homs::AbstractVector{<:GaussianSystem}, obs::AbstractVector)
@assert nboxes(wd) == length(homs)
@assert njunctions(wd) == length(obs)

P₁₁ = P[is, is]; P₁₂ = P[is, js]; P₂₁ = P[js, is]; P₂₂ = P[js, js]
S₁₁ = S[is, is]; S₁₂ = S[is, js]; S₂₁ = S[js, is]; S₂₂ = S[js, js]
p₁ = p[is]; p₂ = p[js]
s₁ = s[is]; s₂ = s[js]

K = KKT(P₂₂, S₂₂)
ports = collect(subpart(wd, :junction))
query = collect(subpart(wd, :outer_junction))

A = solve!(K, P₂₁, S₂₁)
a = solve!(K, p₂, s₂)
n = sum(obs)
L = falses(sum(obs[ports]), n)
R = falses(sum(obs[query]), n)

GaussianSystem(
P₁₁ + A' * P₂₂ * A - P₁₂ * A - A' * P₂₁,
S₁₁ - A' * S₂₂ * A,
p₁ + A' * P₂₂ * a - P₁₂ * a - A' * p₂,
s₁ - S₁₂ * a,
σ - s₂' * a)
end

"""
oapply(wd::AbstractUWD, systems::AbstractVector{<:GaussianSystem})
cs = cumsum(obs)

Compose Gaussian systems according to the undirected wiring diagram `wd`.
"""
function oapply(wd::AbstractUWD, systems::AbstractVector{<:GaussianSystem})
@assert nboxes(wd) == length(systems)
ports = collect(subpart(wd, :junction))
query = collect(subpart(wd, :outer_junction))
L = falses(length(ports), njunctions(wd))
R = falses(length(query), njunctions(wd))
for (i, j) in enumerate(ports)
L[i, j] = true
for ((i, j), m) in zip(enumerate(ports), cumsum(obs[ports]))
o = obs[j]
L[m - o + 1:m, cs[j] - o + 1:cs[j]] = I(o)
end
for (i, j) in enumerate(query)
R[i, j] = true

for ((i, j), m) in zip(enumerate(query), cumsum(obs[query]))
o = obs[j]
R[m - o + 1:m, cs[j] - o + 1:cs[j]] = I(o)
end
Σ = reduce(, systems; init=zero(DenseGaussianSystem{Bool}, 0))
pushforward* L, R)
end

function oapply(wd::AbstractUWD, systems::AbstractVector{<:GaussianSystem}, ::Nothing)
oapply(wd, systems)
Σ = reduce(, homs; init=zero(DenseGaussianSystem{Bool}, 0))
pushforward* L, R)
end
22 changes: 11 additions & 11 deletions src/trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ mutable struct JoinTree{T} <: AbstractNode{Int}
message_from_parent::Union{Nothing, Valuation{T}}
message_to_parent::Union{Nothing, Valuation{T}}

function JoinTree(factor::Valuation{T}, id, domain) where T
function JoinTree{T}(factor, id, domain) where T
new{T}(factor, id, domain, JoinTree{T}[], nothing, nothing, nothing)
end
end

function JoinTree(kb::Vector{Valuation{T}}, pg::AbstractGraph, order) where T
pg = copy(pg)
ls = collect(vertices(pg))
vs = collect(vertices(pg))
function JoinTree{T}(factors, objects, graph, order) where T
graph = copy(graph)
ls = collect(vertices(graph))
vs = collect(vertices(graph))
ns = JoinTree{T}[]
vpll = map(_ -> Set{Int}(), ls)
for j in 1:length(kb)
for js in vpll[domain(kb[j])]
for j in 1:length(factors)
for js in vpll[domain(factors[j])]
push!(js, j)
end
end
Expand All @@ -28,12 +28,12 @@ function JoinTree(kb::Vector{Valuation{T}}, pg::AbstractGraph, order) where T
v = vs[l]
factor = one(Valuation{T})
for j in vpll[l]
factor = combine(factor, kb[j])
for js in vpll[domain(kb[j])]
factor = combine(factor, factors[j], objects)
for js in vpll[domain(factors[j])]
delete!(js, j)
end
end
node = JoinTree(factor, i, [l; ls[neighbors(pg, v)]])
node = JoinTree{T}(factor, i, [l; ls[neighbors(graph, v)]])
for j in length(ns):-1:1
if l in ns[j].domain
ns[j].parent = node
Expand All @@ -43,7 +43,7 @@ function JoinTree(kb::Vector{Valuation{T}}, pg::AbstractGraph, order) where T
end
vs[ls[end]] = v
push!(ns, node)
eliminate!(pg, ls, v)
eliminate!(graph, ls, v)
end
ns[end]
end
Expand Down
Loading

0 comments on commit 3e21016

Please sign in to comment.