From ab40f0e67a176f9ae58e1e8c9d9f33637bae0f07 Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Wed, 20 Sep 2023 18:46:53 -0700 Subject: [PATCH] Removed `context` from `InferenceProblem` constructor. --- docs/src/api.md | 2 +- src/AlgebraicInference.jl | 2 +- src/architectures.jl | 170 ++++++++++++++++++++++---------------- src/cpds.jl | 16 +++- src/elimination.jl | 80 ++++++++++-------- src/factors.jl | 66 +++------------ src/models.jl | 101 +++++++++++----------- src/problems.jl | 85 +++++++++---------- src/solvers.jl | 27 +++--- src/systems.jl | 2 +- test/runtests.jl | 49 +++++------ 11 files changed, 302 insertions(+), 298 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index b844198..a52149f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -24,7 +24,7 @@ mean(::GaussianSystem) ```@docs InferenceProblem -InferenceProblem(::RelationDiagram, ::AbstractDict, ::AbstractDict, ::AbstractDict) +InferenceProblem(::RelationDiagram, ::AbstractDict, ::AbstractDict) InferenceProblem(::BayesNet, ::AbstractVector, ::AbstractDict) solve(::InferenceProblem, ::EliminationAlgorithm, ::SupernodeType, ::ArchitectureType) diff --git a/src/AlgebraicInference.jl b/src/AlgebraicInference.jl index 6c16df4..246b291 100644 --- a/src/AlgebraicInference.jl +++ b/src/AlgebraicInference.jl @@ -8,7 +8,7 @@ export ⊗, cov, invcov, normal, kernel, mean, oapply, var # Inference Problems export InferenceProblem -export init +export init, reduce_to_context # Inference Solvers diff --git a/src/architectures.jl b/src/architectures.jl index 6d43188..ad010b5 100644 --- a/src/architectures.jl +++ b/src/architectures.jl @@ -73,12 +73,12 @@ end # Construct an architecture. function Architecture( model::GraphicalModel, - elalg::EliminationAlgorithm, - stype::SupernodeType) + elimination_algorithm::EliminationAlgorithm, + supernode_type::SupernodeType) labels = model.labels factors = model.factors - tree = JoinTree(model.graph, elalg, stype) + tree = JoinTree(model.graph, elimination_algorithm, supernode_type) vvll = deepcopy(model.vvll) assignments = Vector{Vector{Int}}(undef, length(tree)) @@ -106,23 +106,29 @@ end # Answer a query. # Algorithm 4.2 in doi:10.1002/9781118010877. -function CommonSolve.solve!(arch::Architecture, atype::ShenoyShafer, query) - arch.collect_phase_complete || collect_phase!(arch, atype) +function CommonSolve.solve!( + architecture::Architecture, + architecture_type::ShenoyShafer, + query::AbstractVector) - vars = [arch.labels.index[l] for l in query] + if !architecture.collect_phase_complete + collect_phase!(architecture, architecture_type) + end + + vars = [architecture.labels.index[l] for l in query] - for n in arch.tree.order - node = IndexNode(arch.tree, n) + for n in architecture.tree.order + node = IndexNode(architecture.tree, n) sep, res = nodevalue(node) if vars ⊆ [sep; res] - distribute_phase!(arch, atype, node.index) + distribute_phase!(architecture, architecture_type, node.index) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) fac = combine(mbx.factor, mbx.message_from_parent) for child in children(node) - mbx = mailbox(arch, child.index) + mbx = mailbox(architecture, child.index) fac = combine(fac, mbx.message_to_parent) end @@ -139,19 +145,25 @@ end # Answer a query. # Algorithm 4.4 in doi:10.1002/9781118010877. -function CommonSolve.solve!(arch::Architecture, atype::LauritzenSpiegelhalter, query) - arch.collect_phase_complete || collect_phase!(arch, atype) +function CommonSolve.solve!( + architecture::Architecture, + architecture_type::LauritzenSpiegelhalter, + query::AbstractVector) + + if !architecture.collect_phase_complete + collect_phase!(architecture, architecture_type) + end - vars = [arch.labels.index[l] for l in query] + vars = [architecture.labels.index[l] for l in query] - for n in arch.tree.order - node = IndexNode(arch.tree, n) + for n in architecture.tree.order + node = IndexNode(architecture.tree, n) sep, res = nodevalue(node) if vars ⊆ [sep; res] - distribute_phase!(arch, atype, node.index) + distribute_phase!(architecture, architecture_type, node.index) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) fac = combine(mbx.cpd, mbx.message_from_parent) fac = project(fac, vars) @@ -166,19 +178,20 @@ end # Sample from an architecture. -function Base.rand(rng::AbstractRNG, arch::Architecture) - @assert arch.collect_phase_complete +function Base.rand(rng::AbstractRNG, architecture::Architecture) + @assert architecture.collect_phase_complete - x = Vector{Vector{Float64}}(undef, length(arch.labels)) + m = length(architecture.labels) + x = Vector{Vector{Float64}}(undef, m) - for n in reverse(arch.tree.order) - node = IndexNode(arch.tree, n) + for n in reverse(architecture.tree.order) + node = IndexNode(architecture.tree, n) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) rand!(rng, mbx.cpd, x) end - Dict(zip(arch.labels, x)) + Dict(zip(architecture.labels, x)) end @@ -188,127 +201,144 @@ end # Compute the mean of an architecture. -function Statistics.mean(arch::Architecture) - @assert arch.collect_phase_complete +function Statistics.mean(architecture::Architecture) + @assert architecture.collect_phase_complete - x = Vector{Vector{Float64}}(undef, length(arch.labels)) + m = length(architecture.labels) + x = Vector{Vector{Float64}}(undef, m) - for n in reverse(arch.tree.order) - node = IndexNode(arch.tree, n) + for n in reverse(architecture.tree.order) + node = IndexNode(architecture.tree, n) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) mean!(mbx.cpd, x) end - Dict(zip(arch.labels, x)) + Dict(zip(architecture.labels, x)) end # The collect phase of the Shenoy-Shafer architecture. # Algorithm 4.1 in doi:10.1002/9781118010877. -function collect_phase!(arch::Architecture{<:Any, T₁, T₂}, atype::ShenoyShafer) where {T₁, T₂} - for n in arch.tree.order - node = IndexNode(arch.tree, n) +function collect_phase!( + architecture::Architecture{<:Any, T₁, T₂}, + architecture_type::ShenoyShafer) where {T₁, T₂} - mbx = mailbox(arch, node.index) - mbx.factor = factor(arch, node.index) + for n in architecture.tree.order + node = IndexNode(architecture.tree, n) + + mbx = mailbox(architecture, node.index) + mbx.factor = factor(architecture, node.index) msg = mbx.factor for child in children(node) - mbx = mailbox(arch, child.index) + mbx = mailbox(architecture, child.index) msg = combine(msg, mbx.message_to_parent) end - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) mbx.message_to_parent, mbx.cpd = disintegrate(msg, first(nodevalue(node))) end - mbx = mailbox(arch, rootindex(arch.tree)) + mbx = mailbox(architecture, rootindex(architecture.tree)) mbx.message_from_parent = zero(Factor{T₁, T₂}) - arch.collect_phase_complete = true + architecture.collect_phase_complete = true end # The collect phase of the Lauritzen-Spiegelhalter architecture. # Algorithm 4.3 in doi:10.1002/9781118010877. -function collect_phase!(arch::Architecture{<:Any, T₁, T₂}, atype::LauritzenSpiegelhalter) where {T₁, T₂} - for n in arch.tree.order - node = IndexNode(arch.tree, n) +function collect_phase!( + architecture::Architecture{<:Any, T₁, T₂}, + architecture_type::LauritzenSpiegelhalter) where {T₁, T₂} + + for n in architecture.tree.order + node = IndexNode(architecture.tree, n) - msg = factor(arch, node.index) + msg = factor(architecture, node.index) for child in children(node) - mbx = mailbox(arch, child.index) + mbx = mailbox(architecture, child.index) msg = combine(msg, mbx.message_to_parent) end - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) mbx.message_to_parent, mbx.cpd = disintegrate(msg, first(nodevalue(node))) end - mbx = mailbox(arch, rootindex(arch.tree)) + mbx = mailbox(architecture, rootindex(architecture.tree)) mbx.message_to_parent = nothing mbx.message_from_parent = zero(Factor{T₁, T₂}) - arch.collect_phase_complete = true + architecture.collect_phase_complete = true end -# The distribute phase of the Shenoy-Shafer architecture. +# The distribute phase of the Shenoy-Shafer architecture. Only distributes from the root to +# node n. # Algorithm 4.1 in doi:10.1002/9781118010877. -function distribute_phase!(arch::Architecture, atype::ShenoyShafer, n::Integer) - node = IndexNode(arch.tree, n) - mbx = mailbox(arch, node.index) +function distribute_phase!( + architecture::Architecture, + architecture_type::ShenoyShafer, + n::Integer) + + node = IndexNode(architecture.tree, n) + mbx = mailbox(architecture, node.index) ancestors = Int[] while !isroot(node) && isnothing(mbx.message_from_parent) push!(ancestors, node.index) node = parent(node) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) end for n in ancestors[end:-1:1] - node = IndexNode(arch.tree, n) + node = IndexNode(architecture.tree, n) prnt = parent(node) - mbx = mailbox(arch, prnt.index) + mbx = mailbox(architecture, prnt.index) msg = combine(mbx.factor, mbx.message_from_parent) for sibling in children(prnt) if node != sibling - mbx = mailbox(arch, sibling.index) + mbx = mailbox(architecture, sibling.index) msg = combine(msg, mbx.message_to_parent) end end - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) mbx.message_from_parent = project(msg, first(nodevalue(node))) end end -# The distribute phase of the Lauritzen-Spiegelhalter architecture. +# The distribute phase of the Lauritzen-Spiegelhalter architecture. Only distributes from +# the root to node n. # Algorithm 4.3 in doi:10.1002/9781118010877. -function distribute_phase!(arch::Architecture, atype::LauritzenSpiegelhalter, n::Integer) - node = IndexNode(arch.tree, n) - mbx = mailbox(arch, node.index) +function distribute_phase!( + architecture::Architecture, + architecture_type::LauritzenSpiegelhalter, + n::Integer) + + node = IndexNode(architecture.tree, n) + mbx = mailbox(architecture, node.index) ancestors = Int[] while !isroot(node) && isnothing(mbx.message_from_parent) push!(ancestors, node.index) node = parent(node) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) end for n in ancestors[end:-1:1] - node = IndexNode(arch.tree, n) + node = IndexNode(architecture.tree, n) prnt = parent(node) - mbx = mailbox(arch, prnt.index) + mbx = mailbox(architecture, prnt.index) msg = combine(mbx.cpd, mbx.message_from_parent) - mbx = mailbox(arch, node.index) + mbx = mailbox(architecture, node.index) mbx.message_to_parent = nothing mbx.message_from_parent = project(msg, first(nodevalue(node))) end @@ -319,18 +349,18 @@ end # μ: n → pa(n) # and # μ: pa(n) → n -function mailbox(arch::Architecture, n::Int) +function mailbox(arch::Architecture, n::Integer) arch.mailboxes[n] end # Compute the join tree factor # ψₙ -function factor(arch::Architecture{<:Any, T₁, T₂}, n::Int) where {T₁, T₂} +function factor(architecture::Architecture{<:Any, T₁, T₂}, n::Integer) where {T₁, T₂} fac = zero(Factor{T₁, T₂}) - for f in arch.assignments[n] - fac = combine(fac, arch.factors[f]) + for f in architecture.assignments[n] + fac = combine(fac, architecture.factors[f]) end fac diff --git a/src/cpds.jl b/src/cpds.jl index 5134fcd..0f05a74 100644 --- a/src/cpds.jl +++ b/src/cpds.jl @@ -63,7 +63,14 @@ function combine(cpd₁::CPD, fac₂::Factor{T₁, T₂}) where {T₁, T₂} end -function disintegrate(fac::Factor{T₁, T₂}, vars::Vector{Int}) where {T₁, T₂} +# Compute the projection +# fac ↓ vars +function project(fac::Factor{T₁, T₂}, vars::AbstractVector) where {T₁, T₂} + first(disintegrate(fac, vars)) +end + + +function disintegrate(fac::Factor{T₁, T₂}, vars::AbstractVector) where {T₁, T₂} i₁ = Int[] i₂ = Int[] @@ -88,7 +95,12 @@ function disintegrate(fac::Factor{T₁, T₂}, vars::Vector{Int}) where {T₁, T end -function disintegrate(hom::GaussianSystem, i₁::Vector{Int}, i₂::Vector{Int}, obs::Vector{Int}) +function disintegrate( + hom::GaussianSystem, + i₁::AbstractVector, + i₂::AbstractVector, + obs::AbstractVector) + cms = cumsum(obs) j₁ = Int[] diff --git a/src/elimination.jl b/src/elimination.jl index ed75c84..24741ec 100644 --- a/src/elimination.jl +++ b/src/elimination.jl @@ -106,7 +106,7 @@ end # Determine if # v₁ < v₂ # in the given order. -function (order::Order)(v₁::Int, v₂::Int) +function (order::Order)(v₁::Integer, v₂::Integer) order.index[v₁] < order.index[v₂] end @@ -130,20 +130,20 @@ end # Construct an elimination order using the minimum-degree heuristic. -function Order(graph::Graphs.Graph, elalg::MinDegree) +function Order(graph::Graphs.AbstractGraph, elimination_algorithm::MinDegree) mindegree!(copy(graph)) end # Construct an elimination order using the minimum-fill heuristic. -function Order(graph::Graphs.Graph, elalg::MinFill) +function Order(graph::Graphs.AbstractGraph, elimination_algorithm::MinFill) minfill!(copy(graph)) end # Construct an elimination order using the reverse Cuthill-McKee algorithm. Uses # CuthillMcKee.jl. -function Order(graph::Graphs.Graph, elalg::CuthillMcKeeJL_RCM) +function Order(graph::Graphs.AbstractGraph, elimination_algorithm::CuthillMcKeeJL_RCM) order = CuthillMcKee.symrcm(Graphs.adjacency_matrix(graph)) Order(order) end @@ -151,22 +151,27 @@ end # Construct an elimination order using the approximate minimum degree algorithm. Uses # AMD.jl. -function Order(graph::Graphs.Graph, elalg::AMDJL_AMD) +function Order(graph::Graphs.AbstractGraph, elimination_algorithm::AMDJL_AMD) order = AMD.symamd(Graphs.adjacency_matrix(graph)) Order(order) end # Construct an elimination order using the nested dissection heuristic. Uses Metis.jl. -function Order(graph::Graphs.Graph, elalg::MetisJL_ND) +function Order(graph::Graphs.AbstractGraph, elimination_algorithm::MetisJL_ND) order, index = Metis.permutation(graph) Order(order, index) end # Construct an elimination tree using the given elimination algorithm. -function EliminationTree(graph::Graphs.Graph, elalg::EliminationAlgorithm) - EliminationTree(OrderedGraph(Order(graph, elalg), graph)) +function EliminationTree( + graph::Graphs.AbstractGraph, + elimination_algorithm::EliminationAlgorithm) + + order = Order(graph, elimination_algorithm) + ordered_graph = OrderedGraph(order, graph) + EliminationTree(ordered_graph) end @@ -239,13 +244,18 @@ function JoinTree( end -function JoinTree(graph::Graphs.Graph, elalg::EliminationAlgorithm, stype::SupernodeType) - JoinTree(EliminationTree(graph, elalg), stype) +function JoinTree( + graph::Graphs.AbstractGraph, + elimination_algorithm::EliminationAlgorithm, + supernode_type::SupernodeType) + + elimination_tree = EliminationTree(graph, elimination_algorithm) + JoinTree(elimination_tree, supernode_type) end # Construct a nodal elimination tree. -function JoinTree(tree::EliminationTree, ::Node) +function JoinTree(tree::EliminationTree, supernode_type::Node) order = tree.order parent = tree.parent children = tree.children @@ -258,7 +268,7 @@ end # Construct a supernodal elimination tree with maximal supernodes. # Algorithm 4.1 in doi:10.1561/2400000006. -function JoinTree(tree::EliminationTree, ::MaximalSupernode) +function JoinTree(tree::EliminationTree, supernode_type::MaximalSupernode) parent = Vector{Int}() children = Vector{Vector{Int}}() residuals = Vector{Vector{Int}}() @@ -310,7 +320,7 @@ end # Get the fill-in number of vertex v. -function fillin(graph::Graphs.Graph, v::Int) +function fillin(graph::Graphs.AbstractGraph, v::Integer) count = 0 ns = Graphs.neighbors(graph, v) n = length(ns) @@ -326,7 +336,7 @@ end # Compute an elimination order using the minimum degree heuristic. -function mindegree!(graph::Graphs.Graph) +function mindegree!(graph::Graphs.AbstractGraph) n = Graphs.nv(graph) order = Order(n) labels = Labels(1:n) @@ -343,7 +353,7 @@ end # Compute a vertex elimination order using the minimum fill heuristic. -function minfill!(graph::Graphs.Graph) +function minfill!(graph::Graphs.AbstractGraph) n = Graphs.nv(graph) order = Order(n) labels = Labels(1:n) @@ -362,7 +372,7 @@ end # Eliminate the vertex v. -function eliminate!(labels::Labels, graph::Graphs.Graph, l) +function eliminate!(labels::Labels, graph::Graphs.AbstractGraph, l) v = labels.index[l] ns = Graphs.neighbors(graph, v) n = length(ns) @@ -379,7 +389,7 @@ end # Eliminate the vertex v. # Adapted from https://github.com/JuliaQX/QXGraphDecompositions.jl/blob/ # 22ee3d75bcd267bf462eec8f03930af2129e34b7/src/LabeledGraph.jl#L326 -function eliminate!(labels::Labels, graph::Graphs.Graph, fillins::Vector{Int}, l) +function eliminate!(labels::Labels, graph::Graphs.AbstractGraph, fillins::Vector{Int}, l) v = labels.index[l] ns = Graphs.neighbors(graph, v) n = length(ns) @@ -449,18 +459,18 @@ function Base.size(A::JoinTree) end -function Base.getindex(A::Order, i::Int) - A.order[i] +function Base.getindex(A::Order, n::Integer) + A.order[n] end -function Base.getindex(A::EliminationTree, i::Int) - A.outneighbors[i] +function Base.getindex(A::EliminationTree, n::Integer) + A.outneighbors[n] end -function Base.getindex(A::JoinTree, i::Int) - A.seperators[i], A.residuals[i] +function Base.getindex(A::JoinTree, n::Integer) + A.seperators[n], A.residuals[n] end @@ -479,13 +489,13 @@ function Base.IndexStyle(::Type{JoinTree}) end -function Base.setindex!(A::Order, v, i::Int) +function Base.setindex!(A::Order, v::Integer, i::Integer) A.order[i] = v A.index[v] = i end -function Base.push!(A::Order, v) +function Base.push!(A::Order, v::Integer) n = length(A) push!(A.order, v) push!(A.index, n + 1) @@ -507,23 +517,23 @@ function AbstractTrees.rootindex(tree::JoinTree) end -function AbstractTrees.parentindex(tree::EliminationTree, i::Int) - i == rootindex(tree) ? nothing : tree.parent[i] +function AbstractTrees.parentindex(tree::EliminationTree, n::Integer) + n == rootindex(tree) ? nothing : tree.parent[n] end -function AbstractTrees.parentindex(tree::JoinTree, i::Int) - i == rootindex(tree) ? nothing : tree.parent[i] +function AbstractTrees.parentindex(tree::JoinTree, n::Integer) + n == rootindex(tree) ? nothing : tree.parent[n] end -function AbstractTrees.childindices(tree::EliminationTree, i::Int) - tree.children[i] +function AbstractTrees.childindices(tree::EliminationTree, n::Integer) + tree.children[n] end -function AbstractTrees.childindices(tree::JoinTree, i::Int) - tree.children[i] +function AbstractTrees.childindices(tree::JoinTree, n::Integer) + tree.children[n] end @@ -562,11 +572,11 @@ function Graphs.nv(g::OrderedGraph) end -function Graphs.outneighbors(g::OrderedGraph, v::Int) +function Graphs.outneighbors(g::OrderedGraph, v::Integer) filter(u -> g.order(v, u), Graphs.neighbors(g.graph, v)) end -function Graphs.inneighbors(g::OrderedGraph, v::Int) +function Graphs.inneighbors(g::OrderedGraph, v::Integer) filter(u -> g.order(u, v), Graphs.neighbors(g.graph, v)) end diff --git a/src/factors.jl b/src/factors.jl index 33fe928..dab0e1e 100644 --- a/src/factors.jl +++ b/src/factors.jl @@ -69,9 +69,9 @@ end function combine( hom₁::GaussianSystem, hom₂::GaussianSystem, - i₁::Vector{Int}, - i₂::Vector{Int}, - obs::Vector{Int}) + i₁::AbstractVector, + i₂::AbstractVector, + obs::AbstractVector) cms = cumsum(obs) @@ -92,48 +92,6 @@ function combine( end -# Compute the projection -# fac ↓ vars -function project(fac::Factor{T₁, T₂}, vars::Vector{Int}) where {T₁, T₂} - i₁ = Int[] - i₂ = Int[] - - for (x, y) in enumerate(fac.vars) - if y in vars - push!(i₁, x) - else - push!(i₂, x) - end - end - - hom = project(fac.hom, i₁, i₂, fac.obs) - obs = fac.obs[i₁] - vars = fac.vars[i₁] - - Factor{T₁, T₂}(hom, obs, vars) -end - - -# Compute the composite -# hom ; F(i₁†) -function project(hom::GaussianSystem, i₁::Vector{Int}, i₂::Vector{Int}, obs::Vector{Int}) - cms = cumsum(obs) - - j₁ = Int[] - j₂ = Int[] - - for y₁ in i₁ - append!(j₁, cms[y₁] - obs[y₁] + 1:cms[y₁]) - end - - for y₂ in i₂ - append!(j₂, cms[y₂] - obs[y₂] + 1:cms[y₂]) - end - - first(disintegrate(hom, j₁, j₂)) -end - - # Construct an identity element # e # of type Factor{T₁, T₂} @@ -142,7 +100,7 @@ function Base.zero(::Type{Factor{T₁, T₂}}) where {T₁ <: GaussianSystem, T end -function permute(fac::Factor, vars::Vector{Int}) +function permute(fac::Factor, vars::AbstractVector) i = Vector{Int}(undef, length(fac)) for (x₁, y₁) in enumerate(fac.vars) @@ -159,7 +117,7 @@ end # Compute the composite # hom ; F(i) -function permute(hom::GaussianSystem, i::Vector{Int}, obs::Vector{Int}) +function permute(hom::GaussianSystem, i::AbstractVector, obs::AbstractVector) cms = cumsum(obs) j = Int[] @@ -172,7 +130,7 @@ function permute(hom::GaussianSystem, i::Vector{Int}, obs::Vector{Int}) end -function observe(fac::Factor{T₁, T₂}, ctx::Pair{Int}) where {T₁, T₂} +function reduce_to_context(fac::Factor{T₁, T₂}, ctx::Pair) where {T₁, T₂} i₁ = Int[] i₂ = Int[] @@ -184,7 +142,7 @@ function observe(fac::Factor{T₁, T₂}, ctx::Pair{Int}) where {T₁, T₂} end end - hom = observe(fac.hom, ctx.second, i₁, i₂, fac.obs) + hom = reduce_to_context(fac.hom, ctx.second, i₁, i₂, fac.obs) obs = fac.obs[i₁] vars = fac.vars[i₁] @@ -192,12 +150,12 @@ function observe(fac::Factor{T₁, T₂}, ctx::Pair{Int}) where {T₁, T₂} end -function observe( +function reduce_to_context( hom₁::GaussianSystem, hom₂::AbstractVector, - i₁::Vector{Int}, - i₂::Vector{Int}, - obs::Vector{Int}) + i₁::AbstractVector, + i₂::AbstractVector, + obs::AbstractVector) cms = cumsum(obs) @@ -212,5 +170,5 @@ function observe( append!(j₂, cms[y₂] - obs[y₂] + 1:cms[y₂]) end - observe(hom₁, hom₂, j₁, j₂) + reduce_to_context(hom₁, hom₂, j₁, j₂) end diff --git a/src/models.jl b/src/models.jl index 6320873..d5f421f 100644 --- a/src/models.jl +++ b/src/models.jl @@ -8,67 +8,70 @@ end function GraphicalModel{T₁, T₂, T₃}( - fg::AbstractUndirectedBipartiteGraph, - homs::AbstractVector, - obs::AbstractVector, - labels::AbstractVector) where {T₁, T₂, T₃} + factor_graph::AbstractUndirectedBipartiteGraph, + labels::AbstractVector, + morphisms::AbstractVector, + objects::AbstractVector) where {T₁, T₂, T₃} - @assert nv₁(fg) == length(homs) - @assert nv₂(fg) == length(obs) + @assert nv₁(factor_graph) == length(morphisms) + @assert nv₂(factor_graph) == length(labels) == length(objects) - scopes = [Int[] for _ in vertices₁(fg)] - vvll = [Int[] for _ in vertices₂(fg)] + scopes = [Int[] for _ in vertices₁(factor_graph)] + vvll = [Int[] for _ in vertices₂(factor_graph)] - for i in edges(fg) - f = src(fg, i) - v = tgt(fg, i) + for i in edges(factor_graph) + f = src(factor_graph, i) + v = tgt(factor_graph, i) push!(scopes[f], v) push!(vvll[v], f) end - labels = Labels{T₁}(labels) - factors = Vector{Factor{T₂, T₃}}(undef, nv₁(fg)) - graph = Graphs.Graph(nv₂(fg)) + factors = Vector{Factor{T₂, T₃}}(undef, nv₁(factor_graph)) + graph = Graphs.Graph(nv₂(factor_graph)) for (f, vs) in enumerate(scopes) - factors[f] = Factor(homs[f], obs[vs], vs) n = length(vs) for i₁ in 2:n, i₂ in 1:i₁ - 1 Graphs.add_edge!(graph, vs[i₁], vs[i₂]) end + + factors[f] = Factor(morphisms[f], objects[vs], vs) end + labels = Labels{T₁}(labels) GraphicalModel(labels, factors, graph, vvll) end -function GraphicalModel{T₁, T₂, T₃}(bn::BayesNets.BayesNet) where {T₁, T₂, T₃} - n = length(bn) +function GraphicalModel{T₁, T₂, T₃}(network::BayesNets.BayesNet) where {T₁, T₂, T₃} + n = length(network) - labels = Labels{T₁}(names(bn)) + labels = Labels{T₁}(names(network)) factors = Vector{Factor{T₂, T₃}}(undef, n) graph = Graphs.Graph{Int}(n) vvll = [[i] for i in 1:n] for i in 1:n - cpd = bn.cpds[i] - pas = [bn.name_to_index[l] for l in BayesNets.parents(cpd)] - m = length(pas) + cpd = network.cpds[i] + parents = map(l -> network.name_to_index[l], BayesNets.parents(cpd)) + m = length(parents) - hom = GaussianSystem(cpd) - obs = ones(Int, m + 1) - vars = [pas; i] + morphism = GaussianSystem(cpd) + objects = ones(Int, m + 1) + variables = [parents; i] - factors[i] = Factor(hom, obs, vars) + factors[i] = Factor(morphism, objects, variables) for j₁ in 1:m - Graphs.add_edge!(graph, pas[j₁], i) - push!(vvll[pas[j₁]], i) + i₁ = parents[j₁] + push!(vvll[i₁], i) + Graphs.add_edge!(graph, i₁, i) for j₂ in 1:j₁ - 1 - Graphs.add_edge!(graph, pas[j₁], pas[j₂]) + i₂ = parents[i₂] + Graphs.add_edge!(graph, i₁, i₂) end end end @@ -87,32 +90,30 @@ function Base.copy(model::GraphicalModel) end -function observe!(model::GraphicalModel, context::Pair) - l, hom = context - - v = model.labels.index[l] +function reduce_to_context(model::GraphicalModel, context::AbstractDict) + labels = copy(model.labels) + factors = copy(model.factors) + graph = copy(model.graph) + vvll = copy(model.vvll) - for f in model.vvll[v] - model.factors[f] = observe(model.factors[f], v => hom) - end + for (l, hom) in context + v = labels.index[l] - u = length(model.labels) + for f in vvll[v] + factors[f] = reduce_to_context(factors[f], v => hom) + end - model.vvll[v] = model.vvll[u] + for f in vvll[end] + fac = factors[f] + factors[f] = Factor(fac.hom, fac.obs, replace(fac.vars, length(vvll) => v)) + end - for f in model.vvll[v] - fac = model.factors[f] - model.factors[f] = Factor(fac.hom, fac.obs, replace(fac.vars, u => v)) + delete!(labels, l) + Graphs.rem_vertex!(graph, v) + + vvll[v] = vvll[end] + pop!(vvll) end - delete!(model.labels, l) - Graphs.rem_vertex!(model.graph, v) - pop!(model.vvll) -end - - -function observe!(model::GraphicalModel, context::AbstractDict) - for (l, hom) in context - observe!(model, l => hom) - end + GraphicalModel(labels, factors, graph, vvll) end diff --git a/src/problems.jl b/src/problems.jl index 08b2aa3..167d6cd 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -13,35 +13,28 @@ end function InferenceProblem{T₁, T₂, T₃, T₄}( - uwd::AbstractUWD, + diagram::AbstractUWD, hom_map::AbstractDict, - ob_map::AbstractDict, - context::AbstractDict=Dict(); + ob_map::AbstractDict; hom_attr::Symbol=:name, ob_attr::Symbol=:junction_type, var_attr::Symbol=:variable) where {T₁, T₂, T₃, T₄} - homs = [hom_map[x] for x in subpart(uwd, hom_attr)] - obs = [ob_map[x] for x in subpart(uwd, ob_attr)] + labels = subpart(diagram, var_attr) + morphisms = map(x -> hom_map[x], subpart(diagram, hom_attr)) + objects = map(x -> ob_map[x], subpart(diagram, ob_attr)) - labels = subpart(uwd, var_attr) - - InferenceProblem{T₁, T₂, T₃, T₄}(uwd, homs, obs, context, labels) + InferenceProblem{T₁, T₂, T₃, T₄}(diagram, labels, morphisms, objects) end function InferenceProblem{T₁, T₂, T₃, T₄}( - uwd::AbstractUWD, - homs::AbstractVector, - obs::AbstractVector, - context::AbstractDict=Dict(), - labels::AbstractVector=junctions(uwd)) where {T₁, T₂, T₃, T₄} - - query = [ - labels[v] for v in subpart(uwd, :outer_junction) - if !haskey(context, labels[v])] + diagram::AbstractUWD, + labels::AbstractVector, + morphisms::AbstractVector, + objects::AbstractVector) where {T₁, T₂, T₃, T₄} - fg = @migrate UndirectedBipartiteGraph uwd begin + factor_graph = @migrate UndirectedBipartiteGraph diagram begin E => Port V₁ => Box V₂ => Junction @@ -50,18 +43,20 @@ function InferenceProblem{T₁, T₂, T₃, T₄}( tgt => junction end - model = GraphicalModel{T₁, T₂, T₃}(fg, homs, obs, labels) + model = GraphicalModel{T₁, T₂, T₃}(factor_graph, labels, morphisms, objects) + query = labels[junction(diagram, :; outer=true)] + context = Dict() InferenceProblem{T₁, T₂, T₃, T₄}(model, query, context) end function InferenceProblem{T₁, T₂, T₃, T₄}( - bn::BayesNets.BayesNet, + network::BayesNets.BayesNet, query::AbstractVector, context::AbstractDict=Dict()) where {T₁, T₂, T₃, T₄} - model = GraphicalModel{T₁, T₂, T₃}(bn) + model = GraphicalModel{T₁, T₂, T₃}(network) context = Dict(l => [v] for (l, v) in context) InferenceProblem{T₁, T₂, T₃, T₄}(model, query, context) @@ -70,10 +65,9 @@ end """ InferenceProblem( - uwd::RelationDiagram, + diagram::RelationDiagram, hom_map::AbstractDict, - ob_map::AbstractDict, - evidence::AbstractDict=Dict(); + ob_map::AbstractDict; hom_attr::Symbol=:name, ob_attr::Symbol=:junction_type, var_attr::Symbol=:variable) @@ -81,56 +75,63 @@ end Construct an inference problem that performs undirected compositon. """ InferenceProblem( - uwd::RelationDiagram, + diagram::RelationDiagram, hom_map::AbstractDict, - ob_map::AbstractDict, - evidence::AbstractDict=Dict(); + ob_map::AbstractDict; hom_attr::Symbol=:name, ob_attr::Symbol=:junction_type, var_attr::Symbol=:variable) function InferenceProblem( - uwd::Union{TypedRelationDiagram{<:Any, <:Any, T₁}, UntypedRelationDiagram{<:Any, T₁}}, + diagram::Union{TypedRelationDiagram{<:Any, <:Any, T₁}, UntypedRelationDiagram{<:Any, T₁}}, hom_map::AbstractDict{<:Any, T₂}, - ob_map::AbstractDict{<:Any, T₃}, - context::AbstractDict{<:Any, T₄}=Dict(); + ob_map::AbstractDict{<:Any, T₃}; hom_attr::Symbol=:name, ob_attr::Symbol=:junction_type, - var_attr::Symbol=:variable) where {T₁, T₂, T₃, T₄} + var_attr::Symbol=:variable) where {T₁, T₂, T₃} - InferenceProblem{T₁, T₂, T₃, T₄}(uwd, hom_map, ob_map, context; hom_attr, ob_attr, var_attr) + InferenceProblem{T₁, T₂, T₃, Union{}}(diagram, hom_map, ob_map; hom_attr, ob_attr, var_attr) end """ InferenceProblem( - bn::BayesNet, + network::BayesNet, query::AbstractVector, evidence::AbstractDict=Dict()) Construct an inference problem that queries a Bayesian network. """ function InferenceProblem( - bn::BayesNets.BayesNet, + network::BayesNets.BayesNet, query::AbstractVector, - context::AbstractDict=Dict()) + context::AbstractDict) - InferenceProblem{Symbol, DenseCanonicalForm{Float64}, Int, Vector{Float64}}(bn, query, context) + InferenceProblem{Symbol, DenseCanonicalForm{Float64}, Int, Vector{Float64}}(network, query, context) end """ solve( problem::InferenceProblem, - elalg::EliminationAlgorithm=MinFill() - stype::SupernodeType=Node() - atype::ArchitectureType=ShenoyShafer()) + elimination_algorithm::EliminationAlgorithm=MinFill() + supernode_type::SupernodeType=Node() + architecture_type::ArchitectureType=ShenoyShafer()) Solve an inference problem. """ CommonSolve.solve( problem::InferenceProblem, - elalg::EliminationAlgorithm=MinFill(), - stype::SupernodeType=Node(), - atype::ArchitectureType=ShenoyShafer()) + elimination_algorithm::EliminationAlgorithm=MinFill(), + supernode_type::SupernodeType=Node(), + architecture_type::ArchitectureType=ShenoyShafer()) + + +function reduce_to_context(problem::InferenceProblem, context::AbstractDict) + model = problem.model + query = filter(l -> !haskey(context, l), problem.query) + context = merge(problem.context, context) + + InferenceProblem(model, query, context) +end diff --git a/src/solvers.jl b/src/solvers.jl index d87a027..7c2d520 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -21,12 +21,11 @@ end # Construct a solver for an inference problem. function InferenceSolver( problem::InferenceProblem, - elalg::EliminationAlgorithm, - stype::SupernodeType, - atype::ArchitectureType) + elimination_algorithm::EliminationAlgorithm, + supernode_type::SupernodeType, + architecture_type::ArchitectureType) - model = copy(problem.model) - observe!(model, problem.context) + model = reduce_to_context(problem.model, problem.context) for i₁ in eachindex(problem.query), i₂ in 1:i₁ - 1 v₁ = model.labels.index[problem.query[i₁]] @@ -37,9 +36,7 @@ function InferenceSolver( end end - architecture = Architecture(model, elalg, stype) - architecture_type = atype - + architecture = Architecture(model, elimination_algorithm, supernode_type) InferenceSolver(architecture, architecture_type, problem.query) end @@ -47,19 +44,19 @@ end """ init( problem::InferenceProblem, - elalg::EliminationAlgorithm=MinFill(), - stype::SupernodeType=Node(), - atype::ArchitectureType=ShenoyShafer()) + elimination_algorithm::EliminationAlgorithm=MinFill(), + supernode_type::SupernodeType=Node(), + architecture_type::ArchitectureType=ShenoyShafer()) Construct a solver for an inference problem. """ function CommonSolve.init( problem::InferenceProblem, - elalg::EliminationAlgorithm=MinFill(), - stype::SupernodeType=Node(), - atype::ArchitectureType=ShenoyShafer()) + elimination_algorithm::EliminationAlgorithm=MinFill(), + supernode_type::SupernodeType=Node(), + architecture_type::ArchitectureType=ShenoyShafer()) - InferenceSolver(problem, elalg, stype, atype) + InferenceSolver(problem, elimination_algorithm, supernode_type, architecture_type) end diff --git a/src/systems.jl b/src/systems.jl index 8687c13..cc09cbf 100644 --- a/src/systems.jl +++ b/src/systems.jl @@ -458,7 +458,7 @@ function permute(Σ::GaussianSystem, i::AbstractVector) end -function observe( +function reduce_to_context( Σ::GaussianSystem, v::AbstractVector, i₁::AbstractVector, diff --git a/test/runtests.jl b/test/runtests.jl index 54be6e4..e14f1e9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -399,7 +399,7 @@ end -63.6 ] - uwd = @relation (x₂,) where (x₀::X, x₁::X, x₂::X, z₁::Z, z₂::Z) begin + diagram = @relation (x₂,) where (x₀::X, x₁::X, x₂::X, z₁::Z, z₂::Z) begin state(x₀) predict(x₀, x₁) predict(x₁, x₂) @@ -420,44 +420,41 @@ end :X => 6, :Z => 2) - Σ = oapply(uwd, hom_map, ob_map; ob_attr=:junction_type) + Σ = oapply(diagram, hom_map, ob_map; ob_attr=:junction_type) @test isapprox(true_cov, cov(Σ); atol=0.3) @test isapprox(true_mean, mean(Σ); atol=0.3) - problem = InferenceProblem(uwd, hom_map, ob_map; ob_attr=:junction_type) + problem = InferenceProblem(diagram, hom_map, ob_map) @test problem.query == [:x₂] - elalg = MinFill() - stype = Node() - atype = ShenoyShafer() + elimination_algorithm = MinFill() + supernode_type = Node() + architecture_type = ShenoyShafer() - solver = init(problem, elalg, stype, atype) + solver = init(problem, elimination_algorithm, supernode_type, architecture_type) Σ = solve!(solver) @test isapprox(true_cov, cov(Σ); atol=0.3) @test isapprox(true_mean, mean(Σ); atol=0.3) + x = rand(solver) x = mean(solver) @test isapprox(true_mean, x[:x₂]; atol=0.3) - Random.seed!(42) - x = rand(solver) - - elalg = MinDegree() - stype = MaximalSupernode() - atype = LauritzenSpiegelhalter() + elimination_algorithm = MinDegree() + supernode_type = MaximalSupernode() + architecture_type = LauritzenSpiegelhalter() problem.query = [] - solver = init(problem, elalg, stype, atype); solver.query = [:x₂] + solver = init(problem, elimination_algorithm, supernode_type, architecture_type) + solver.query = [:x₂] Σ = solve!(solver) @test isapprox(true_cov, cov(Σ); atol=0.3) @test isapprox(true_mean, mean(Σ); atol=0.3) + x = rand(solver) x = mean(solver) @test isapprox(true_mean, x[:x₂]; atol=0.3) - Random.seed!(42) - x = rand(solver) - solver.query = [:x₀, :x₁, :x₂, :z₁, :z₂] @test_throws ErrorException("Query not covered by join tree.") solve!(solver) end @@ -473,20 +470,18 @@ end true_var = 0.0094 true_mean = 50.934 - bn = BayesNet() - push!(bn, StaticCPD(:x₀, Normal(x₀, √p₀))) - push!(bn, LinearGaussianCPD(:x₁, [:x₀], [1], 0, √q)) - push!(bn, LinearGaussianCPD(:x₂, [:x₁], [1], 0, √q)) - push!(bn, LinearGaussianCPD(:z₁, [:x₁], [1], 0, √r)) - push!(bn, LinearGaussianCPD(:z₂, [:x₂], [1], 0, √r)) + network = BayesNet() + push!(network, StaticCPD(:x₀, Normal(x₀, √p₀))) + push!(network, LinearGaussianCPD(:x₁, [:x₀], [1], 0, √q)) + push!(network, LinearGaussianCPD(:x₂, [:x₁], [1], 0, √q)) + push!(network, LinearGaussianCPD(:z₁, [:x₁], [1], 0, √r)) + push!(network, LinearGaussianCPD(:z₂, [:x₂], [1], 0, √r)) query = [:x₂] context = Dict(:z₁ => 50.486, :z₂ => 50.963) - problem = InferenceProblem(bn, query, context) - solver = init(problem, MinFill()) - - Σ = solve!(solver) + problem = InferenceProblem(network, query, context) + Σ = solve(problem) @test isapprox(true_var, only(var(Σ)); atol=0.001) @test isapprox(true_mean, only(mean(Σ)); atol=0.001) end