diff --git a/Project.toml b/Project.toml index b7cbce2..4bdc28f 100644 --- a/Project.toml +++ b/Project.toml @@ -10,18 +10,22 @@ Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" CuthillMcKee = "17f17636-5e38-52e3-a803-7ae3aaaf3da9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinkedLists = "70f5e60a-1556-5f34-a19e-a48b3e4aaee9" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +TreeWidthSolver = "7d267fc5-9ace-409f-a54c-cd2374872a55" [compat] -AbstractTrees = "0.4" AMD = "0.5" +AbstractTrees = "0.4" Catlab = "0.16" CuthillMcKee = "0.1" DataStructures = "0.18" +LinkedLists = "0.1" MLStyle = "0.4" Metis = "1" PartialFunctions = "1.1" +TreeWidthSolver = "0.3" julia = "1.7" diff --git a/src/Decompositions.jl b/src/Decompositions.jl index 4115cb2..ef6b6db 100644 --- a/src/Decompositions.jl +++ b/src/Decompositions.jl @@ -5,16 +5,23 @@ export StructuredDecomposition, StrDecomp, 𝐃, bags, adhesions, adhesionSpans, ∫ +using ..JunctionTrees +using ..JunctionTrees: EliminationAlgorithm, SupernodeType, DEFAULT_ELIMINATION_ALGORITHM, DEFAULT_SUPERNODE_TYPE + using PartialFunctions using MLStyle +using AbstractTrees +using Base.Threads using Catlab using Catlab.CategoricalAlgebra using Catlab.Graphs using Catlab.ACSetInterface using Catlab.CategoricalAlgebra.Diagrams + import Catlab.CategoricalAlgebra.Diagrams: ob_map, hom_map, colimit, limit + ##################### # DATA ##################### @@ -180,4 +187,162 @@ function 𝐃(f, d ::StructuredDecomposition, t::DecompType = d.decomp_type)::St StrDecomp(d.decomp_shape, Q, t) end + +################################## +# Integration with JunctionTrees # +################################## + + +""" + StrDecomp(graph::AbstractSymmetricGraph[, ealg::Union{Order, EliminationAlgorithm}[, stype::SupernodeType]]) + +Construct a structured decomposition of a simple graph, optionally specifying an elimination algorithm and +supernode type. +""" +function StrDecomp( + graph::AbstractSymmetricGraph, + ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, + stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) + + merge_decompositions(decompositions(graph, ealg, stype)) +end + + +# Construct a tree decomposition. +# ---------------------------------------- +# graph simple connected graph +# jtree junction tree +# ---------------------------------------- +function StrDecomp(graph::AbstractSymmetricGraph, jtree::JunctionTree) + n = length(jtree) + tree = Graph(n) + + for i in 1:n - 1 + add_edge!(tree, i, parentindex(jtree, i)) + end + + diagram = FinDomFunctor(homomorphisms(graph, jtree)..., ∫(tree)) + StrDecomp(tree, diagram, Decomposition, dom(diagram)) +end + + +function merge_decompositions(decomposition::AbstractVector) + tree = apex(coproduct(map(d -> d.decomp_shape, decomposition))) + l = length(decomposition) + m = nv(tree) + subgraph = Vector(undef, 2m - l) + homomorphism = Vector(undef, 2m - 2l) + + i = 0 + + for j in 1:l + n = nv(decomposition[j].decomp_shape) + + for k in 1:n + subgraph[i + k] = ob_map(decomposition[j].diagram, k) + end + + for k in 1:n - 1 + subgraph[i - j + k + m + 1] = ob_map(decomposition[j].diagram, k + n) + end + + for k in 1:n - 1 + homomorphism[i - j + k + 1] = hom_map(decomposition[j].diagram, k) + end + + for k in 1:n - 1 + homomorphism[i - j + k - l + m + 1] = hom_map(decomposition[j].diagram, k + n - 1) + end + + i += n + end + + diagram = FinDomFunctor(subgraph, homomorphism, ∫(tree)) + StrDecomp(tree, diagram, Decomposition, dom(diagram)) +end + + +function decompositions(graph::AbstractSymmetricGraph, ealg::EliminationAlgorithm, stype::SupernodeType) + component = connected_components(graph) + + n = length(component) + decomposition = Vector(undef, n) + + @threads for i in 1:n + subgraph = induced_subgraph(graph, component[i]) + decomposition[i] = StrDecomp(subgraph, JunctionTree(subgraph, ealg, stype)) + end + + decomposition +end + + +function decompositions(graph::AbstractSymmetricGraph, order::Order, stype::SupernodeType) + component = connected_components(graph) + + n = length(component) + decomposition = Vector(undef, n) + + @threads for i in 1:n + subgraph = induced_subgraph(graph, component[i]) + decomposition[i] = StrDecomp(subgraph, JunctionTree(subgraph, induced_order(order, component[i]), stype)) + end + + decomposition +end + + +function homomorphisms(graph::AbstractSymmetricGraph, jtree::JunctionTree) + n = length(jtree) + subgraph = Vector{Any}(undef, 2n - 1) + homomorphism = Vector{Any}(undef, 2n - 2) + + for i in 1:n + # clique(i) + subgraph[i] = induced_subgraph(graph, clique(jtree, i)) + end + + for i in 1:n - 1 + # seperator(i) + subgraph[n + i] = induced_subgraph(graph, seperator(jtree, i)) + end + + for i in 1:n - 1 + # seperator(i) → clique(parent(i)) + j = parentindex(jtree, i) + homomorphism[i] = induced_homomorphism(subgraph[n + i], subgraph[j], seperator_to_parent(jtree, i)) + end + + for i in 1:n - 1 + # seperator(i) → clique(i) + homomorphism[n + i - 1] = induced_homomorphism(subgraph[n + i], subgraph[i], seperator_to_self(jtree, i)) + end + + subgraph, homomorphism +end + + +function induced_order(order::Order, elements::AbstractVector) + Order(sortperm(inverse(order, elements))) +end + + +function induced_homomorphism(domain::AbstractSymmetricGraph, codomain::AbstractSymmetricGraph, V::AbstractVector) + index = Dict{Tuple{Int, Int}, Int}() + sizehint!(index, ne(codomain)) + + for e in edges(codomain) + index[src(codomain, e), tgt(codomain, e)] = e + end + + E = Vector{Int}(undef, ne(domain)) + + for e in edges(domain) + E[e] = index[V[src(domain, e)], V[tgt(domain, e)]] + end + + ACSetTransformation(domain, codomain; V, E) +end + + end diff --git a/src/JunctionTrees.jl b/src/JunctionTrees.jl new file mode 100644 index 0000000..5437baa --- /dev/null +++ b/src/JunctionTrees.jl @@ -0,0 +1,43 @@ +module JunctionTrees + + +import AMD +import CuthillMcKee +import LinkedLists +import Metis +import TreeWidthSolver + +using AbstractTrees +using Catlab.BasicGraphs +using DataStructures +using SparseArrays + + +# Orders +export Order, inverse + + +# Elimination Algorithms +export AMDJL_AMD, CuthillMcKeeJL_RCM, MetisJL_ND, TreeWidthSolverJL_BT, MCS + + +# Supernode Types +export Node, Maximal, Fundamental + + +# Junction Trees +export JunctionTree, width, height, seperator, residual, clique, seperator_to_parent, seperator_to_self + + +include("junction_trees/orders.jl") +include("junction_trees/elimination_algorithms.jl") +include("junction_trees/ordered_graphs.jl") +include("junction_trees/trees.jl") +include("junction_trees/postorder_trees.jl") +include("junction_trees/elimination_trees.jl") +include("junction_trees/supernode_types.jl") +include("junction_trees/supernode_trees.jl") +include("junction_trees/junction_trees.jl") + + +end diff --git a/src/StructuredDecompositions.jl b/src/StructuredDecompositions.jl index 8c90cc6..2feb087 100644 --- a/src/StructuredDecompositions.jl +++ b/src/StructuredDecompositions.jl @@ -1,11 +1,10 @@ module StructuredDecompositions +include("JunctionTrees.jl") include("Decompositions.jl") include("FunctorUtils.jl") include("DecidingSheaves.jl") -include("junction_trees/JunctionTrees.jl") -include("nested_uwds/NestedUWDs.jl") end diff --git a/src/junction_trees/JunctionTrees.jl b/src/junction_trees/JunctionTrees.jl deleted file mode 100644 index 87ae3d3..0000000 --- a/src/junction_trees/JunctionTrees.jl +++ /dev/null @@ -1,83 +0,0 @@ -module JunctionTrees - - -import AMD -import CuthillMcKee -import Metis - -using AbstractTrees -using Catlab.BasicGraphs -using DataStructures -using SparseArrays - -# Elimination Algorithms -export EliminationAlgorithm, AMDJL_AMD, CuthillMcKeeJL_RCM, MetisJL_ND, MCS - -# Supernodes -export Supernode, Node, MaximalSupernode, FundamentalSupernode - -# Orders -export Order - -# Elimination Trees -export EliminationTree -export getwidth, getsupernode, getsubtree, getlevel - -# Junction Trees -export JunctionTree -export getseperator, getresidual - - -# Add an element x to a sorted set v. -# Returns true if x ∉ v. -# Returns false if x ∈ v. -function insertsorted!(v::Vector, x::Integer) - i = searchsortedfirst(v, x) - - if i > length(v) || v[i] != x - insert!(v, i, x) - true - else - false - end -end - - -# Delete an element x from a sorted set v. -# Returns true if x ∈ v. -# Returns false if x ∉ v. -function deletesorted!(v::Vector, x::Integer) - i = searchsortedfirst(v, x) - - if i <= length(v) && v[i] == x - deleteat!(v, i) - true - else - false - end -end - - -# Delete the elements xs from a sorted set v. -# Returns true if xs and v intersect. -# Returns false if xs and v are disjoint. -function deletesorted!(v::Vector, xs::AbstractVector) - isintersecting = true - - for x in xs - isintersecting = deletesorted!(v, x) || isintersecting - end - - isintersecting -end - - -include("elimination_algorithms.jl") -include("supernodes.jl") -include("orders.jl") -include("trees.jl") -include("elimination_trees.jl") -include("junction_trees.jl") - - -end diff --git a/src/junction_trees/elimination_algorithms.jl b/src/junction_trees/elimination_algorithms.jl index d302a74..71d5d5c 100644 --- a/src/junction_trees/elimination_algorithms.jl +++ b/src/junction_trees/elimination_algorithms.jl @@ -5,6 +5,7 @@ A graph elimination algorithm. The options are - [`CuthillMcKeeJL_RCM`](@ref) - [`AMDJL_AMD`](@ref) - [`MetisJL_ND`](@ref) +- [`TreeWidthSolverJL_BT`](@ref) - [`MCS`](@ref) """ abstract type EliminationAlgorithm end @@ -34,6 +35,14 @@ The nested dissection heuristic. Uses Metis.jl. struct MetisJL_ND <: EliminationAlgorithm end +""" + TreeWidthSolverJL_BT <: EliminationAlgorithm + +The Bouchitte-Todinca algorithm. Uses TreeWidthSolver.jl. +""" +struct TreeWidthSolverJL_BT <: EliminationAlgorithm end + + """ MCS <: EliminationAlgorithm @@ -42,4 +51,128 @@ The maximum cardinality search algorithm. struct MCS <: EliminationAlgorithm end +""" + Order(graph::AbstractSymmetricGraph[, ealg::EliminationAlgorithm]) + +Construct an elimination order for a simple graph, optionally specifying an elimination algorithm. +""" +function Order(graph::AbstractSymmetricGraph) + Order(graph, DEFAULT_ELIMINATION_ALGORITHM) +end + + +# Construct an order using the reverse Cuthill-McKee algorithm. Uses CuthillMcKee.jl. +function Order(graph::AbstractSymmetricGraph, ealg::CuthillMcKeeJL_RCM) + order = CuthillMcKee.symrcm(adjacencymatrix(graph)) + Order(order) +end + + +# Construct an order using the approximate minimum degree algorithm. Uses AMD.jl. +function Order(graph::AbstractSymmetricGraph, ealg::AMDJL_AMD) + order = AMD.symamd(adjacencymatrix(graph)) + Order(order) +end + + +# Construct an order using the nested dissection heuristic. Uses Metis.jl. +function Order(graph::AbstractSymmetricGraph, ealg::MetisJL_ND) + order, index = Metis.permutation(adjacencymatrix(graph)) + Order(order, index) +end + + +# Construct an order using the Bouchitte-Todinca algorithm. Uses TreeWidthSolver.jl. +function Order(graph::AbstractSymmetricGraph, ealg::TreeWidthSolverJL_BT) + n = nv(graph) + T = TreeWidthSolver.LongLongUInt{n ÷ 64 + 1} + fadjlist = Vector{Vector{Int}}(undef, n) + bitfadjlist = Vector{T}(undef, n) + + for i in 1:n + fadjlist[i] = sort(collect(outneighbors(graph, i))) + bitfadjlist[i] = TreeWidthSolver.bmask(T, fadjlist[i]) + end + + bitgraph = TreeWidthSolver.MaskedBitGraph(bitfadjlist, fadjlist, TreeWidthSolver.bmask(T, 1:n)) + decomposition = TreeWidthSolver.bt_algorithm(bitgraph, TreeWidthSolver.all_pmc_enmu(bitgraph, false), ones(n), false, true) + order = reverse(vcat(TreeWidthSolver.EliminationOrder(decomposition.tree).order...)) + Order(order) +end + + +# Construct an order using the maximum cardinality search algorithm. +function Order(graph::AbstractSymmetricGraph, ealg::MCS) + order, index = mcs(graph) + Order(order, index) +end + + +# Construct the adjacency matrix of a graph. +function adjacencymatrix(graph::AbstractSymmetricGraph) + m = ne(graph) + n = nv(graph) + + colptr = ones(Int, n + 1) + rowval = sizehint!(Vector{Int}(), 2m) + + for j in 1:n + ns = collect(neighbors(graph, j)) + sort!(ns) + colptr[j + 1] = colptr[j] + length(ns) + append!(rowval, ns) + end + + nzval = ones(Int, length(rowval)) + SparseMatrixCSC(n, n, colptr, rowval, nzval) +end + + +# Simple Linear-Time Algorithms to Test Chordality of Graphs, Test Acyclicity of Hypergraphs, and Selectively Reduce Acyclic Hypergraphs +# Tarjan and Yannakakis +# Maximum Cardinality Search +function mcs(graph::AbstractSymmetricGraph) + n = nv(graph) + α = Vector{Int}(undef, n) + β = Vector{Int}(undef, n) + size = Vector{Int}(undef, n) + set = Vector{LinkedLists.LinkedList{Int}}(undef, n) + pointer = Vector{LinkedLists.ListNode{Int}}(undef, n) + + for i in 1:n + size[i] = 1 + set[i] = LinkedLists.LinkedList{Int}() + pointer[i] = push!(set[1], i) + end + + i = n + j = 1 + + while i >= 1 + v = first(set[j]) + deleteat!(set[j], pointer[v]) + α[v] = i + β[i] = v + size[v] = 0 + + for w in neighbors(graph, v) + if size[w] >= 1 + deleteat!(set[size[w]], pointer[w]) + size[w] += 1 + pointer[w] = push!(set[size[w]], w) + end + end + + i -= 1 + j += 1 + + while j >= 1 && isempty(set[j]) + j -= 1 + end + end + + β, α +end + + const DEFAULT_ELIMINATION_ALGORITHM = AMDJL_AMD() diff --git a/src/junction_trees/elimination_trees.jl b/src/junction_trees/elimination_trees.jl index c5d57d9..7a1afc5 100644 --- a/src/junction_trees/elimination_trees.jl +++ b/src/junction_trees/elimination_trees.jl @@ -1,268 +1,99 @@ -# A supernodal elimination tree. -struct EliminationTree - order::Order - tree::Tree - firstsupernodelist::Vector{Int} - lastsupernodelist::Vector{Int} - subtreelist::Vector{Int} - width::Int +# An ordered graph (G, σ) equipped with the elimination tree T of its elimination graph. +# Nodes i in T correspond to vertices σ(i) in G. +struct EliminationTree{T <: Union{Tree, PostorderTree}} + tree::T # elimination tree + graph::OrderedGraph # ordered graph end +# Construct an elimination tree using an elimination algorithm. +# ---------------------------------------- +# graph simple connected graph +# ealg elimination algorithm +# ---------------------------------------- function EliminationTree( - order::Order, - tree::Tree, - supernodelist::AbstractVector, - subtreelist::AbstractVector, - width::Integer) - - n = length(order) - m = length(tree) - postorder = Order(n) - firstsupernodelist = Vector{Int}(undef, m) - lastsupernodelist = Vector{Int}(undef, m) - - i₂ = 0 - - for j in 1:m - supernode = supernodelist[j] - i₁ = i₂ + 1 - i₂ = i₂ + length(supernode) - firstsupernodelist[j] = i₁ - lastsupernodelist[j] = i₂ - postorder[i₁:i₂] .= supernode - end - - order = compose(postorder, order) - subtreelist = subtreelist[postorder] - - EliminationTree( - order, - tree, - firstsupernodelist, - lastsupernodelist, - subtreelist, - width) + graph::AbstractSymmetricGraph, + ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM) + EliminationTree(OrderedGraph(graph, ealg)) end -# Construct a supernodal elimination tree. -# -# The complexity is -# 𝒪(m α(m, n) + n) -# where m = |E|, n = |V|, and α is the inverse Ackermann function. -function EliminationTree( - graph::AbstractSymmetricGraph, - order::Order, - supernode::Supernode=DEFAULT_SUPERNODE) - - etree = Tree(graph, order) - _, outdegreelist = getdegrees(graph, order, etree) - - supernodelist, subtreelist, parentlist = makestree( - etree, - outdegreelist, - supernode) - - n = nv(graph) - tree = Tree(subtreelist[n], parentlist) - postorder, tree = makepostorder(tree) - - supernodelist = supernodelist[postorder] - subtreelist = postorder.index[subtreelist] - width = maximum(outdegreelist) - - EliminationTree( - order, - tree, - supernodelist, - subtreelist, - width) -end - - -# Construct a supernodal elimination tree, first computing an elimination order. -function EliminationTree( - graph::AbstractSymmetricGraph, - algorithm::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM, - supernode::Supernode=DEFAULT_SUPERNODE) - - order = Order(graph, algorithm) - EliminationTree(graph, order, supernode) +# Construct the elimination tree of an ordered graph. +# ---------------------------------------- +# graph ordered graph +# ---------------------------------------- +function EliminationTree(graph::OrderedGraph) + EliminationTree(Tree(etree(graph)), graph) end -# Get the number of nodes in a supernodal elimination tree. -function Base.length(stree::EliminationTree) - length(stree.tree) +# Postorder an elimination tree. +# ---------------------------------------- +# graph ordered graph +# order postorder +# ---------------------------------------- +function EliminationTree{PostorderTree}(etree::EliminationTree, order::Order) + EliminationTree(PostorderTree(etree.tree, order), OrderedGraph(etree.graph, order)) end -# Get the width of a supernodal elimination tree. -function getwidth(stree::EliminationTree) - stree.width +# An Efficient Algorithm to Compute Row and Column Counts for Sparse Cholesky Factorization +# Gilbert, Ng, and Peyton +# Figure 3: Implementation of algorithm to compute row and column counts. +function supcnt(etree::EliminationTree) + order = postorder(etree.tree) + index = inverse(order) + rc, cc = supcnt(EliminationTree{PostorderTree}(etree, order)) + rc[index], cc[index] end -# Get the supernode at node i. -function getsupernode(stree::EliminationTree, i::Integer) - i₁ = stree.firstsupernodelist[i] - i₂ = stree.lastsupernodelist[i] - stree.order[i₁:i₂] -end - - -# Get the highest node containing a vertex v. -function getsubtree(stree::EliminationTree, v::Integer) - stree.subtreelist[stree.order.index[v]] -end - - -# Get the highest node containing vertices vs. -function getsubtree(stree::EliminationTree, vs::AbstractVector) - init = length(stree.order) - stree.subtreelist[minimum(stree.order.index[vs]; init)] -end - - -# Get the level of node i. -function getlevel(stree::EliminationTree, i::Integer) - getlevel(stree.tree, i) -end - - -# Evaluate whether node i₁ is a descendant of node i₂. -function AbstractTrees.isdescendant(stree::EliminationTree, i₁::Integer, i₂::Integer) - isdescendant(stree.tree, i₁, i₂) -end - - -# Compute the supernodes, parent function, and first ancestor of a -# supernodal elimination tree. -# -# The complexity is -# 𝒪(n) -# where n = |V|. -# -# doi:10.1561/2400000006 -# Algorithm 4.1: Maximal supernodes and supernodal elimination tree. -function makestree(etree::Tree, outdegrees::AbstractVector, supernode::Supernode) - n = length(etree) - sbt = Vector{Int}(undef, n) - snd = Vector{Int}[] - q = Int[] - a = Int[] - - for v in 1:n - w′ = findchild(etree, outdegrees, v, supernode) - - if isnothing(w′) - i = length(snd) + 1 - sbt[v] = i - push!(snd, [v]) - push!(q, 0) - push!(a, 0) - else - i = sbt[w′] - sbt[v] = i - push!(snd[i], v) - end - - for w in childindices(etree, v) - if w !== w′ - j = sbt[w] - q[j] = i - a[j] = v - end - end - end - - snd, sbt, q, a -end - - -# Find a child w of v such that -# v ∈ snd(w). -# If no such child exists, return nothing. -function findchild(etree::Tree, outdegrees::AbstractVector, v::Integer, ::Supernode) end - - -function findchild(etree::Tree, outdegrees::AbstractVector, v::Integer, ::MaximalSupernode) - for w in childindices(etree, v) - if outdegrees[w] == outdegrees[v] + 1 - return w - end - end -end - - -function findchild(etree::Tree, outdegrees::AbstractVector, v::Integer, ::FundamentalSupernode) - ws = childindices(etree, v) - - if length(ws) == 1 - w = only(ws) - - if outdegrees[w] == outdegrees[v] + 1 - return w - end - end -end - - -# Compute the row and column counts of a graph's elimination graph. -# -# The complexity is -# 𝒪(m α(m, n)) -# where m = |E|, n = |V|, and α is the inverse Ackermann function. -# -# doi:10.1137/S089547989223692 -# Figure 3: Implementation of algorithm to compute row and column counts -function getdegrees(graph::AbstractSymmetricGraph, order::Order, etree::Tree) - n = nv(graph) - forest = IntDisjointSets(n) - rvert = Vector{Int}(undef, n) - index = Vector{Int}(undef, n) - rvert .= index .= 1:n - - function FIND(p) - index[find_root!(forest, p)] +# An Efficient Algorithm to Compute Row and Column Counts for Sparse Cholesky Factorization +# Gilbert, Ng, and Peyton +# Figure 3: Implementation of algorithm to compute row and column counts. +function supcnt(etree::EliminationTree{PostorderTree}) + n = length(etree.tree) + + #### Disjoint Set Union #### + + rvert = collect(1:n) + index = collect(1:n) + forest = IntDisjointSets(n) + + function find(u) + index[find_root!(forest, u)] end - - function UNION(u, v) + + function union(u, v) w = max(u, v) rvert[w] = root_union!(forest, rvert[u], rvert[v]) index[rvert[w]] = w end - postorder, etree = makepostorder(etree) - graph = Graph(graph, compose(postorder, order)) - prev_p = Vector{Int}(undef, n) - prev_nbr = Vector{Int}(undef, n) - rc = Vector{Int}(undef, n) - wt = Vector{Int}(undef, n) + ############################ + + prev_p = zeros(Int, n) + prev_nbr = zeros(Int, n) + rc = ones(Int, n) + wt = ones(Int, n) - for u in 1:n - prev_p[u] = 0 - prev_nbr[u] = 0 - rc[u] = 1 - wt[u] = isempty(childindices(etree, u)) + for u in 1:n - 1 + wt[parentindex(etree.tree, u)] = 0 end + + for p in 1:n - 1 + wt[parentindex(etree.tree, p)] -= 1 - for p in 1:n - if p != n - wt[parentindex(etree, p)] -= 1 - end - - for u in neighbors(graph, p) - if getfirstdescendant(etree, p) > prev_nbr[u] + for u in outneighbors(etree.graph, p) + if firstdescendant(etree.tree, p) > prev_nbr[u] wt[p] += 1 - p′ = prev_p[u] + pp = prev_p[u] - if p′ == 0 - rc[u] += getlevel(etree, p) - getlevel(etree, u) + if iszero(pp) + rc[u] += level(etree.tree, p) - level(etree.tree, u) else - q = FIND(p′) - rc[u] += getlevel(etree, p) - getlevel(etree, q) + q = find(pp) + rc[u] += level(etree.tree, p) - level(etree.tree, q) wt[q] -= 1 end @@ -272,48 +103,22 @@ function getdegrees(graph::AbstractSymmetricGraph, order::Order, etree::Tree) prev_nbr[u] = p end - if p != n - UNION(p, parentindex(etree, p)) - end + union(p, parentindex(etree.tree, p)) end cc = wt for v in 1:n - 1 - cc[parentindex(etree, v)] += cc[v] + cc[parentindex(etree.tree, v)] += cc[v] end - indegrees = rc[postorder.index] .- 1 - outdegrees = cc[postorder.index] .- 1 - indegrees, outdegrees -end - - -########################## -# Indexed Tree Interface # -########################## - - -function AbstractTrees.rootindex(stree::EliminationTree) - rootindex(stree.tree) -end - - -function AbstractTrees.parentindex(stree::EliminationTree, i::Integer) - parentindex(stree.tree, i) -end - - -function AbstractTrees.childindices(stree::EliminationTree, i::Integer) - childindices(stree.tree, i) -end - - -function AbstractTrees.NodeType(::Type{IndexNode{EliminationTree, Int}}) - HasNodeType() + rc, cc end -function AbstractTrees.nodetype(::Type{IndexNode{EliminationTree, Int}}) - IndexNode{EliminationTree, Int} +# Compute higher degree of every vertex in the elimination graph of +# (G, σ). +function outdegrees(etree::EliminationTree) + rc, cc = supcnt(etree) + cc .- 1 end diff --git a/src/junction_trees/junction_trees.jl b/src/junction_trees/junction_trees.jl index 87e22e0..39f525a 100644 --- a/src/junction_trees/junction_trees.jl +++ b/src/junction_trees/junction_trees.jl @@ -1,161 +1,136 @@ -# A junction tree. +""" + JunctionTree + +A junction tree. +""" struct JunctionTree - stree::EliminationTree - seperatorlist::Vector{Vector{Int}} + stree::SupernodeTree # supernodal elimination tree + seperator::Vector{Vector{Int}} # seperator end -# Construct a tree decomposition. -function JunctionTree(graph::AbstractSymmetricGraph, stree::EliminationTree) - graph = makeeliminationgraph(graph, stree) - - n = length(stree) - seperatorlist = Vector{Vector{Int}}(undef, n) - seperatorlist[n] = [] +""" + JunctionTree(graph::AbstractSymmetricGraph[, ealg::Union{Order, EliminationAlgorithm}[, stype::SupernodeType]]) - for i in 1:n - 1 - v₁ = stree.firstsupernodelist[i] - v₂ = stree.lastsupernodelist[i] - bag = collect(neighbors(graph, v₁)) - sort!(bag) - seperatorlist[i] = bag[v₂ - v₁ + 1:end] - end +Construct a tree decomposition of a connected simple graph, optionally specifying an elimination algorithm and +a supernode type. +""" +function JunctionTree( + graph::AbstractSymmetricGraph, + ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, + stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) - JunctionTree(stree, seperatorlist) + JunctionTree(SupernodeTree(graph, ealg, stype)) end -# Reorient a juncton tree towards the given root. -function JunctionTree(root::Integer, jtree::JunctionTree) - m = length(jtree.stree.order) - n = length(jtree) - seperatorlist = Vector{Vector{Int}}(undef, n) - supernodelist = Vector{Vector{Int}}(undef, n) - subtreelist = Vector{Int}(undef, m) - - v₁ = jtree.stree.firstsupernodelist[root] - v₂ = jtree.stree.lastsupernodelist[root] - seperatorlist[n] = [] - supernodelist[n] = [v₁:v₂; jtree.seperatorlist[root]] - subtreelist[supernodelist[n]] .= n +# Construct a junction tree. +# ---------------------------------------- +# stree supernodal elimination tree +# ---------------------------------------- +function JunctionTree(stree::SupernodeTree) + JunctionTree(stree, map(sort ∘ collect, seperators(stree))) +end - tree = Tree(root, jtree.stree.tree) - postorder, tree = makepostorder(tree) - for i in 1:n - 1 - j = postorder[i] - v₁ = jtree.stree.firstsupernodelist[j] - v₂ = jtree.stree.lastsupernodelist[j] +""" + clique(jtree::JunctionTree, i::Integer) - if isdescendant(jtree, root, j) - seperatorlist[i] = jtree.seperatorlist[postorder[parentindex(tree, i)]] - supernodelist[i] = [v₁:v₂; jtree.seperatorlist[j]] - deletesorted!(supernodelist[i], seperatorlist[i]) - else - seperatorlist[i] = jtree.seperatorlist[j] - supernodelist[i] = v₁:v₂ - end - - subtreelist[supernodelist[i]] .= i - end +Get the clique at node i. +""" +function clique(jtree::JunctionTree, i::Integer) + [residual(jtree, i); seperator(jtree, i)] +end - order = jtree.stree.order - width = jtree.stree.width - stree = EliminationTree(order, tree, supernodelist, subtreelist, width) - for i in 1:n - seperatorlist[i] = stree.order.index[order[seperatorlist[i]]] - sort!(seperatorlist[i]) - end +""" + seperator(jtree::JunctionTree, i::Integer) - JunctionTree(stree, seperatorlist) +Get the seperator at node i. +""" +function seperator(jtree::JunctionTree, i::Integer) + permutation(jtree.stree.graph, jtree.seperator[i]) end -# Construct a tree decomposition, first computing an elimination order and a supernodal -# elimination tree. -function JunctionTree( - graph::AbstractSymmetricGraph, - algorithm::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, - supernode::Supernode=DEFAULT_SUPERNODE) +""" + residual(jtree::JunctionTree, i::Integer) - stree = EliminationTree(graph, algorithm, supernode) - JunctionTree(graph, stree) +Get the residual at node i. +""" +function residual(jtree::JunctionTree, i::Integer) + permutation(jtree.stree.graph, supernode(jtree.stree, i)) end -# Get the number of nodes in a junction tree. -function Base.length(jtree::JunctionTree) - length(jtree.stree) -end +# Construct the inclusion seperator(i) → clique(parent(i)). +function seperator_to_parent(jtree::JunctionTree, i::Integer) + j = parentindex(jtree, i) + sep = jtree.seperator[i] + sep_parent = jtree.seperator[j] + res_parent = supernode(jtree.stree, j) + i = 0 + index = Vector{Int}(undef, length(sep)) + + for (j, v) in enumerate(sep) + if v in res_parent + index[j] = v - first(res_parent) + 1 + else + i += searchsortedfirst(view(sep_parent, i + 1:length(sep_parent)), v) + index[j] = length(res_parent) + i + end + end -# Get the width of a junction tree. -function getwidth(jtree::JunctionTree) - getwidth(jtree.stree) + index end -# Get the seperator at node i. -function getseperator(jtree::JunctionTree, i::Integer) - jtree.stree.order[jtree.seperatorlist[i]] +# Construct the inclusion seperator(i) → clique(i). +function seperator_to_self(jtree::JunctionTree, i::Integer) + sep = jtree.seperator[i] + res = supernode(jtree.stree, i) + length(res) + 1:length(res) + length(sep) end -# Get the residual at node i. -function getresidual(jtree::JunctionTree, i::Integer) - getsupernode(jtree.stree, i) -end - +""" + length(jtree::JunctionTree) -# Get the highest node containing the vertex v. -function getsubtree(jtree::JunctionTree, v::Union{Integer, AbstractVector}) - getsubtree(jtree.stree, v) +Get the number of nodes in a junction tree. +""" +function Base.length(jtree::JunctionTree) + length(jtree.stree.tree) end -# Get the level of node i. -function getlevel(jtree::JunctionTree, i::Integer) - getlevel(jtree.stree, i) -end +""" + height(jtree::JunctionTree) - -# Evaluate whether node i₁ is a descendant of node i₂. -function AbstractTrees.isdescendant(jtree::JunctionTree, i₁::Integer, i₂::Integer) - isdescendant(jtree.stree, i₁, i₂) +Compute the height of a junction tree. +""" +function height(jtree::JunctionTree) + height(jtree.stree.tree) end -# Construct an elimination graph. -function makeeliminationgraph(graph::AbstractSymmetricGraph, stree::EliminationTree) - n = length(stree) - graph = Graph(graph, stree.order) - - for i in 1:n - 1 - u₁ = stree.firstsupernodelist[i] - u₂ = stree.lastsupernodelist[i] +""" + width(jtree::JunctionTree) - for u in u₁:u₂ - 1 - v = u + 1 - - for w in neighbors(graph, u) - if v != w && !has_edge(graph, v, w) - add_edge!(graph, v, w) - end - end - end +Compute the width of a junction tree. +""" +function width(jtree::JunctionTree) + width(jtree.stree) +end - u = u₂ - v = stree.firstsupernodelist[parentindex(stree, i)] - for w in neighbors(graph, u) - if v != w && !has_edge(graph, v, w) - add_edge!(graph, v, w) - end - end +function Base.show(io::IO, jtree::JunctionTree) + n = width(jtree) + print(io, "width: $n\njunction tree:\n") + + print_tree(io, IndexNode(jtree)) do io, node + show(IOContext(io, :compact => true, :limit => true), clique(jtree, node.index)) end - - graph end @@ -165,17 +140,17 @@ end function AbstractTrees.rootindex(jtree::JunctionTree) - rootindex(jtree.stree) + rootindex(jtree.stree.tree) end function AbstractTrees.parentindex(jtree::JunctionTree, i::Integer) - parentindex(jtree.stree, i) + parentindex(jtree.stree.tree, i) end function AbstractTrees.childindices(jtree::JunctionTree, i::Integer) - childindices(jtree.stree, i) + childindices(jtree.stree.tree, i) end diff --git a/src/junction_trees/ordered_graphs.jl b/src/junction_trees/ordered_graphs.jl new file mode 100644 index 0000000..f06fa96 --- /dev/null +++ b/src/junction_trees/ordered_graphs.jl @@ -0,0 +1,162 @@ +# An ordered graph (G, σ). +struct OrderedGraph + graph::Graph # graph + order::Order # permutation +end + + +# Given a graph G, construct the ordered graph +# (G, σ), +# where the permutation σ is computed using an elimination algorithm. +# ---------------------------------------- +# sgraph simple connected graph +# ealg elimination algorithm +# ---------------------------------------- +function OrderedGraph(sgraph::AbstractSymmetricGraph, ealg::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM) + OrderedGraph(sgraph, Order(sgraph, ealg)) +end + + +# Given a graph G and permutation σ, construct the ordered graph +# (G, σ). +# ---------------------------------------- +# sgraph simple connected graph +# order vertex order +# ---------------------------------------- +function OrderedGraph(sgraph::AbstractSymmetricGraph, order::Order) + n = nv(sgraph) + graph = Graph(n) + + for e in edges(sgraph) + u = src(sgraph, e) + v = tgt(sgraph, e) + + if order(u, v) + add_edge!(graph, inverse(order, u), inverse(order, v)) + end + end + + OrderedGraph(graph, order) +end + + +# Given an ordered graph (G, σ) and permutation μ, construct the ordered graph +# (G, σ ∘ μ). +# ---------------------------------------- +# ograph ordered graph +# order permutation +# ---------------------------------------- +function OrderedGraph(ograph::OrderedGraph, order::Order) + n = nv(ograph) + graph = Graph(n) + + for e in edges(ograph) + u = src(ograph, e) + v = tgt(ograph, e) + + if order(u, v) + add_edge!(graph, inverse(order, u), inverse(order, v)) + else + add_edge!(graph, inverse(order, v), inverse(order, u)) + end + end + + OrderedGraph(graph, compose(order, ograph.order)) +end + + +# A Compact Row Storage Scheme for Cholesky Factors Using Elimination Trees +# Liu +# Algorithm 4.2: Elimination Tree by Path Compression. +function etree(graph::OrderedGraph) + n = nv(graph) + parent = Vector{Int}(undef, n) + ancestor = Vector{Int}(undef, n) + + for i in 1:n + parent[i] = 0 + ancestor[i] = 0 + + for k in inneighbors(graph, i) + r = k + + while ancestor[r] != 0 && ancestor[r] != i + t = ancestor[r] + ancestor[r] = i + r = t + end + + if ancestor[r] == 0 + ancestor[r] = i + parent[r] = i + end + end + end + + parent[n] = n + parent +end + + +function Base.deepcopy(ograph::OrderedGraph) + order = deepcopy(ograph.order) + graph = deepcopy(ograph.graph) + OrderedGraph(graph, order) +end + + +# Get the vertex σ(i). +function permutation(ograph::OrderedGraph, i) + ograph.order[i] +end + + +# Get the index σ⁻¹(v). +function inverse(ograph::OrderedGraph, v) + inverse(ograph.order, v) +end + + +############################ +# Abstract Graph Interface # +############################ + + +function BasicGraphs.ne(ograph::OrderedGraph) + ne(ograph.graph) +end + + +function BasicGraphs.nv(ograph::OrderedGraph) + nv(ograph.graph) +end + + +function BasicGraphs.inneighbors(ograph::OrderedGraph, i) + inneighbors(ograph.graph, i) +end + + +function BasicGraphs.outneighbors(ograph::OrderedGraph, i) + outneighbors(ograph.graph, i) +end + + +function BasicGraphs.edges(ograph::OrderedGraph) + edges(ograph.graph) +end + + +function BasicGraphs.vertices(ograph::OrderedGraph) + vertices(ograph.graph) +end + + +function BasicGraphs.src(ograph::OrderedGraph, i) + src(ograph.graph, i) +end + + +function BasicGraphs.tgt(ograph::OrderedGraph, i) + tgt(ograph.graph, i) +end diff --git a/src/junction_trees/orders.jl b/src/junction_trees/orders.jl index 48f2a67..849f06e 100644 --- a/src/junction_trees/orders.jl +++ b/src/junction_trees/orders.jl @@ -1,170 +1,77 @@ -# A total ordering of the numbers {1, ..., n}. +""" + Order <: AbstractVector{Int} + +A permutation of the set ``\\{1, \\dots, n\\}.`` +""" struct Order <: AbstractVector{Int} - order::Vector{Int} - index::Vector{Int} + order::Vector{Int} # permutation + index::Vector{Int} # inverse permutation end -# Given a vector σ, construct the order ≺, where -# σ(i₁) ≺ σ(i₂) -# if -# i₁ < i₂. +""" + Order(order::AbstractVector) + +Construct a permutation ``\\sigma`` from a sequence ``(\\sigma(1), \\dots, \\sigma(n)).`` +""" function Order(order::AbstractVector) n = length(order) index = Vector{Int}(undef, n) for i in 1:n index[order[i]] = i - end - - Order(order, index) -end - - -# Construct an empty order of length n. -function Order(n::Integer) - order = Vector{Int}(undef, n) - index = Vector{Int}(undef, n) + end + Order(order, index) end -# Construct an elimination order using the reverse Cuthill-McKee algorithm. Uses -# CuthillMcKee.jl. -function Order(graph::AbstractSymmetricGraph, ::CuthillMcKeeJL_RCM) - order = CuthillMcKee.symrcm(adjacencymatrix(graph)) - Order(order) +# Determine if i < j, where +# u = σ(i) +# v = σ(j) +function (order::Order)(u, v) + inverse(order, u) < inverse(order, v) end -# Construct an elimination order using the approximate minimum degree algorithm. Uses -# AMD.jl. -function Order(graph::AbstractSymmetricGraph, ::AMDJL_AMD) - order = AMD.symamd(adjacencymatrix(graph)) - Order(order) +# Compose two permutations. +function compose(left::Order, right::Order) + Order(right.order[left.order], left.index[right.index]) end -# Construct an elimination order using the nested dissection heuristic. Uses Metis.jl. -function Order(graph::AbstractSymmetricGraph, ::MetisJL_ND) - order, index = Metis.permutation(adjacencymatrix(graph)) - Order(order, index) +# Construct the inverse permutation. +function inverse(order::Order) + Order(order.index, order.order) end -# Construct an elimination order using the maximum cardinality search algorithm. -function Order(graph::AbstractSymmetricGraph, ::MCS) - order, index = mcs(graph) - Order(order, index) +# Get the index σ⁻¹(v), +function inverse(order::Order, v) + order.index[v] end -# Compose as permutations. -function compose(order₁::Order, order₂::Order) - order = order₂.order[order₁.order] - index = order₁.index[order₂.index] - Order(order, index) -end - +############################# +# Abstract Vector Interface # +############################# -# Evaluate whether -# n₁ < n₂ -# in the given order. -function Base.isless(order::Order, n₁::Integer, n₂::Integer) - order.index[n₁] < order.index[n₂] -end - - -# Compute a vertex elimination order using the maximum cardinality search algorithm. -# -# The complexity is -# 𝒪(m + n), -# where m = |E| and n = |V|. -# -# https://doi.org/10.1137/0213035 -# Maximum cardinality search -function mcs(graph::AbstractSymmetricGraph) - n = nv(graph) - α = Vector{Int}(undef, n) - α⁻¹ = Vector{Int}(undef, n) - size = Vector{Int}(undef, n) - set = Vector{Vector{Int}}(undef, n) - - set .= [[]] - size .= 1 - append!(set[1], vertices(graph)) - - i = n - j = 1 - - while i >= 1 - v = pop!(set[j]) - α[v] = i - α⁻¹[i] = v - size[v] = 0 - - for w in neighbors(graph, v) - if size[w] >= 1 - deletesorted!(set[size[w]], w) - size[w] += 1 - insertsorted!(set[size[w]], w) - end - end - - i -= 1 - j += 1 - - while j >= 1 && isempty(set[j]) - j -= 1 - end - end - α⁻¹, α +function Base.getindex(order::Order, i) + order.order[i] end -# Construct the adjacency matrix of a graph. -function adjacencymatrix(graph::AbstractSymmetricGraph) - m = ne(graph) - n = nv(graph) - - colptr = ones(Int, n + 1) - rowval = sizehint!(Vector{Int}(), 2m) - - for j in 1:n - ns = collect(neighbors(graph, j)) - sort!(ns) - colptr[j + 1] = colptr[j] + length(ns) - append!(rowval, ns) - end - - nzval = ones(Int, length(rowval)) - SparseMatrixCSC(n, n, colptr, rowval, nzval) +function Base.IndexStyle(::Type{Order}) + IndexLinear() end -############################ -# AbstractVector Interface # -############################ - - function Base.size(order::Order) (length(order.order),) end -function Base.getindex(order::Order, i::Integer) - order.order[i] -end - - -function Base.setindex!(order::Order, v::Integer, i::Integer) - order.order[i] = v - order.index[v] = i - v -end - - -function Base.IndexStyle(::Type{Order}) - IndexLinear() +function Base.deepcopy(order::Order) + Order(copy(order.order), copy(order.index)) end diff --git a/src/junction_trees/postorder_trees.jl b/src/junction_trees/postorder_trees.jl new file mode 100644 index 0000000..23ed2d8 --- /dev/null +++ b/src/junction_trees/postorder_trees.jl @@ -0,0 +1,117 @@ +# A postordered rooted tree. +struct PostorderTree + parent::Vector{Int} # parent + children::Vector{Vector{Int}} # children + level::Vector{Int} # level + descendant::Vector{Int} # first descendant +end + + +# Construct a tree from a postordered list of parents. +# ---------------------------------------- +# parent list of parents +# ---------------------------------------- +function PostorderTree(parent::AbstractVector) + n = length(parent) + children = Vector{Vector{Int}}(undef, n) + level = Vector{Int}(undef, n) + descendant = Vector{Int}(undef, n) + + for i in 1:n + children[i] = [] + level[i] = 0 + descendant[i] = i + end + + for i in 1:n - 1 + j = parent[i] + push!(children[j], i) + descendant[j] = min(descendant[i], descendant[j]) + end + + for i in n - 1:-1:1 + j = parent[i] + level[i] = level[j] + 1 + end + + PostorderTree(parent, children, level, descendant) +end + + +# Postorder a tree. +# ---------------------------------------- +# tree tree +# order postorder +# ---------------------------------------- +function PostorderTree(tree::Tree, order::Order) + n = length(tree) + parent = collect(1:n) + + for i in 1:n - 1 + parent[i] = inverse(order, parentindex(tree, order[i])) + end + + PostorderTree(parent) +end + + +# The number of node in a tree. +function Base.length(tree::PostorderTree) + length(tree.parent) +end + + +# Get the level of a node i. +function level(tree::PostorderTree, i::Integer) + tree.level[i] +end + + +# Get the first descendant of a node i. +function firstdescendant(tree::PostorderTree, i::Integer) + tree.descendant[i] +end + + +# Determine whether the node i is a descendant of the node j. +function isdescendant(tree::PostorderTree, i::Integer, j::Integer) + getdescendant(tree, j) <= i < j +end + + +# Get the height of a tree. +function height(tree::PostorderTree) + maximum(tree.level) +end + + +########################## +# Indexed Tree Interface # +########################## + + +function AbstractTrees.parentindex(tree::PostorderTree, i::Integer) + if i != rootindex(tree) + tree.parent[i] + end +end + + +function AbstractTrees.childindices(tree::PostorderTree, i::Integer) + tree.children[i] +end + + +function AbstractTrees.rootindex(tree::PostorderTree) + length(tree) +end + + +function AbstractTrees.NodeType(::Type{IndexNode{PostorderTree, Int}}) + HasNodeType() +end + + +function AbstractTrees.nodetype(::Type{IndexNode{PostorderTree, Int}}) + IndexNode{PostorderTree, Int} +end diff --git a/src/junction_trees/supernode_trees.jl b/src/junction_trees/supernode_trees.jl new file mode 100644 index 0000000..cf7ff0e --- /dev/null +++ b/src/junction_trees/supernode_trees.jl @@ -0,0 +1,96 @@ +# An ordered graph (G, σ) equipped with a supernodal elimination tree T. +struct SupernodeTree + tree::PostorderTree # supernodal elimination tree + graph::OrderedGraph # ordered graph + representative::Vector{Int} # representative vertex + cardinality::Vector{Int} # supernode cardinality + ancestor::Vector{Int} # first ancestor + degree::Vector{Int} # higher degrees +end + + +# Construct a supernodal elimination tree using an elimination algorithm. +# ---------------------------------------- +# graph simple connected graph +# ealg elimination algorithm +# stype supernode type +# ---------------------------------------- +function SupernodeTree( + graph::AbstractSymmetricGraph, + ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, + stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) + + SupernodeTree(EliminationTree(graph, ealg), stype) +end + + +# Construct a supernodal elimination tree. +# ---------------------------------------- +# etree elimination tree +# stype supernode type +# ---------------------------------------- +function SupernodeTree(etree::EliminationTree, stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) + degree = outdegrees(etree) + supernode, parent, ancestor = stree(etree, degree, stype) + tree = Tree(parent) + + order = postorder(tree) + tree = PostorderTree(tree, order) + permute!(supernode, order) + permute!(ancestor, order) + + order = Order(vcat(supernode...)) + graph = OrderedGraph(etree.graph, order) + permute!(degree, order) + + representative = map(first, supernode) + cardinality = map(length, supernode) + map!(i -> inverse(order, i), ancestor, ancestor) + map!(i -> inverse(order, i), representative, representative) + + SupernodeTree(tree, graph, representative, cardinality, ancestor, degree) +end + + +# Compute the width of a supernodal elimination tree. +function width(stree::SupernodeTree) + maximum(stree.degree[stree.representative]) +end + + +# Get the (sorted) supernode at node i. +function supernode(stree::SupernodeTree, i::Integer) + v = stree.representative[i] + n = stree.cardinality[i] + v:v + n - 1 +end + + +# Compute the (unsorted) seperators of every node in T. +function seperators(stree::SupernodeTree) + n = length(stree.tree) + seperator = Vector{Set{Int}}(undef, n) + + for i in 1:n - 1 + seperator[i] = Set(stree.ancestor[i]) + + for v in outneighbors(stree.graph, stree.representative[i]) + if stree.ancestor[i] < v + push!(seperator[i], v) + end + end + end + + for j in 1:n - 1 + for i in childindices(stree.tree, j) + for v in seperator[i] + if stree.ancestor[j] < v + push!(seperator[j], v) + end + end + end + end + + seperator[n] = Set() + seperator +end diff --git a/src/junction_trees/supernode_types.jl b/src/junction_trees/supernode_types.jl new file mode 100644 index 0000000..84ced33 --- /dev/null +++ b/src/junction_trees/supernode_types.jl @@ -0,0 +1,107 @@ +""" + SupernodeType + +A type of supernode. The options are +- [`Node`](@ref) +- [`MaximalSupernode`](@ref) +- [`FundamentalSupernode`](@ref) +""" +abstract type SupernodeType end + + +""" + Node <: Supernode + +A single-vertex supernode. +""" +struct Node <: SupernodeType end + + +""" + MaximalSupernode <: Supernode + +A maximal supernode. +""" +struct Maximal <: SupernodeType end + + +""" + FundamentalSupernode <: Supernode + +A fundamental supernode. +""" +struct Fundamental <: SupernodeType end + + +# Compact Clique Tree Data Structures in Sparse Matrix Factorizations +# Pothen and Sun +# Figure 4: The Clique Tree Algorithm 2 +function stree(etree::EliminationTree, degree::AbstractVector, stype::SupernodeType) + n = length(etree.tree) + new_in_clique = Vector{Int}(undef, n) + new = Vector{Int}[] + parent = Int[] + first_anc = Int[] + + i = 0 + + for v in 1:n + u = child_in_supernode(etree, degree, stype, v) + + if !isnothing(u) + new_in_clique[v] = new_in_clique[u] + push!(new[new_in_clique[v]], v) + else + new_in_clique[v] = i += 1 + push!(new, [v]) + push!(parent, i) + push!(first_anc, n) + end + + for s in childindices(etree.tree, v) + if s !== u + parent[new_in_clique[s]] = new_in_clique[v] + first_anc[new_in_clique[s]] = v + end + end + end + + new, parent, first_anc +end + + +# Find a child w of v such that +# v ∈ snd(w). +# If no such child exists, return nothing. +function child_in_supernode(etree::EliminationTree, degree::AbstractVector, stype::Node, v::Integer) end + + +# Find a child w of v such that +# v ∈ snd(w). +# If no such child exists, return nothing. +function child_in_supernode(etree::EliminationTree, degree::AbstractVector, stype::Maximal, v::Integer) + for w in childindices(etree.tree, v) + if degree[w] == degree[v] + 1 + return w + end + end +end + + +# Find a child w of v such that +# v ∈ snd(w). +# If no such child exists, return nothing. +function child_in_supernode(etree::EliminationTree, degree::AbstractVector, stype::Fundamental, v::Integer) + ws = childindices(etree.tree, v) + + if length(ws) == 1 + w = only(ws) + + if degree[w] == degree[v] + 1 + return w + end + end +end + + +const DEFAULT_SUPERNODE_TYPE = Maximal() diff --git a/src/junction_trees/supernodes.jl b/src/junction_trees/supernodes.jl deleted file mode 100644 index af940bb..0000000 --- a/src/junction_trees/supernodes.jl +++ /dev/null @@ -1,36 +0,0 @@ -""" - Supernode - -A type of supernode. The options are -- [`Node`](@ref) -- [`MaximalSupernode`](@ref) -- [`FundamentalSupernode`](@ref) -""" -abstract type Supernode end - - -""" - Node <: Supernode - -A single-vertex supernode. -""" -struct Node <: Supernode end - - -""" - MaximalSupernode <: Supernode - -A maximal supernode. -""" -struct MaximalSupernode <: Supernode end - - -""" - FundamentalSupernode <: Supernode - -A fundamental supernode. -""" -struct FundamentalSupernode <: Supernode end - - -const DEFAULT_SUPERNODE = MaximalSupernode() diff --git a/src/junction_trees/trees.jl b/src/junction_trees/trees.jl index ab8655f..759d19f 100644 --- a/src/junction_trees/trees.jl +++ b/src/junction_trees/trees.jl @@ -1,202 +1,56 @@ # A rooted tree. struct Tree - root::Int - parentlist::Vector{Int} - childrenlist::Vector{Vector{Int}} - levellist::Vector{Int} - firstdescendantlist::Vector{Int} -end - - -# Orient a tree towards the given root. -function Tree(root::Integer, tree::Tree) - i = root - parent = parentindex(tree, i) - parentlist = copy(tree.parentlist) - childrenlist = deepcopy(tree.childrenlist) - - while !isnothing(parent) - parentlist[parent] = i - push!(childrenlist[i], parent) - deletesorted!(childrenlist[parent], i) - i = parent - parent = parentindex(tree, i) - end - - Tree(root, parentlist, childrenlist) -end - - -# Construct a tree from a list of parent and a list of children. -function Tree(root::Integer, parentlist::AbstractVector, childrenlist::AbstractVector) - n = length(parentlist) - levellist = Vector{Int}(undef, n) - firstdescendantlist = Vector{Int}(undef, n) - Tree(root, parentlist, childrenlist, levellist, firstdescendantlist) + root::Int # root + parent::Vector{Int} # parent + children::Vector{Vector{Int}} # children end # Construct a tree from a list of parents. -function Tree(root::Integer, parentlist::AbstractVector) - n = length(parentlist) - childrenlist = Vector{Vector{Int}}(undef, n) - childrenlist .= [[]] - +# ---------------------------------------- +# parent list of parents +# ---------------------------------------- +function Tree(parent::AbstractVector) + n = root = length(parent) + children = Vector{Vector{Int}}(undef, n) + for i in 1:n - if i != root - push!(childrenlist[parentlist[i]], i) - end + children[i] = [] end - Tree(root, parentlist, childrenlist) -end - - -# Construct an elimination tree. -function Tree(graph::AbstractSymmetricGraph, order::Order) - n = nv(graph) - parentlist = makeetree(graph, order) - @assert count(parentlist .== 0) == 1 - Tree(n, parentlist) -end - - -function Base.length(tree::Tree) - length(tree.parentlist) -end - - -# Compute the parent vector of the elimination tree of the elimination graph of a ordered -# graph. -# -# The complexity is -# 𝒪(m log(n)) -# where m = |E| and n = |V|. -# -# doi:10.1145/6497.6499 -# Algorithm 4.2: Elimination Tree by Path Compression -function makeetree(graph::AbstractSymmetricGraph, order::Order) - graph = Graph(graph, order) - - n = nv(graph) - parent = Vector{Int}(undef, n) - ancestor = Vector{Int}(undef, n) - for i in 1:n - parent[i] = 0 - ancestor[i] = 0 - - for k in inneighbors(graph, i) - r = k - - while ancestor[r] != 0 && ancestor[r] != i - t = ancestor[r] - ancestor[r] = i - r = t - end - - if ancestor[r] == 0 - ancestor[r] = i - parent[r] = i - end + j = parent[i] + + if i == j + root = i + else + push!(children[j], i) end end - parent -end - - -# Given an ordered graph -# (G, σ), -# construct a directed graph by ordering the edges in G from lower to higher index. -# -# The complexity is -# 𝒪(m) -# where m = |E|. -function BasicGraphs.Graph(graph::AbstractSymmetricGraph, order::Order) - n = nv(graph) - digraph = Graph(n) - - for v in vertices(graph) - i = order.index[v] - - for w in neighbors(graph, v) - j = order.index[w] - - if i < j - add_edge!(digraph, i, j) - end - end - end - - digraph -end - - -############## -# Postorders # -############## - - -# Get the level of node i. -# This function only works on postordered trees. -function getlevel(tree::Tree, i::Integer) - tree.levellist[i] + Tree(root, parent, children) end -# Get the first descendant of node i. -# This function only works on postordered trees. -function getfirstdescendant(tree::Tree, i::Integer) - tree.firstdescendantlist[i] -end - - -# Evaluate whether node i₁ is a descendant of node i₂. -# This function only works on postordered trees. -function AbstractTrees.isdescendant(tree::Tree, i₁::Integer, i₂::Integer) - getfirstdescendant(tree, i₂) <= i₁ < i₂ +# Get the number of nodes in a tree. +function Base.length(tree::Tree) + length(tree.parent) end -# Compute a postordering of a tree. -# -# The complexity is -# 𝒪(n) -# where n = |V|. -function makepostorder(tree::Tree) +# Compute a postordering of tree's vertices. +function postorder(tree::Tree) n = length(tree) - order = Order(n) - parentlist = Vector{Int}(undef, n) - childrenlist = Vector{Vector{Int}}(undef, n) - levellist = Vector{Int}(undef, n) - firstdescendantlist = Vector{Int}(undef, n) - - root, nodes... = PreOrderDFS(IndexNode(tree)) - - order[n] = root.index - parentlist[n] = 0 - childrenlist[n] = [] - levellist[n] = 0 - - for (i, node) in enumerate(nodes) - j = n - i - order[j] = node.index - - k = order.index[parentindex(tree, node.index)] - parentlist[j] = k - childrenlist[j] = [] - pushfirst!(childrenlist[k], j) - levellist[j] = 1 + levellist[k] - end - - for i in 1:n - init = i - firstdescendantlist[i] = minimum(firstdescendantlist[childrenlist[i]]; init) + order = Vector{Int}(undef, n) + index = Vector{Int}(undef, n) + + for node in PreOrderDFS(IndexNode(tree)) + order[n] = node.index + index[node.index] = n + n -= 1 end - - tree = Tree(n, parentlist, childrenlist, levellist, firstdescendantlist) - order, tree + + Order(order, index) end @@ -212,13 +66,13 @@ end function AbstractTrees.parentindex(tree::Tree, i::Integer) if i != rootindex(tree) - tree.parentlist[i] + tree.parent[i] end end function AbstractTrees.childindices(tree::Tree, i::Integer) - tree.childrenlist[i] + tree.children[i] end diff --git a/src/nested_uwds/NestedUWDs.jl b/src/nested_uwds/NestedUWDs.jl deleted file mode 100644 index 5de7eca..0000000 --- a/src/nested_uwds/NestedUWDs.jl +++ /dev/null @@ -1,34 +0,0 @@ -module NestedUWDs - - -using AbstractTrees -using Catlab.ACSetInterface -using Catlab.BasicGraphs -using Catlab.DirectedWiringDiagrams -using Catlab.DirectedWiringDiagrams: WiringDiagramACSet -using Catlab.MonoidalUndirectedWiringDiagrams -using Catlab.MonoidalUndirectedWiringDiagrams: UntypedHypergraphDiagram -using Catlab.RelationalPrograms -using Catlab.RelationalPrograms: TypedUnnamedRelationDiagram -using Catlab.Theories -using Catlab.UndirectedWiringDiagrams -using Catlab.WiringDiagramAlgebras - -using ..JunctionTrees -using ..JunctionTrees: insertsorted!, DEFAULT_ELIMINATION_ALGORITHM, DEFAULT_SUPERNODE - -# Elimination Algorithms -export EliminationAlgorithm, AMDJL_AMD, CuthillMcKeeJL_RCM, MetisJL_ND, MCS - -# Supernodes -export Supernode, Node, MaximalSupernode, FundamentalSupernode - -# Nested UWDs -export NestedUWD -export evalschedule, makeschedule, makeoperations - - -include("nested_uwds.jl") - - -end diff --git a/src/nested_uwds/nested_uwds.jl b/src/nested_uwds/nested_uwds.jl deleted file mode 100644 index 54dd648..0000000 --- a/src/nested_uwds/nested_uwds.jl +++ /dev/null @@ -1,312 +0,0 @@ -""" - NestedUWD{T, B, V} - -An undirected wiring diagram, represented as a nested collected of undirected wiring -diagrams. -""" -struct NestedUWD{T, B, V} - diagram::TypedUnnamedRelationDiagram{T, B, V} - jtree::JunctionTree - assignmentlist::Vector{Int} - assignmentindex::Vector{Vector{Int}} -end - - -function NestedUWD( - diagram::D, - jtree::JunctionTree, - assignmentlist::AbstractVector, - assignmentindex::AbstractVector) where D <: UndirectedWiringDiagram - - T, B, V = getattributetypes(D) - relation = TypedUnnamedRelationDiagram{T, B, V}() - copy_parts!(relation, diagram) - NestedUWD{T, B, V}(relation, jtree, assignmentlist, assignmentindex) -end - - -function NestedUWD(diagram::UndirectedWiringDiagram, jtree::JunctionTree) - n = nparts(diagram, :Box) - m = length(jtree) - assignmentlist = Vector{Int}(undef, n) - assignmentindex = Vector{Vector{Int}}(undef, m) - assignmentindex .= [[]] - - for b in 1:n - i = getsubtree(jtree, diagram[incident(diagram, b, :box), :junction]) - assignmentlist[b] = i - push!(assignmentindex[i], b) - end - - NestedUWD(diagram, jtree, assignmentlist, assignmentindex) -end - - -""" - NestedUWD( - diagram::UndirectedWiringDiagram, - [, algorithm::Union{Order, EliminationAlgorithm}] - [, supernode::Supernode]) - -Construct a nested undirected wiring diagram. -""" -function NestedUWD( - diagram::UndirectedWiringDiagram, - algorithm::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, - supernode::Supernode=DEFAULT_SUPERNODE) - - jtree = JunctionTree(diagram, algorithm, supernode) - NestedUWD(diagram, jtree) -end - - -# Construct a tree decomposition of the line graph of an undirected wiring diagram. -function JunctionTree( - diagram::UndirectedWiringDiagram, - algorithm::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, - supernode::Supernode=DEFAULT_SUPERNODE) - - graph = makegraph(diagram) - jtree = JunctionTree(graph, algorithm, supernode) - - query = diagram[:outer_junction] - JunctionTree(getsubtree(jtree, query), jtree) -end - - -# Construct the line graph of an undirected wiring diagram. -function makegraph(diagram::UndirectedWiringDiagram) - n = nparts(diagram, :Junction) - m = nparts(diagram, :Box) - graph = SymmetricGraph(n) - - for i in 1:m - junctions = diagram[incident(diagram, i, :box), :junction] - l = length(junctions) - - for i₁ in 1:l - 1 - j₁ = junctions[i₁] - - for i₂ in i₁ + 1:l - j₂ = junctions[i₂] - - if !has_edge(graph, j₁, j₂) - add_edge!(graph, j₁, j₂) - end - end - end - end - - junctions = diagram[:, :outer_junction] - l = length(junctions) - - for i₁ in 1:l - 1 - j₁ = junctions[i₁] - - for i₂ in i₁ + 1:l - j₂ = junctions[i₂] - - if !has_edge(graph, j₁, j₂) - add_edge!(graph, j₁, j₂) - end - end - end - - graph -end - - -""" - makeschedule(nuwd::NestedUWD) - -Construct a directed wiring diagram that represents the nesting structure of a nested UWD. -""" -function makeschedule(nuwd::NestedUWD{<:Any, T}) where T - m = length(nuwd.assignmentlist) - n = length(nuwd.jtree) - - parents = map(1:n - 1) do i - parentindex(nuwd.jtree, i) - end - - costs = map(1:n) do i - length(getresidual(nuwd.jtree, i)) + length(getseperator(nuwd.jtree, i)) - end - - schedule = WiringDiagramACSet{T, Nothing, Union{Int, AbstractBox}, DataType}() - - add_parts!(schedule, :Box, n) - add_parts!(schedule, :Wire, n - 1) - add_parts!(schedule, :InPort, m + n - 1) - add_parts!(schedule, :InWire, m) - add_parts!(schedule, :OutPort, n) - add_parts!(schedule, :OutWire, 1) - add_parts!(schedule, :OuterInPort, m) - add_parts!(schedule, :OuterOutPort, 1) - - schedule[:, :src] = 1:n - 1 - schedule[:, :tgt] = m + 1:m + n - 1 - schedule[:, :in_src] = 1:m - schedule[:, :in_tgt] = 1:m - schedule[:, :out_src] = n:n - schedule[:, :out_tgt] = 1:1 - schedule[:, :in_port_box] = [nuwd.assignmentlist; parents] - schedule[:, :out_port_box] = 1:n - - schedule[:, :value] = costs - schedule[:, :box_type] = Box{Int} - schedule[:, :outer_in_port_type] = nuwd.diagram[:, :name] - - Theory = ThSymmetricMonoidalCategory.Meta.T - WiringDiagram{Theory, T, Nothing, Int}(schedule, nothing) -end - - -""" - function evalschedule( - f, - nuwd::NestedUWD, - generators::Union{AbstractVector, AbstractDict} - [, operations::AbstractVector]) - -Evaluate an undirected wiring diagrams given a set of generators for the boxes. The -optional first argument `f` should be callable with the signature -``` - f(diagram, generators) -``` -where `diagram` is an undirected wiring diagram, and `generators` is a vector. If `f` is not -specified, then it defaults to `oapply`. -""" -function evalschedule( - f, - nuwd::NestedUWD, - generators::AbstractVector{T}, - operations::AbstractVector=makeoperations(nuwd)) where T - - n = length(nuwd.jtree) - mailboxes = Vector{T}(undef, n) - - for i in 1:n - g₁ = generators[nuwd.assignmentindex[i]] - g₂ = mailboxes[childindices(nuwd.jtree, i)] - mailboxes[i] = f(operations[i], [g₁; g₂]) - end - - mailboxes[n] -end - - -function evalschedule( - f, - nuwd::NestedUWD, - generators::AbstractDict{<:Any, T}, - operations::AbstractVector=makeoperations(nuwd)) where T - - g = generators - n = nparts(nuwd.diagram, :Box) - generators = Vector{T}(undef, n) - - for i in 1:n - generators[i] = g[nuwd.diagram[i, :name]] - end - - evalschedule(f, nuwd, generators, operations) -end - - -function evalschedule( - nuwd::NestedUWD, - generators::Union{AbstractVector, AbstractDict}, - operations::AbstractVector=makeoperations(nuwd)) - - evalschedule(oapply, nuwd, generators, operations) -end - - -# For each node i of a nested UWD, construct the undirected wiring diagram corresponding to i. -function makeoperations(nuwd::NestedUWD) - m = length(nuwd.jtree) - - map(1:m) do i - makeoperation(nuwd, i) - end -end - - -# Construct the undirected wiring diagram corresponding to node i of a nested UWD. -function makeoperation(nuwd::NestedUWD{T, B, V}, i::Integer) where {T, B, V} - function findjunction(j::Integer) - v = nuwd.jtree.stree.order.index[j] - v₁ = nuwd.jtree.stree.firstsupernodelist[i] - v₂ = nuwd.jtree.stree.lastsupernodelist[i] - - if v <= v₂ - v - v₁ + 1 - else - v₂ - v₁ + 1 + searchsortedfirst(nuwd.jtree.seperatorlist[i], v) - end - end - - residual = getresidual(nuwd.jtree, i) - seperator = getseperator(nuwd.jtree, i) - m = length(residual) - n = length(seperator) - - operation = TypedUnnamedRelationDiagram{T, B, V}() - add_parts!(operation, :Junction, m + n) - - operation[1:m, :junction_type] = nuwd.diagram[residual, :junction_type] - operation[1:m, :variable] = nuwd.diagram[residual, :variable] - operation[m + 1:m + n, :junction_type] = nuwd.diagram[seperator, :junction_type] - operation[m + 1:m + n, :variable] = nuwd.diagram[seperator, :variable] - - if i < length(nuwd.jtree) - for j in seperator - p′ = add_part!(operation, :OuterPort) - operation[p′, :outer_junction] = m + p′ - operation[p′, :outer_port_type] = nuwd.diagram[j, :junction_type] - end - else - for j in nuwd.diagram[:outer_junction] - p′ = add_part!(operation, :OuterPort) - operation[p′, :outer_junction] = findjunction(j) - operation[p′, :outer_port_type] = nuwd.diagram[j, :junction_type] - end - end - - for b in nuwd.assignmentindex[i] - b′ = add_part!(operation, :Box) - operation[b′, :name] = nuwd.diagram[b, :name] - - for j in nuwd.diagram[incident(nuwd.diagram, b, :box), :junction] - p′ = add_part!(operation, :Port) - operation[p′, :box] = b′ - operation[p′, :junction] = findjunction(j) - operation[p′, :port_type] = nuwd.diagram[j, :junction_type] - end - end - - for b in childindices(nuwd.jtree, i) - b′ = add_part!(operation, :Box) - - for j in getseperator(nuwd.jtree, b) - p′ = add_part!(operation, :Port) - operation[p′, :box] = b′ - operation[p′, :junction] = findjunction(j) - operation[p′, :port_type] = nuwd.diagram[j, :junction_type] - end - end - - operation -end - - -# Get the attribute types of an undirected wiring diagram. -function getattributetypes(::Type{<:UntypedRelationDiagram{B, V}}) where {B, V} - Nothing, B, V -end - - -function getattributetypes(::Type{<:TypedRelationDiagram{T, B, V}}) where {T, B, V} - T, B, V -end diff --git a/test/Decompositions.jl b/test/Decompositions.jl index 4761456..c9cb286 100644 --- a/test/Decompositions.jl +++ b/test/Decompositions.jl @@ -5,13 +5,14 @@ using PartialFunctions using StructuredDecompositions.Decompositions using StructuredDecompositions.FunctorUtils +using StructuredDecompositions.JunctionTrees: Order, Maximal using Catlab.Graphics using Catlab.Graphs using Catlab.ACSetInterface using Catlab.CategoricalAlgebra -#using Catlab.Graphi +#using Catlab.Graph #Define the instance####################### #bag 1 @@ -99,4 +100,61 @@ bigdecomp_skeleton = 𝐃ₛ(bigdecomp_to_sets) adhesionSpans(bigdecomp_skeleton) ) -end \ No newline at end of file + +################################## +# Integration with JunctionTrees # +################################## + + +graph = SymmetricGraph(17) + +add_edges!(graph, + [1, 1, 1, 1, 2, 2, 5, 5, 6, 6, 7, 7, 7, 10, 10, 10, 10, 12, 12, 12, 12, 15], + [3, 4, 5, 15, 3, 4, 9, 16, 9, 16, 8, 9, 15, 11, 13, 14, 17, 13, 14, 16, 17, 17]) + +decomposition = StrDecomp(graph, Order(1:17), Maximal()) + +@test decomposition.decomp_shape == @acset Graph begin + V = 8 + E = 7 + src = [1, 2, 3, 4, 5, 6, 7] + tgt = [5, 5, 4, 5, 6, 8, 8] +end + +@test map(i -> ob_map(decomposition.diagram, i), 1:15) == [ + induced_subgraph(graph, [7, 8, 9, 15]), # g h i o + induced_subgraph(graph, [6, 9, 16]), # f i p + induced_subgraph(graph, [2, 3, 4]), # b c d + induced_subgraph(graph, [1, 3, 4, 5, 15]), # a c d e o + induced_subgraph(graph, [5, 9, 15, 16]), # e i o p + induced_subgraph(graph, [15, 16, 17]), # o p q + induced_subgraph(graph, [10, 11, 13, 14, 17]), # j k m n q + induced_subgraph(graph, [12, 13, 14, 16, 17]), # l m n p q + induced_subgraph(graph, [9, 15]), # i o + induced_subgraph(graph, [9, 16]), # i p + induced_subgraph(graph, [3, 4]), # c d + induced_subgraph(graph, [5, 15]), # e o + induced_subgraph(graph, [15, 16]), # o p + induced_subgraph(graph, [16, 17]), # p q + induced_subgraph(graph, [13, 14, 17]), # m n q +] + +@test map(i -> hom_map(decomposition.diagram, i), 1:14) == [ + ACSetTransformation(induced_subgraph(graph, [9, 15]), induced_subgraph(graph, [5, 9, 15, 16]), V=[2, 3], E=Int[]), # i o → e i o p + ACSetTransformation(induced_subgraph(graph, [9, 16]), induced_subgraph(graph, [5, 9, 15, 16]), V=[2, 4], E=Int[]), # i p → e i o p + ACSetTransformation(induced_subgraph(graph, [3, 4]), induced_subgraph(graph, [1, 3, 4, 5, 15]), V=[2, 3], E=Int[]), # c d → a c d e o + ACSetTransformation(induced_subgraph(graph, [5, 15]), induced_subgraph(graph, [5, 9, 15, 16]), V=[1, 3], E=Int[]), # e o → e i o p + ACSetTransformation(induced_subgraph(graph, [15, 16]), induced_subgraph(graph, [15, 16, 17]), V=[1, 2], E=Int[]), # o p → o p q + ACSetTransformation(induced_subgraph(graph, [16, 17]), induced_subgraph(graph, [12, 13, 14, 16, 17]), V=[4, 5], E=Int[]), # p q → l m n p q + ACSetTransformation(induced_subgraph(graph, [13, 14, 17]), induced_subgraph(graph, [12, 13, 14, 16, 17]), V=[2, 3, 5], E=Int[]), # m n q → l m n p q + ACSetTransformation(induced_subgraph(graph, [9, 15]), induced_subgraph(graph, [7, 8, 9, 15]), V=[3, 4], E=Int[]), # i o → g h i o + ACSetTransformation(induced_subgraph(graph, [9, 16]), induced_subgraph(graph, [6, 9, 16]), V=[2, 3], E=Int[]), # i p → f i p + ACSetTransformation(induced_subgraph(graph, [3, 4]), induced_subgraph(graph, [2, 3, 4]), V=[2, 3], E=Int[]), # c d → b c d + ACSetTransformation(induced_subgraph(graph, [5, 15]), induced_subgraph(graph, [1, 3, 4, 5, 15]), V=[4, 5], E=Int[]), # e o → a c d e o + ACSetTransformation(induced_subgraph(graph, [15, 16]), induced_subgraph(graph, [5, 9, 15, 16]), V=[3, 4], E=Int[]), # o p → e i o p + ACSetTransformation(induced_subgraph(graph, [16, 17]), induced_subgraph(graph, [15, 16, 17]), V=[2, 3], E=Int[]), # p q → o p q + ACSetTransformation(induced_subgraph(graph, [13, 14, 17]), induced_subgraph(graph, [10, 11, 13, 14, 17]), V=[3, 4, 5], E=Int[]), # m n q → j k m n q +] + + +end diff --git a/test/JunctionTrees.jl b/test/JunctionTrees.jl index 0e9b724..89ecf85 100644 --- a/test/JunctionTrees.jl +++ b/test/JunctionTrees.jl @@ -1,13 +1,15 @@ +using StructuredDecompositions.JunctionTrees + + using AbstractTrees using Catlab.BasicGraphs using Catlab.RelationalPrograms using Catlab.UndirectedWiringDiagrams -using StructuredDecompositions.JunctionTrees using Test -# Vandenberghe and Andersen # Chordal Graphs and Semidefinite Optimization +# Vandenberghe and Andersen graph = SymmetricGraph(17) add_edges!(graph, @@ -15,274 +17,222 @@ add_edges!(graph, [3, 4, 5, 15, 3, 4, 9, 16, 9, 16, 8, 9, 15, 11, 13, 14, 17, 13, 14, 16, 17, 17]) order = JunctionTrees.Order(graph, CuthillMcKeeJL_RCM()) -@test order == [2, 14, 13, 11, 4, 3, 12, 10, 16, 1, 17, 5, 6, 15, 9, 7, 8] +@test length(order) == 17 order = JunctionTrees.Order(graph, AMDJL_AMD()) -@test order == [8, 11, 7, 2, 4, 3, 1, 6, 13, 14, 10, 12, 17, 16, 5, 9, 15] +@test length(order) == 17 order = JunctionTrees.Order(graph, MetisJL_ND()) -# @test order == [11, 17, 14, 13, 10, 12, 8, 6, 7, 5, 4, 3, 9, 2, 1, 16, 15] -# changing test case to only check that the size of the order object as the object in previous test case -# reference change request in PR#20 StructuredDecompositions -@test length(order) == length([11, 17, 14, 13, 10, 12, 8, 6, 7, 5, 4, 3, 9, 2, 1, 16, 15]) +@test length(order) == 17 + +order = JunctionTrees.Order(graph, TreeWidthSolverJL_BT()) +@test length(order) == 17 order = JunctionTrees.Order(graph, MCS()) -@test order == [2, 3, 4, 8, 1, 5, 6, 9, 7, 11, 13, 10, 14, 16, 12, 15, 17] +@test length(order) == 17 order = JunctionTrees.Order(1:17) -parent = JunctionTrees.makeetree(graph, order) - -# Figure 4.2 -@test parent == [3, 3, 4, 5, 9, 9, 8, 9, 15, 11, 13, 13, 14, 16, 16, 17, 0] - -etree = JunctionTrees.Tree(17, parent) -indegrees, outdegrees = JunctionTrees.getdegrees(graph, order, etree) +@test length(order) == 17 -@test indegrees == [0, 0, 2, 3, 3, 0, 0, 1, 4, 0, 1, 0, 3, 4, 7, 7, 7] -@test outdegrees == [4, 2, 3, 2, 3, 2, 3, 2, 2, 4, 3, 4, 3, 2, 2, 1, 0] - -# Figure 4.3 -snd, sbt, q, a = JunctionTrees.makestree(etree, outdegrees, Node()) +# Figure 4.3 +jtree = JunctionTree(graph, order, Node()) +@test width(jtree) == 4 +@test height(jtree) == 7 +@test length(jtree) == 17 -@test snd == [ +@test map(i -> parentindex(jtree, i), 1:17) == [ + 2, + 9, + 9, + 6, + 6, + 7, + 8, + 9, + 10, + 16, + 14, + 13, + 14, + 15, + 16, + 17, + nothing, +] + +@test map(i -> childindices(jtree, i), 1:17) == [ + [], [1], - [2], - [3], - [4], - [5], + [], + [], + [], + [4, 5], [6], [7], - [8], + [2, 3, 8], [9], - [10], - [11], + [], + [], [12], - [13], + [11, 13], [14], - [15], + [10, 15], [16], - [17]] +] + +@test map(i -> residual(jtree, i), 1:17) == [ + [7], # g + [8], # h + [6], # f + [2], # b + [1], # a + [3], # c + [4], # d + [5], # e + [9], # i + [15], # o + [12], # l + [10], # j + [11], # k + [13], # m + [14], # n + [16], # p + [17], # q +] + +@test map(i -> seperator(jtree, i), 1:17) == [ + [8, 9, 15], # h i o + [9, 15], # i o + [9, 16], # i p + [3, 4], # c d + [3, 4, 5, 15], # c d e o + [4, 5, 15], # d e o + [5, 15], # e o + [9, 15, 16], # i o p + [15, 16], # o p + [16, 17], # p q + [13, 14, 16, 17], # m n p q + [11, 13, 14, 17], # k m n q + [13, 14, 17], # m n q + [14, 16, 17], # n p q + [16, 17], # p q + [17], # q + [], # +] -@test sbt == 1:17 -@test q == [3, 3, 4, 5, 9, 9, 8, 9, 15, 11, 13, 13, 14, 16, 16, 17, 0] -@test a == [3, 3, 4, 5, 9, 9, 8, 9, 15, 11, 13, 13, 14, 16, 16, 17, 0] # Figure 4.7 (left) -snd, sbt, q, a = JunctionTrees.makestree(etree, outdegrees, MaximalSupernode()) +jtree = JunctionTree(graph, order, Maximal()) +@test width(jtree) == 4 +@test height(jtree) == 4 +@test length(jtree) == 8 -@test snd == [ - [1, 3, 4], - [2], - [5, 9], - [6], - [7, 8], - [10, 11], - [12, 13, 14, 16, 17], - [15] ] - -@test sbt == [1, 2, 1, 1, 3, 4, 5, 5, 3, 6, 6, 7, 7, 7, 8, 7, 7] -@test q == [3, 1, 8, 3, 3, 7, 0, 7] -@test a == [5, 3, 15, 9, 9, 13, 0, 16] - -# Figure 4.9 -snd, sbt, q, a = JunctionTrees.makestree(etree, outdegrees, FundamentalSupernode()) - -@test snd == [ - [1], - [2], - [3, 4], - [5], - [6], - [7, 8], - [9], - [10, 11], - [12], - [13, 14], - [15], - [16, 17] ] - -@test sbt == [1, 2, 3, 3, 4, 5, 6, 6, 7, 8, 8, 9, 10, 10, 11, 12, 12] -@test q == [3, 3, 4, 7, 7, 7, 11, 10, 10, 12, 12, 0] -@test a == [3, 3, 5, 9, 9, 9, 15, 13, 13, 16, 16, 0] - -# Figure 4.3 -jtree = JunctionTree(graph, order, Node()) +@test map(i -> parentindex(jtree, i), 1:8) == [ + 5, + 5, + 4, + 5, + 6, + 8, + 8, + nothing +] -@test getresidual.([jtree], getsubtree.([jtree], 1:17)) == [ - [1], - [2], +@test map(i -> childindices(jtree, i), 1:8) == [ + [], + [], + [], [3], - [4], + [1, 2, 4], [5], - [6], - [7], - [8], - [9], - [10], - [11], - [12], - [13], - [14], - [15], - [16], - [17] ] + [], + [6, 7], +] + +@test map(i -> residual(jtree, i), 1:8) == [ + [7, 8], # g h + [6], # f + [2], # b + [1, 3, 4], # a c d + [5, 9], # e i + [15], # o + [10, 11], # j k + [12, 13, 14, 16, 17], # l m n p q +] + +@test map(i -> seperator(jtree, i), 1:8) == [ + [9, 15], # i o + [9, 16], # i p + [3, 4], # c d + [5, 15], # e o + [15, 16], # o p + [16, 17], # p q + [13, 14, 17], # m n q + [], # +] -@test getseperator.([jtree], getsubtree.([jtree], 1:17)) == [ - [3, 4, 5, 15], - [3, 4], - [4, 5, 15], - [5, 15], - [9, 15, 16], - [9, 16], - [8, 9, 15], - [9, 15], - [15, 16], - [11, 13, 14, 17], - [13, 14, 17], - [13, 14, 16, 17], - [14, 16, 17], - [16, 17], - [16, 17], - [17], - [] ] +# Figure 4.9 +jtree = JunctionTree(graph, order, Fundamental()) +@test width(jtree) == 4 +@test height(jtree) == 5 +@test length(jtree) == 12 -@test getlevel.([jtree], getsubtree.([jtree], 1:17)) == [ +@test map(i -> parentindex(jtree, i), 1:12) == [ 7, 7, - 6, 5, - 4, - 4, - 5, - 4, - 3, 5, - 4, - 4, - 3, - 2, - 2, - 1, - 0 ] - -@test isdescendant(jtree, getsubtree(jtree, 5), getsubtree(jtree, 15)) -@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 5)) -@test !isdescendant(jtree, getsubtree(jtree, 10), getsubtree(jtree, 15)) -@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 10)) -@test !isdescendant(jtree, getsubtree(jtree, 1), getsubtree(jtree, 1)) -@test getwidth(jtree) == 4 - -# Figure 4.7 (left) -jtree = JunctionTree(graph, order, MaximalSupernode()) - -@test getresidual.([jtree], getsubtree.([jtree], 1:17)) == [ - [1, 3, 4], - [2], - [1, 3, 4], - [1, 3, 4], - [5, 9], - [6], - [7, 8], - [7, 8], - [5, 9], - [10, 11], - [10, 11], - [12, 13, 14, 16, 17], - [12, 13, 14, 16, 17], - [12, 13, 14, 16, 17], - [15], - [12, 13, 14, 16, 17], - [12, 13, 14, 16, 17]] - -@test getseperator.([jtree], getsubtree.([jtree], 1:17)) == [ - [5, 15], - [3, 4], - [5, 15], - [5, 15], - [15, 16], - [9, 16], - [9, 15], - [9, 15], - [15, 16], - [13, 14, 17], - [13, 14, 17], + 6, + 7, + 8, + 12, + 11, + 11, + 12, + nothing, +] + +@test map(i -> childindices(jtree, i), 1:12) == [ [], [], [], - [16, 17], [], - []] - -@test getlevel.([jtree], getsubtree.([jtree], 1:17)) == [ - 3, - 4, - 3, - 3, - 2, - 3, - 3, - 3, - 2, - 1, - 1, - 0, - 0, - 0, - 1, - 0, - 0 ] - -@test isdescendant(jtree, getsubtree(jtree, 5), getsubtree(jtree, 15)) -@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 5)) -@test !isdescendant(jtree, getsubtree(jtree, 10), getsubtree(jtree, 15)) -@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 10)) -@test !isdescendant(jtree, getsubtree(jtree, 1), getsubtree(jtree, 1)) -@test getwidth(jtree) == 4 - -# Figure 4.9 -jtree = JunctionTree(graph, order, FundamentalSupernode()) - -@test getresidual.([jtree], getsubtree.([jtree], 1:17)) == [ - [1], - [2], - [3, 4], [3, 4], [5], - [6], - [7, 8], - [7, 8], - [9], - [10, 11], - [10, 11], - [12], - [13, 14], - [13, 14], - [15], - [16, 17], - [16, 17]] - -@test getseperator.([jtree], getsubtree.([jtree], 1:17)) == [ - [3, 4, 5, 15], - [3, 4], - [5, 15], - [5, 15], - [9, 15, 16], - [9, 16], - [9, 15], - [9, 15], - [15, 16], - [13, 14, 17], - [13, 14, 17], - [13, 14, 16, 17], - [16, 17], - [16, 17], - [16, 17], + [1, 2, 6], + [7], [], - []] - -@test isdescendant(jtree, getsubtree(jtree, 5), getsubtree(jtree, 15)) -@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 5)) -@test !isdescendant(jtree, getsubtree(jtree, 10), getsubtree(jtree, 15)) -@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 10)) -@test !isdescendant(jtree, getsubtree(jtree, 1), getsubtree(jtree, 1)) -@test getwidth(jtree) == 4 + [], + [9, 10], + [8, 11], +] + +@test map(i -> residual(jtree, i), 1:12) == [ + [7, 8], # g h + [6], # f + [2], # b + [1], # a + [3, 4], # c d + [5], # e + [9], # i + [15], # o + [12], # l + [10, 11], # j k + [13, 14], # m n + [16, 17], # p q +] + +@test map(i -> seperator(jtree, i), 1:12) == [ + [9, 15], # i o + [9, 16], # i p + [3, 4], # c d + [3, 4, 5, 15], # c d e o + [5, 15], # e o + [9, 15, 16], # i o p + [15, 16], # o p + [16, 17], # p q + [13, 14, 16, 17], # m n p q + [13, 14, 17], # m n q + [16, 17], # p q + [], # +] diff --git a/test/NestedUWDs.jl b/test/NestedUWDs.jl deleted file mode 100644 index e156351..0000000 --- a/test/NestedUWDs.jl +++ /dev/null @@ -1,122 +0,0 @@ -using Catlab.RelationalPrograms -using Catlab.UndirectedWiringDiagrams -using LinearAlgebra -using StructuredDecompositions.NestedUWDs -using Test - - -# CategoricalTensorNetworks.jl -# https://github.com/AlgebraicJulia/CategoricalTensorNetworks.jl/ -function contract_tensor_network(d::UndirectedWiringDiagram, - tensors::AbstractVector{<:AbstractArray}) - @assert nboxes(d) == length(tensors) - juncs = [junction(d, ports(d, b)) for b in boxes(d)] - j_out = junction(d, ports(d, outer=true), outer=true) - contract_tensor_network(tensors, juncs, j_out) -end - - -function contract_tensor_network(tensors::AbstractVector{<:AbstractArray{T}}, - juncs::AbstractVector, j_out) where T - # Handle important binary case with specialized code. - if length(tensors) == 2 && length(juncs) == 2 - return contract_tensor_network(Tuple(tensors), Tuple(juncs), j_out) - end - - jsizes = Tuple(infer_junction_sizes(tensors, juncs, j_out)) - juncs, j_out = map(Tuple, juncs), Tuple(j_out) - C = zeros(T, Tuple(jsizes[j] for j in j_out)) - for index in CartesianIndices(jsizes) - x = one(T) - for (A, junc) in zip(tensors, juncs) - x *= A[(index[j] for j in junc)...] - end - C[(index[j] for j in j_out)...] += x - end - return C -end - - -function contract_tensor_network( # Binary case. - (A, B)::Tuple{<:AbstractArray{T},<:AbstractArray{T}}, - (jA, jB), j_out) where T - jsizes = Tuple(infer_junction_sizes((A, B), (jA, jB), j_out)) - jA, jB, j_out = Tuple(jA), Tuple(jB), Tuple(j_out) - C = zeros(T, Tuple(jsizes[j] for j in j_out)) - for index in CartesianIndices(jsizes) - C[(index[j] for j in j_out)...] += - A[(index[j] for j in jA)...] * B[(index[j] for j in jB)...] - end - return C -end - - -function infer_junction_sizes(tensors, juncs, j_out) - @assert length(tensors) == length(juncs) - njunc = maximum(Iterators.flatten((Iterators.flatten(juncs), j_out))) - jsizes = fill(-1, njunc) - for (A, junc) in zip(tensors, juncs) - for (i, j) in enumerate(junc) - if jsizes[j] == -1 - jsizes[j] = size(A, i) - else - @assert jsizes[j] == size(A, i) - end - end - end - @assert all(s >= 0 for s in jsizes) - jsizes -end - - -# out[v,z] = A[v,w] * B[w,x] * C[x,y] * D[y,z] -diagram = @relation (v, z) begin - A(v, w) - B(w, x) - C(x, y) - D(y, z) -end - -nuwd = NestedUWD(diagram) -A, B, C, D = map(randn, [(3, 4), (4, 5), (5, 6), (6, 7)]) -generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B, :C => C, :D => D) -out = evalschedule(contract_tensor_network, nuwd, generators) -@test out ≈ A * B * C * D - -# out[] = A[w,x] * B[x,y] * C[y,z] * D[z,w] -diagram = @relation () begin - A(w, x) - B(x, y) - C(y, z) - D(z, w) -end - -nuwd = NestedUWD(diagram) -A, B, C, D = map(randn, [(10, 5), (5, 5), (5, 5), (5, 10)]) -generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B, :C => C, :D => D) -out = evalschedule(contract_tensor_network, nuwd, generators) -@test out[] ≈ tr(A * B * C * D) - -# out[w,x,y,z] = A[w,x] * B[y,z] -diagram = @relation (w, x, y, z) begin - A(w, x) - B(y, z) -end - -nuwd = NestedUWD(diagram) -A, B = map(randn, [(3, 4), (5, 6)]) -generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B) -out = evalschedule(contract_tensor_network, nuwd, generators) -@test out ≈ (reshape(A, (3, 4, 1, 1)) .* reshape(B, (1, 1, 5, 6))) - -# out[] = A[x,y] * B[x,y] -diagram = @relation () begin - A(x, y) - B(x, y) -end - -nuwd = NestedUWD(diagram) -A, B = map(randn, [(5, 5), (5, 5)]) -generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B) -out = evalschedule(contract_tensor_network, nuwd, generators) -@test out[] ≈ dot(vec(A), vec(B)) diff --git a/test/runtests.jl b/test/runtests.jl index 6fdd6b8..fbb1669 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,3 @@ end @testset "JunctionTrees" begin include("JunctionTrees.jl") end - -@testset "NestedUWDs" begin - include("NestedUWDs.jl") -end