diff --git a/Project.toml b/Project.toml index 4bdc28f..7741486 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" CuthillMcKee = "17f17636-5e38-52e3-a803-7ae3aaaf3da9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinkedLists = "70f5e60a-1556-5f34-a19e-a48b3e4aaee9" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" diff --git a/src/Decompositions.jl b/src/Decompositions.jl index ef6b6db..5a52f3b 100644 --- a/src/Decompositions.jl +++ b/src/Decompositions.jl @@ -214,7 +214,7 @@ end # jtree junction tree # ---------------------------------------- function StrDecomp(graph::AbstractSymmetricGraph, jtree::JunctionTree) - n = length(jtree) + n = treesize(jtree) tree = Graph(n) for i in 1:n - 1 @@ -293,7 +293,7 @@ end function homomorphisms(graph::AbstractSymmetricGraph, jtree::JunctionTree) - n = length(jtree) + n = treesize(jtree) subgraph = Vector{Any}(undef, 2n - 1) homomorphism = Vector{Any}(undef, 2n - 2) @@ -310,12 +310,12 @@ function homomorphisms(graph::AbstractSymmetricGraph, jtree::JunctionTree) for i in 1:n - 1 # seperator(i) → clique(parent(i)) j = parentindex(jtree, i) - homomorphism[i] = induced_homomorphism(subgraph[n + i], subgraph[j], seperator_to_parent(jtree, i)) + homomorphism[i] = induced_homomorphism(subgraph[n + i], subgraph[j], lift_par(jtree, i)) end for i in 1:n - 1 # seperator(i) → clique(i) - homomorphism[n + i - 1] = induced_homomorphism(subgraph[n + i], subgraph[i], seperator_to_self(jtree, i)) + homomorphism[n + i - 1] = induced_homomorphism(subgraph[n + i], subgraph[i], lift_sep(jtree, i)) end subgraph, homomorphism @@ -323,7 +323,7 @@ end function induced_order(order::Order, elements::AbstractVector) - Order(sortperm(inverse(order, elements))) + Order(sortperm(view(inv(order), elements))) end diff --git a/src/JunctionTrees.jl b/src/JunctionTrees.jl index 5437baa..4649042 100644 --- a/src/JunctionTrees.jl +++ b/src/JunctionTrees.jl @@ -2,41 +2,52 @@ module JunctionTrees import AMD +import Catlab.BasicGraphs import CuthillMcKee import LinkedLists import Metis import TreeWidthSolver using AbstractTrees -using Catlab.BasicGraphs +using Base.Order: Ordering using DataStructures +using Graphs +using Graphs.SimpleGraphs +using LinearAlgebra using SparseArrays +using SparseArrays: AbstractSparseMatrixCSC # Orders -export Order, inverse +export Order # Elimination Algorithms export AMDJL_AMD, CuthillMcKeeJL_RCM, MetisJL_ND, TreeWidthSolverJL_BT, MCS +# Ordered Graphs +export OrderedGraph + + # Supernode Types export Node, Maximal, Fundamental # Junction Trees -export JunctionTree, width, height, seperator, residual, clique, seperator_to_parent, seperator_to_self +export JunctionTree, treewidth, seperator, residual, clique, find_clique, lift_par, lift_sep, lift +include("junction_trees/fixed_stacks.jl") +include("junction_trees/disjoint_sets.jl") include("junction_trees/orders.jl") include("junction_trees/elimination_algorithms.jl") -include("junction_trees/ordered_graphs.jl") include("junction_trees/trees.jl") +include("junction_trees/child_indices.jl") include("junction_trees/postorder_trees.jl") -include("junction_trees/elimination_trees.jl") +include("junction_trees/abstract_ordered_graphs.jl") +include("junction_trees/ordered_graphs.jl") include("junction_trees/supernode_types.jl") -include("junction_trees/supernode_trees.jl") include("junction_trees/junction_trees.jl") diff --git a/src/junction_trees/abstract_ordered_graphs.jl b/src/junction_trees/abstract_ordered_graphs.jl new file mode 100644 index 0000000..c37788c --- /dev/null +++ b/src/junction_trees/abstract_ordered_graphs.jl @@ -0,0 +1,31 @@ +abstract type AbstractOrderedGraph <: AbstractSimpleGraph{Int} end + + +############################ +# Abstract Graph Interface # +############################ + + +function SimpleGraphs.is_directed(::Type{AbstractOrderedGraph}) + true +end + + +function SimpleGraphs.edgetype(graph::AbstractOrderedGraph) + SimpleEdge{Int} +end + + +function SimpleGraphs.has_edge(graph::AbstractOrderedGraph, edge::SimpleEdge{Int}) + i = src(edge) + j = dst(edge) + i < j && insorted(j, outneighbors(graph, i)) +end + + +# Multiline printing. +function Base.show(io::IO, ::MIME"text/plain", graph::AbstractOrderedGraph) + print(io, "ordered graph:\n") + SparseArrays._show_with_braille_patterns(io, adjacencymatrix(graph)) +end + diff --git a/src/junction_trees/child_indices.jl b/src/junction_trees/child_indices.jl new file mode 100644 index 0000000..b87be2d --- /dev/null +++ b/src/junction_trees/child_indices.jl @@ -0,0 +1,33 @@ +struct ChildIndices + tree::Tree + index::Int +end + + +function Base.iterate(iterator::ChildIndices) + iterate(iterator, iterator.tree.child[iterator.index]) +end + + +function Base.iterate(iterator::ChildIndices, i::Integer) + if iszero(i) + nothing + else + i, iterator.tree.brother[i] + end +end + + +function Base.IteratorSize(::Type{ChildIndices}) + Base.SizeUnknown() +end + + +function Base.eltype(::Type{ChildIndices}) + Int +end + + +function AbstractTrees.childindices(tree::Tree, i::Integer) + ChildIndices(tree, i) +end diff --git a/src/junction_trees/disjoint_sets.jl b/src/junction_trees/disjoint_sets.jl new file mode 100644 index 0000000..ac49f7d --- /dev/null +++ b/src/junction_trees/disjoint_sets.jl @@ -0,0 +1,21 @@ +struct DisjointSets + sets::IntDisjointSets{Int} + index::Vector{Int} + root::Vector{Int} + + function DisjointSets(n::Integer) + new(IntDisjointSets(n), collect(1:n), collect(1:n)) + end +end + + +function find!(sets::DisjointSets, u::Integer) + sets.index[find_root!(sets.sets, u)] +end + + +function Base.union!(sets::DisjointSets, u::Integer, v::Integer) + w = max(u, v) + sets.root[w] = root_union!(sets.sets, sets.root[u], sets.root[v]) + sets.index[sets.root[w]] = w +end diff --git a/src/junction_trees/elimination_algorithms.jl b/src/junction_trees/elimination_algorithms.jl index 71d5d5c..fb0f616 100644 --- a/src/junction_trees/elimination_algorithms.jl +++ b/src/junction_trees/elimination_algorithms.jl @@ -52,45 +52,51 @@ struct MCS <: EliminationAlgorithm end """ - Order(graph::AbstractSymmetricGraph[, ealg::EliminationAlgorithm]) + Order(graph[, ealg::EliminationAlgorithm]) Construct an elimination order for a simple graph, optionally specifying an elimination algorithm. """ -function Order(graph::AbstractSymmetricGraph) +function Order(graph, ealg::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM) + Order(adjacencymatrix(graph), ealg) +end + + +# Construct an elimination order. +function Order(graph::AbstractMatrix) Order(graph, DEFAULT_ELIMINATION_ALGORITHM) end # Construct an order using the reverse Cuthill-McKee algorithm. Uses CuthillMcKee.jl. -function Order(graph::AbstractSymmetricGraph, ealg::CuthillMcKeeJL_RCM) - order = CuthillMcKee.symrcm(adjacencymatrix(graph)) +function Order(graph::AbstractMatrix, ealg::CuthillMcKeeJL_RCM) + order = CuthillMcKee.symrcm(graph) Order(order) end # Construct an order using the approximate minimum degree algorithm. Uses AMD.jl. -function Order(graph::AbstractSymmetricGraph, ealg::AMDJL_AMD) - order = AMD.symamd(adjacencymatrix(graph)) +function Order(graph::AbstractMatrix, ealg::AMDJL_AMD) + order = AMD.symamd(graph) Order(order) end # Construct an order using the nested dissection heuristic. Uses Metis.jl. -function Order(graph::AbstractSymmetricGraph, ealg::MetisJL_ND) - order, index = Metis.permutation(adjacencymatrix(graph)) +function Order(graph::AbstractMatrix, ealg::MetisJL_ND) + order, index = Metis.permutation(graph) Order(order, index) end # Construct an order using the Bouchitte-Todinca algorithm. Uses TreeWidthSolver.jl. -function Order(graph::AbstractSymmetricGraph, ealg::TreeWidthSolverJL_BT) - n = nv(graph) +function Order(graph::AbstractSparseMatrixCSC, ealg::TreeWidthSolverJL_BT) + n = size(graph, 1) T = TreeWidthSolver.LongLongUInt{n ÷ 64 + 1} fadjlist = Vector{Vector{Int}}(undef, n) bitfadjlist = Vector{T}(undef, n) for i in 1:n - fadjlist[i] = sort(collect(outneighbors(graph, i))) + fadjlist[i] = rowvals(graph)[nzrange(graph, i)] bitfadjlist[i] = TreeWidthSolver.bmask(T, fadjlist[i]) end @@ -102,45 +108,54 @@ end # Construct an order using the maximum cardinality search algorithm. -function Order(graph::AbstractSymmetricGraph, ealg::MCS) - order, index = mcs(graph) - Order(order, index) +function Order(graph::AbstractMatrix, ealg::MCS) + mcs(graph) end # Construct the adjacency matrix of a graph. -function adjacencymatrix(graph::AbstractSymmetricGraph) - m = ne(graph) - n = nv(graph) - - colptr = ones(Int, n + 1) - rowval = sizehint!(Vector{Int}(), 2m) - - for j in 1:n - ns = collect(neighbors(graph, j)) - sort!(ns) - colptr[j + 1] = colptr[j] + length(ns) - append!(rowval, ns) +function adjacencymatrix(graph::BasicGraphs.AbstractSymmetricGraph) + m = BasicGraphs.ne(graph) + n = BasicGraphs.nv(graph) + colptr = Vector{Int}(undef, n + 1) + rowval = Vector{Int}(undef, m) + count = 1 + + for i in 1:n + colptr[i] = count + neighbor = collect(BasicGraphs.all_neighbors(graph, i)) + sort!(neighbor) + + for j in neighbor + rowval[count] = j + count += 1 + end end - nzval = ones(Int, length(rowval)) + colptr[n + 1] = m + 1 + nzval = ones(Bool, m) SparseMatrixCSC(n, n, colptr, rowval, nzval) end +# Construct the adjacency matrix of a graph. +function adjacencymatrix(graph::AbstractGraph) + adjacency_matrix(graph; dir=:both) +end + + # Simple Linear-Time Algorithms to Test Chordality of Graphs, Test Acyclicity of Hypergraphs, and Selectively Reduce Acyclic Hypergraphs # Tarjan and Yannakakis # Maximum Cardinality Search -function mcs(graph::AbstractSymmetricGraph) - n = nv(graph) - α = Vector{Int}(undef, n) - β = Vector{Int}(undef, n) - size = Vector{Int}(undef, n) +function mcs(graph::AbstractSparseMatrixCSC) + n = size(graph, 1) + α = Order(undef, n) + len = Vector{Int}(undef, n) set = Vector{LinkedLists.LinkedList{Int}}(undef, n) pointer = Vector{LinkedLists.ListNode{Int}}(undef, n) for i in 1:n - size[i] = 1 + len[i] = 1 set[i] = LinkedLists.LinkedList{Int}() pointer[i] = push!(set[1], i) end @@ -151,15 +166,14 @@ function mcs(graph::AbstractSymmetricGraph) while i >= 1 v = first(set[j]) deleteat!(set[j], pointer[v]) - α[v] = i - β[i] = v - size[v] = 0 - - for w in neighbors(graph, v) - if size[w] >= 1 - deleteat!(set[size[w]], pointer[w]) - size[w] += 1 - pointer[w] = push!(set[size[w]], w) + α[i] = v + len[v] = 0 + + for w in view(rowvals(graph), nzrange(graph, v)) + if len[w] >= 1 + deleteat!(set[len[w]], pointer[w]) + len[w] += 1 + pointer[w] = push!(set[len[w]], w) end end @@ -171,7 +185,7 @@ function mcs(graph::AbstractSymmetricGraph) end end - β, α + α end diff --git a/src/junction_trees/elimination_trees.jl b/src/junction_trees/elimination_trees.jl deleted file mode 100644 index 7a1afc5..0000000 --- a/src/junction_trees/elimination_trees.jl +++ /dev/null @@ -1,124 +0,0 @@ -# An ordered graph (G, σ) equipped with the elimination tree T of its elimination graph. -# Nodes i in T correspond to vertices σ(i) in G. -struct EliminationTree{T <: Union{Tree, PostorderTree}} - tree::T # elimination tree - graph::OrderedGraph # ordered graph -end - - -# Construct an elimination tree using an elimination algorithm. -# ---------------------------------------- -# graph simple connected graph -# ealg elimination algorithm -# ---------------------------------------- -function EliminationTree( - graph::AbstractSymmetricGraph, - ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM) - EliminationTree(OrderedGraph(graph, ealg)) -end - - -# Construct the elimination tree of an ordered graph. -# ---------------------------------------- -# graph ordered graph -# ---------------------------------------- -function EliminationTree(graph::OrderedGraph) - EliminationTree(Tree(etree(graph)), graph) -end - - -# Postorder an elimination tree. -# ---------------------------------------- -# graph ordered graph -# order postorder -# ---------------------------------------- -function EliminationTree{PostorderTree}(etree::EliminationTree, order::Order) - EliminationTree(PostorderTree(etree.tree, order), OrderedGraph(etree.graph, order)) -end - - -# An Efficient Algorithm to Compute Row and Column Counts for Sparse Cholesky Factorization -# Gilbert, Ng, and Peyton -# Figure 3: Implementation of algorithm to compute row and column counts. -function supcnt(etree::EliminationTree) - order = postorder(etree.tree) - index = inverse(order) - rc, cc = supcnt(EliminationTree{PostorderTree}(etree, order)) - rc[index], cc[index] -end - - -# An Efficient Algorithm to Compute Row and Column Counts for Sparse Cholesky Factorization -# Gilbert, Ng, and Peyton -# Figure 3: Implementation of algorithm to compute row and column counts. -function supcnt(etree::EliminationTree{PostorderTree}) - n = length(etree.tree) - - #### Disjoint Set Union #### - - rvert = collect(1:n) - index = collect(1:n) - forest = IntDisjointSets(n) - - function find(u) - index[find_root!(forest, u)] - end - - function union(u, v) - w = max(u, v) - rvert[w] = root_union!(forest, rvert[u], rvert[v]) - index[rvert[w]] = w - end - - ############################ - - prev_p = zeros(Int, n) - prev_nbr = zeros(Int, n) - rc = ones(Int, n) - wt = ones(Int, n) - - for u in 1:n - 1 - wt[parentindex(etree.tree, u)] = 0 - end - - for p in 1:n - 1 - wt[parentindex(etree.tree, p)] -= 1 - - for u in outneighbors(etree.graph, p) - if firstdescendant(etree.tree, p) > prev_nbr[u] - wt[p] += 1 - pp = prev_p[u] - - if iszero(pp) - rc[u] += level(etree.tree, p) - level(etree.tree, u) - else - q = find(pp) - rc[u] += level(etree.tree, p) - level(etree.tree, q) - wt[q] -= 1 - end - - prev_p[u] = p - end - - prev_nbr[u] = p - end - - union(p, parentindex(etree.tree, p)) - end - - cc = wt - - for v in 1:n - 1 - cc[parentindex(etree.tree, v)] += cc[v] - end - - rc, cc -end - - -# Compute higher degree of every vertex in the elimination graph of -# (G, σ). -function outdegrees(etree::EliminationTree) - rc, cc = supcnt(etree) - cc .- 1 -end diff --git a/src/junction_trees/fixed_stacks.jl b/src/junction_trees/fixed_stacks.jl new file mode 100644 index 0000000..cf14c8f --- /dev/null +++ b/src/junction_trees/fixed_stacks.jl @@ -0,0 +1,40 @@ +struct FixedStack{T} <: AbstractVector{T} + top::Array{Int, 0} + items::Vector{T} + + function FixedStack{T}(n::Integer) where T + new(zeros(Int), Vector{T}(undef, n)) + end +end + + +function Base.push!(stack::FixedStack, v) + stack.top[] += 1 + stack.items[stack.top[]] = v +end + + +function Base.pop!(stack::FixedStack) + stack.top[] -= 1 + stack.items[stack.top[] + 1] +end + + +############################# +# Abstract Vector Interface # +############################# + + +function Base.getindex(stack::FixedStack, i) + stack.items[i] +end + + +function Base.IndexStyle(::Type{FixedStack}) + IndexLinear() +end + + +function Base.size(stack::FixedStack) + (stack.top[],) +end diff --git a/src/junction_trees/junction_trees.jl b/src/junction_trees/junction_trees.jl index 39f525a..b463663 100644 --- a/src/junction_trees/junction_trees.jl +++ b/src/junction_trees/junction_trees.jl @@ -1,131 +1,168 @@ -""" - JunctionTree - -A junction tree. -""" struct JunctionTree - stree::SupernodeTree # supernodal elimination tree - seperator::Vector{Vector{Int}} # seperator -end + order::Order + tree::PostorderTree + partition::Vector{Int} + # supernode(i) + sndptr::Vector{Int} -""" - JunctionTree(graph::AbstractSymmetricGraph[, ealg::Union{Order, EliminationAlgorithm}[, stype::SupernodeType]]) + # seperator(i) + sepptr::Vector{Int} + sepval::Vector{Int} -Construct a tree decomposition of a connected simple graph, optionally specifying an elimination algorithm and -a supernode type. -""" -function JunctionTree( - graph::AbstractSymmetricGraph, - ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, - stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) + # seperator(i) → clique(parent(i)) + #relptr::Vector{Int} + #relval::Vector{Int} +end - JunctionTree(SupernodeTree(graph, ealg, stype)) + +function JunctionTree(graph, ealg::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM, stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) + JunctionTree(graph, Order(graph, ealg), stype) end -# Construct a junction tree. -# ---------------------------------------- -# stree supernodal elimination tree -# ---------------------------------------- -function JunctionTree(stree::SupernodeTree) - JunctionTree(stree, map(sort ∘ collect, seperators(stree))) +function JunctionTree(graph, order::Order, stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) + graph = OrderedGraph(graph, order) + tree = etree(graph) + rowcount, colcount = supcnt(graph, tree) + supernode, tree = stree(graph, tree, colcount, stype) + + n = 0 + postorder = Order(undef, nv(graph)) + partition = Vector{Int}(undef, nv(graph)) + sndptr = Vector{Int}(undef, treesize(tree) + 1) + sndptr[1] = 1 + + for (i, snd) in enumerate(supernode) + sndptr[i + 1] = sndptr[i] + length(snd) + partition[sndptr[i]:sndptr[i + 1] - 1] .= i + postorder[sndptr[i]:sndptr[i + 1] - 1] = snd + n += colcount[snd[end]] - 1 + end + + order = compose(postorder, order) + graph = OrderedGraph(graph, postorder) + + sepval = Vector{Int}(undef, n) + sepptr = Vector{Int}(undef, treesize(tree) + 1) + sepptr[1] = 1 + + fullarray = zeros(Int, nv(graph)) + + for j in 1:treesize(tree) + u = sndptr[j + 1] - 1 + column = Int[] + + for v in outneighbors(graph, sndptr[j]) + if u < v + push!(column, v) + fullarray[v] = j + end + end + + for i in childindices(tree, j) + for v in view(sepval, sepptr[i]:sepptr[i + 1] - 1) + if u < v && fullarray[v] != j + push!(column, v) + fullarray[v] = j + end + end + end + + sepptr[j + 1] = sepptr[j] + length(column) + sepval[sepptr[j]:sepptr[j + 1] - 1] = sort(column) + end + + JunctionTree(order, tree, partition, sndptr, sepptr, sepval) end """ - clique(jtree::JunctionTree, i::Integer) + Order(jtree::JunctionTree) -Get the clique at node i. +Construct a perfect elimination ordering. """ -function clique(jtree::JunctionTree, i::Integer) - [residual(jtree, i); seperator(jtree, i)] +function Order(jtree::JunctionTree) + Order(jtree.order) end """ - seperator(jtree::JunctionTree, i::Integer) + clique(jtree::JunctionTree, i::Integer) -Get the seperator at node i. +Get the clique at node ``i``. """ -function seperator(jtree::JunctionTree, i::Integer) - permutation(jtree.stree.graph, jtree.seperator[i]) +function clique(jtree::JunctionTree, i::Integer) + view(jtree.order, [residualindices(jtree, i); seperatorindices(jtree, i)]) end """ - residual(jtree::JunctionTree, i::Integer) + treewidth(jtree::JunctionTree) -Get the residual at node i. +Compute the width of a junction tree. """ -function residual(jtree::JunctionTree, i::Integer) - permutation(jtree.stree.graph, supernode(jtree.stree, i)) +function treewidth(jtree::JunctionTree) + n = treesize(jtree) + maximum(map(i -> length(residualindices(jtree, i)) + length(seperatorindices(jtree, i)) - 1, 1:n)) end -# Construct the inclusion seperator(i) → clique(parent(i)). -function seperator_to_parent(jtree::JunctionTree, i::Integer) - j = parentindex(jtree, i) - sep = jtree.seperator[i] - sep_parent = jtree.seperator[j] - res_parent = supernode(jtree.stree, j) +function residualindices(jtree::JunctionTree, i::Integer) + jtree.sndptr[i]:jtree.sndptr[i + 1] - 1 +end - i = 0 - index = Vector{Int}(undef, length(sep)) - - for (j, v) in enumerate(sep) - if v in res_parent - index[j] = v - first(res_parent) + 1 - else - i += searchsortedfirst(view(sep_parent, i + 1:length(sep_parent)), v) - index[j] = length(res_parent) + i - end - end - index +function seperatorindices(jtree::JunctionTree, i::Integer) + view(jtree.sepval, jtree.sepptr[i]:jtree.sepptr[i + 1] - 1) end -# Construct the inclusion seperator(i) → clique(i). -function seperator_to_self(jtree::JunctionTree, i::Integer) - sep = jtree.seperator[i] - res = supernode(jtree.stree, i) - length(res) + 1:length(res) + length(sep) +""" + seperator(jtree::JunctionTree, i::Integer) + +Get the seperator at node ``i``. +""" +function seperator(jtree::JunctionTree, i::Integer) + view(jtree.order, seperatorindices(jtree, i)) end """ - length(jtree::JunctionTree) + residual(jtree::JunctionTree, i::Integer) -Get the number of nodes in a junction tree. +Get the residual at node ``i``. """ -function Base.length(jtree::JunctionTree) - length(jtree.stree.tree) +function residual(jtree::JunctionTree, i::Integer) + view(jtree.order, residualindices(jtree, i)) end +#= """ - height(jtree::JunctionTree) + find_clique(jtree::JunctionTree, v::Integer) -Compute the height of a junction tree. +Find a node `i` safisfying `v ∈ clique(jtree, i)`. """ -function height(jtree::JunctionTree) - height(jtree.stree.tree) +function find_clique(jtree::JunctionTree, v::Integer) + jtree.partition[inv(jtree.order)[v]] end """ - width(jtree::JunctionTree) + find_clique(jtree::JunctionTree, set::AbstractVector) -Compute the width of a junction tree. +Find a node `i` satisfying `vertices ⊆ clique(jtree, i)`. """ -function width(jtree::JunctionTree) - width(jtree.stree) +function find_clique(jtree::JunctionTree, vertices::AbstractVector) + jtree.partition[minimum(view(jtree.order, vertices))] end +=# -function Base.show(io::IO, jtree::JunctionTree) - n = width(jtree) +# Multiline printing. +function Base.show(io::IO, ::MIME"text/plain", jtree::JunctionTree) + n = treewidth(jtree) print(io, "width: $n\njunction tree:\n") print_tree(io, IndexNode(jtree)) do io, node @@ -134,23 +171,82 @@ function Base.show(io::IO, jtree::JunctionTree) end +######### +# Lifts # +######### +# In principal, cliques are subsets C ⊆ V. In practice, we represent them by vectors +# C: n → V +# Given another vector S: m → V, we may wish to find a vector L: m → n satisfying +# C +# n → V +# ↖ ↑ S +# L m +# The vector L is called a lift, and we write L: S → C. + + +# Compute the lift L: seperator(i) → clique(i). This satisfies +# seperator(jtree, i) == clique(jtree, i)[lift_sep(jtree, i)] +function lift_sep(jtree::JunctionTree, i::Integer) + residual = residualindices(jtree, i) + seperator = seperatorindices(jtree, i) + length(residual) + 1:length(residual) + length(seperator) +end + +# Compute the lift L: seperator(i) → clique(parent(i)). This satisfies +# seperator(jtree, i) == clique(jtree, parentindex(jtree, i)[lift_sep_par(jtree, i)] +function lift_par(jtree::JunctionTree, i::Integer) + lift_ind(jtree, seperatorindices(jtree, i), parentindex(jtree, i)) +end + + +# Compute the lift L: vertices → clique(i). This satisfies +# vertices == clique(jtree, i)[lift(jtree, vertices, i)] +function lift(jtree::JunctionTree, vertices::AbstractVector, i::Integer) + lift_ind(jtree, view(inv(jtree.order), vertices), i) +end + + +function lift_ind(jtree::JunctionTree, indices::AbstractVector, i::Integer) + residual = residualindices(jtree, i) + seperator = seperatorindices(jtree,i) + + map(indices) do v + if v in residual + v - first(residual) + 1 + else + length(residual) + searchsortedfirst(seperator, v) + end + end +end + + ########################## # Indexed Tree Interface # ########################## +function AbstractTrees.treesize(jtree::JunctionTree) + treesize(jtree.tree) +end + + +function AbstractTrees.treeheight(jtree::JunctionTree) + treeheight(jtree.tree) +end + + function AbstractTrees.rootindex(jtree::JunctionTree) - rootindex(jtree.stree.tree) + rootindex(jtree.tree) end function AbstractTrees.parentindex(jtree::JunctionTree, i::Integer) - parentindex(jtree.stree.tree, i) + parentindex(jtree.tree, i) end function AbstractTrees.childindices(jtree::JunctionTree, i::Integer) - childindices(jtree.stree.tree, i) + childindices(jtree.tree, i) end diff --git a/src/junction_trees/ordered_graphs.jl b/src/junction_trees/ordered_graphs.jl index f06fa96..0cfeef1 100644 --- a/src/junction_trees/ordered_graphs.jl +++ b/src/junction_trees/ordered_graphs.jl @@ -1,73 +1,71 @@ -# An ordered graph (G, σ). -struct OrderedGraph - graph::Graph # graph - order::Order # permutation +""" + OrderedGraph <: AbstractOrderedGraph + +A directed simple graph whose edges ``(i, j)`` satisfy the inequality ``i < j``. +This type implements the [abstract graph interface](https://juliagraphs.org/Graphs.jl/stable/core_functions/interface/). +""" +struct OrderedGraph <: AbstractOrderedGraph + colptr::Vector{Int} + adjptr::Vector{Int} + rowval::Vector{Int} +end + + +""" + OrderedGraph(graph[, ealg::Union{Order, EliminationAlgorithm}]) + +Construct an ordered graph by permuting the vertices of a simple graph and directing them from lower to higher. +""" +function OrderedGraph(graph, ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM) + OrderedGraph(adjacencymatrix(graph), ealg) end -# Given a graph G, construct the ordered graph -# (G, σ), -# where the permutation σ is computed using an elimination algorithm. +# Construct an ordered graph by permuting the vertices of a simple graph and directing them from lower to higher. # ---------------------------------------- -# sgraph simple connected graph +# graph simple graph # ealg elimination algorithm # ---------------------------------------- -function OrderedGraph(sgraph::AbstractSymmetricGraph, ealg::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM) - OrderedGraph(sgraph, Order(sgraph, ealg)) +function OrderedGraph(graph::AbstractMatrix, ealg::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM) + OrderedGraph(graph, Order(graph, ealg)) end -# Given a graph G and permutation σ, construct the ordered graph -# (G, σ). +# Construct an ordered graph by permuting the vertices of a simple graph and directing them from lower to higher. # ---------------------------------------- -# sgraph simple connected graph +# graph simple graph # order vertex order # ---------------------------------------- -function OrderedGraph(sgraph::AbstractSymmetricGraph, order::Order) - n = nv(sgraph) - graph = Graph(n) - - for e in edges(sgraph) - u = src(sgraph, e) - v = tgt(sgraph, e) - - if order(u, v) - add_edge!(graph, inverse(order, u), inverse(order, v)) - end +function OrderedGraph(graph::AbstractSparseMatrixCSC, order::Order) + OrderedGraph(permute(graph, order, order)) +end + + +function OrderedGraph(graph::AbstractSparseMatrixCSC) + n = size(graph, 1) + colptr = Vector{Int}(undef, n) + + for i in 1:n + colptr[i] = graph.colptr[i] + searchsortedfirst(view(rowvals(graph), nzrange(graph, i)), i) - 1 end - OrderedGraph(graph, order) + OrderedGraph(colptr, graph.colptr, graph.rowval) end -# Given an ordered graph (G, σ) and permutation μ, construct the ordered graph -# (G, σ ∘ μ). -# ---------------------------------------- -# ograph ordered graph -# order permutation -# ---------------------------------------- -function OrderedGraph(ograph::OrderedGraph, order::Order) - n = nv(ograph) - graph = Graph(n) - - for e in edges(ograph) - u = src(ograph, e) - v = tgt(ograph, e) - - if order(u, v) - add_edge!(graph, inverse(order, u), inverse(order, v)) - else - add_edge!(graph, inverse(order, v), inverse(order, u)) - end - end - - OrderedGraph(graph, compose(order, ograph.order)) + +# Construct the adjacency matrix of an ordered graph. +function adjacencymatrix(graph::OrderedGraph) + SparseMatrixCSC(nv(graph), nv(graph), graph.adjptr, graph.rowval, ones(Bool, length(graph.rowval))) end # A Compact Row Storage Scheme for Cholesky Factors Using Elimination Trees # Liu # Algorithm 4.2: Elimination Tree by Path Compression. +# ---------------------------------------- +# graph simple connected graph +# ---------------------------------------- function etree(graph::OrderedGraph) n = nv(graph) parent = Vector{Int}(undef, n) @@ -80,83 +78,117 @@ function etree(graph::OrderedGraph) for k in inneighbors(graph, i) r = k - while ancestor[r] != 0 && ancestor[r] != i + while !iszero(ancestor[r]) && ancestor[r] != i t = ancestor[r] ancestor[r] = i r = t end - if ancestor[r] == 0 + if iszero(ancestor[r]) ancestor[r] = i parent[r] = i end end end - parent[n] = n - parent + Tree(parent) end - -function Base.deepcopy(ograph::OrderedGraph) - order = deepcopy(ograph.order) - graph = deepcopy(ograph.graph) - OrderedGraph(graph, order) -end - -# Get the vertex σ(i). -function permutation(ograph::OrderedGraph, i) - ograph.order[i] +# An Efficient Algorithm to Compute Row and Column Counts for Sparse Cholesky Factorization +# Gilbert, Ng, and Peyton +# Figure 3: Implementation of algorithm to compute row and column counts. +# ---------------------------------------- +# graph simple connected graph +# tree elimination tree +# ---------------------------------------- +function supcnt(graph::OrderedGraph, tree::Tree) + order = postorder(tree) + rc, cc = supcnt(OrderedGraph(graph, order), PostorderTree(tree, order)) + view(rc, inv(order)), view(cc, inv(order)) end -# Get the index σ⁻¹(v). -function inverse(ograph::OrderedGraph, v) - inverse(ograph.order, v) -end +# An Efficient Algorithm to Compute Row and Column Counts for Sparse Cholesky Factorization +# Gilbert, Ng, and Peyton +# Figure 3: Implementation of algorithm to compute row and column counts. +# ---------------------------------------- +# graph simple connected graph +# tree elimination tree +# ---------------------------------------- +function supcnt(graph::OrderedGraph, tree::PostorderTree) + n = treesize(tree) + sets = DisjointSets(n) + + prev_p = zeros(Int, n) + prev_nbr = zeros(Int, n) + rc = ones(Int, n) + wt = ones(Int, n) + for u in 1:n - 1 + wt[parentindex(tree, u)] = 0 + end + + for p in 1:n - 1 + wt[parentindex(tree, p)] -= 1 + + for u in outneighbors(graph, p) + if first(descendantindices(tree, p)) > prev_nbr[u] + wt[p] += 1 + pp = prev_p[u] + + if iszero(pp) + rc[u] += level(tree, p) - level(tree, u) + else + q = find!(sets, pp) + rc[u] += level(tree, p) - level(tree, q) + wt[q] -= 1 + end + + prev_p[u] = p + end -############################ -# Abstract Graph Interface # -############################ + prev_nbr[u] = p + end + union!(sets, p, parentindex(tree, p)) + end -function BasicGraphs.ne(ograph::OrderedGraph) - ne(ograph.graph) -end + cc = wt + for v in 1:n - 1 + cc[parentindex(tree, v)] += cc[v] + end -function BasicGraphs.nv(ograph::OrderedGraph) - nv(ograph.graph) + rc, cc end -function BasicGraphs.inneighbors(ograph::OrderedGraph, i) - inneighbors(ograph.graph, i) -end +############################ +# Abstract Graph Interface # +############################ -function BasicGraphs.outneighbors(ograph::OrderedGraph, i) - outneighbors(ograph.graph, i) +function SimpleGraphs.ne(graph::OrderedGraph) + (last(graph.adjptr) - 1) ÷ 2 end -function BasicGraphs.edges(ograph::OrderedGraph) - edges(ograph.graph) +function SimpleGraphs.nv(graph::OrderedGraph) + length(graph.adjptr) - 1 end -function BasicGraphs.vertices(ograph::OrderedGraph) - vertices(ograph.graph) +function SimpleGraphs.badj(graph::OrderedGraph, i::Integer) + view(graph.rowval, graph.adjptr[i]:graph.colptr[i] - 1) end - -function BasicGraphs.src(ograph::OrderedGraph, i) - src(ograph.graph, i) + +function SimpleGraphs.fadj(graph::OrderedGraph, i::Integer) + view(graph.rowval, graph.colptr[i]:graph.adjptr[i + 1] - 1) end -function BasicGraphs.tgt(ograph::OrderedGraph, i) - tgt(ograph.graph, i) +function SimpleGraphs.all_neighbors(graph::OrderedGraph, i::Integer) + view(graph.rowval, graph.adjptr[i]:graph.adjptr[i + 1] - 1) end diff --git a/src/junction_trees/orders.jl b/src/junction_trees/orders.jl index 849f06e..f6763ce 100644 --- a/src/junction_trees/orders.jl +++ b/src/junction_trees/orders.jl @@ -1,7 +1,8 @@ """ Order <: AbstractVector{Int} -A permutation of the set ``\\{1, \\dots, n\\}.`` +A [permutation](https://en.wikipedia.org/wiki/Permutation) ``\\sigma`` of the set ``\\{1, \\dots, n\\}``. +This type implements the [abstract array interface](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array). """ struct Order <: AbstractVector{Int} order::Vector{Int} # permutation @@ -12,25 +13,18 @@ end """ Order(order::AbstractVector) -Construct a permutation ``\\sigma`` from a sequence ``(\\sigma(1), \\dots, \\sigma(n)).`` +Construct a permutation ``\\sigma`` from a sequence ``(\\sigma(1), \\dots, \\sigma(n))``. """ function Order(order::AbstractVector) n = length(order) index = Vector{Int}(undef, n) - - for i in 1:n - index[order[i]] = i - end - + index[order] = 1:n Order(order, index) end -# Determine if i < j, where -# u = σ(i) -# v = σ(j) -function (order::Order)(u, v) - inverse(order, u) < inverse(order, v) +function Order(::UndefInitializer, n::Integer) + Order(Vector{Int}(undef, n), Vector{Int}(undef, n)) end @@ -40,15 +34,19 @@ function compose(left::Order, right::Order) end -# Construct the inverse permutation. -function inverse(order::Order) - Order(order.index, order.order) +# Construct a copy of a permutation. +function Base.copy(order::Order) + Order(order.order, order.index) end -# Get the index σ⁻¹(v), -function inverse(order::Order, v) - order.index[v] +""" + inverse(order::Order) + +Construct the inverse permutation ``\\sigma^{-1}``. +""" +function Base.inv(order::Order) + Order(order.index, order.order) end @@ -62,6 +60,12 @@ function Base.getindex(order::Order, i) end +function Base.setindex!(order::Order, v, i) + order.index[v] = i + order.order[i] = v +end + + function Base.IndexStyle(::Type{Order}) IndexLinear() end @@ -70,8 +74,3 @@ end function Base.size(order::Order) (length(order.order),) end - - -function Base.deepcopy(order::Order) - Order(copy(order.order), copy(order.index)) -end diff --git a/src/junction_trees/postorder_trees.jl b/src/junction_trees/postorder_trees.jl index 23ed2d8..dd8c523 100644 --- a/src/junction_trees/postorder_trees.jl +++ b/src/junction_trees/postorder_trees.jl @@ -1,9 +1,9 @@ # A postordered rooted tree. +# This type implements the indexed tree interface. struct PostorderTree - parent::Vector{Int} # parent - children::Vector{Vector{Int}} # children - level::Vector{Int} # level - descendant::Vector{Int} # first descendant + tree::Tree # rooted tree + level::Vector{Int} # vector of levels + fdesc::Vector{Int} # vector of first descendants end @@ -12,29 +12,29 @@ end # parent list of parents # ---------------------------------------- function PostorderTree(parent::AbstractVector) - n = length(parent) - children = Vector{Vector{Int}}(undef, n) + tree = Tree(parent) + n = treesize(tree) level = Vector{Int}(undef, n) - descendant = Vector{Int}(undef, n) - - for i in 1:n - children[i] = [] - level[i] = 0 - descendant[i] = i - end - - for i in 1:n - 1 - j = parent[i] - push!(children[j], i) - descendant[j] = min(descendant[i], descendant[j]) - end - + fdesc = Vector{Int}(undef, n) + level[n] = 0 + fdesc[n] = 0 + for i in n - 1:-1:1 - j = parent[i] + j = parentindex(tree, i) level[i] = level[j] + 1 + fdesc[i] = i + fdesc[j] = 0 + end + + for i in 1:n - 1 + j = parentindex(tree, i) + + if iszero(fdesc[j]) + fdesc[j] = fdesc[i] + end end - PostorderTree(parent, children, level, descendant) + PostorderTree(tree, level, fdesc) end @@ -44,66 +44,56 @@ end # order postorder # ---------------------------------------- function PostorderTree(tree::Tree, order::Order) - n = length(tree) - parent = collect(1:n) + n = treesize(tree) + parent = Vector{Int}(undef, n) for i in 1:n - 1 - parent[i] = inverse(order, parentindex(tree, order[i])) + parent[i] = inv(order)[parentindex(tree, order[i])] end + parent[n] = 0 PostorderTree(parent) end -# The number of node in a tree. -function Base.length(tree::PostorderTree) - length(tree.parent) -end - - # Get the level of a node i. function level(tree::PostorderTree, i::Integer) tree.level[i] end -# Get the first descendant of a node i. -function firstdescendant(tree::PostorderTree, i::Integer) - tree.descendant[i] +function descendantindices(tree::PostorderTree, i::Integer) + tree.fdesc[i]:i - 1 end -# Determine whether the node i is a descendant of the node j. -function isdescendant(tree::PostorderTree, i::Integer, j::Integer) - getdescendant(tree, j) <= i < j -end +########################## +# Indexed Tree Interface # +########################## -# Get the height of a tree. -function height(tree::PostorderTree) - maximum(tree.level) +function AbstractTrees.treesize(tree::PostorderTree) + treesize(tree.tree) end -########################## -# Indexed Tree Interface # -########################## +function AbstractTrees.treeheight(tree::PostorderTree) + maximum(tree.level) +end function AbstractTrees.parentindex(tree::PostorderTree, i::Integer) - if i != rootindex(tree) - tree.parent[i] - end + parentindex(tree.tree, i) end function AbstractTrees.childindices(tree::PostorderTree, i::Integer) - tree.children[i] + childindices(tree.tree, i) end function AbstractTrees.rootindex(tree::PostorderTree) - length(tree) + rootindex(tree.tree) end diff --git a/src/junction_trees/supernode_trees.jl b/src/junction_trees/supernode_trees.jl deleted file mode 100644 index cf7ff0e..0000000 --- a/src/junction_trees/supernode_trees.jl +++ /dev/null @@ -1,96 +0,0 @@ -# An ordered graph (G, σ) equipped with a supernodal elimination tree T. -struct SupernodeTree - tree::PostorderTree # supernodal elimination tree - graph::OrderedGraph # ordered graph - representative::Vector{Int} # representative vertex - cardinality::Vector{Int} # supernode cardinality - ancestor::Vector{Int} # first ancestor - degree::Vector{Int} # higher degrees -end - - -# Construct a supernodal elimination tree using an elimination algorithm. -# ---------------------------------------- -# graph simple connected graph -# ealg elimination algorithm -# stype supernode type -# ---------------------------------------- -function SupernodeTree( - graph::AbstractSymmetricGraph, - ealg::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, - stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) - - SupernodeTree(EliminationTree(graph, ealg), stype) -end - - -# Construct a supernodal elimination tree. -# ---------------------------------------- -# etree elimination tree -# stype supernode type -# ---------------------------------------- -function SupernodeTree(etree::EliminationTree, stype::SupernodeType=DEFAULT_SUPERNODE_TYPE) - degree = outdegrees(etree) - supernode, parent, ancestor = stree(etree, degree, stype) - tree = Tree(parent) - - order = postorder(tree) - tree = PostorderTree(tree, order) - permute!(supernode, order) - permute!(ancestor, order) - - order = Order(vcat(supernode...)) - graph = OrderedGraph(etree.graph, order) - permute!(degree, order) - - representative = map(first, supernode) - cardinality = map(length, supernode) - map!(i -> inverse(order, i), ancestor, ancestor) - map!(i -> inverse(order, i), representative, representative) - - SupernodeTree(tree, graph, representative, cardinality, ancestor, degree) -end - - -# Compute the width of a supernodal elimination tree. -function width(stree::SupernodeTree) - maximum(stree.degree[stree.representative]) -end - - -# Get the (sorted) supernode at node i. -function supernode(stree::SupernodeTree, i::Integer) - v = stree.representative[i] - n = stree.cardinality[i] - v:v + n - 1 -end - - -# Compute the (unsorted) seperators of every node in T. -function seperators(stree::SupernodeTree) - n = length(stree.tree) - seperator = Vector{Set{Int}}(undef, n) - - for i in 1:n - 1 - seperator[i] = Set(stree.ancestor[i]) - - for v in outneighbors(stree.graph, stree.representative[i]) - if stree.ancestor[i] < v - push!(seperator[i], v) - end - end - end - - for j in 1:n - 1 - for i in childindices(stree.tree, j) - for v in seperator[i] - if stree.ancestor[j] < v - push!(seperator[j], v) - end - end - end - end - - seperator[n] = Set() - seperator -end diff --git a/src/junction_trees/supernode_types.jl b/src/junction_trees/supernode_types.jl index 760284a..d19bcf5 100644 --- a/src/junction_trees/supernode_types.jl +++ b/src/junction_trees/supernode_types.jl @@ -3,7 +3,7 @@ A type of supernode. The options are - [`Node`](@ref) -- [`Maximal](@ref) +- [`Maximal`](@ref) - [`Fundamental`](@ref) """ abstract type SupernodeType end @@ -36,8 +36,8 @@ struct Fundamental <: SupernodeType end # Compact Clique Tree Data Structures in Sparse Matrix Factorizations # Pothen and Sun # Figure 4: The Clique Tree Algorithm 2 -function stree(etree::EliminationTree, degree::AbstractVector, stype::SupernodeType) - n = length(etree.tree) +function cta(tree::Tree, colcount::AbstractVector, stype::SupernodeType) + n = treesize(tree) new_in_clique = Vector{Int}(undef, n) new = Vector{Int}[] parent = Int[] @@ -46,7 +46,7 @@ function stree(etree::EliminationTree, degree::AbstractVector, stype::SupernodeT i = 0 for v in 1:n - u = child_in_supernode(etree, degree, stype, v) + u = child_in_supernode(tree, colcount, stype, v) if !isnothing(u) new_in_clique[v] = new_in_clique[u] @@ -54,11 +54,11 @@ function stree(etree::EliminationTree, degree::AbstractVector, stype::SupernodeT else new_in_clique[v] = i += 1 push!(new, [v]) - push!(parent, i) + push!(parent, 0) push!(first_anc, n) end - for s in childindices(etree.tree, v) + for s in childindices(tree, v) if s !== u parent[new_in_clique[s]] = new_in_clique[v] first_anc[new_in_clique[s]] = v @@ -66,41 +66,53 @@ function stree(etree::EliminationTree, degree::AbstractVector, stype::SupernodeT end end - new, parent, first_anc + new, Tree(parent) end -# Find a child w of v such that -# v ∈ snd(w). +function stree(graph::OrderedGraph, tree::Tree, colcount::AbstractVector, stype::SupernodeType) + supernode, tree = cta(tree, colcount, stype) + order = postorder(tree) + view(supernode, order), PostorderTree(tree, order) +end + + +# Find a child w of v such that v ∈ supernode(w). # If no such child exists, return nothing. -function child_in_supernode(etree::EliminationTree, degree::AbstractVector, stype::Node, v::Integer) end +function child_in_supernode(tree::Tree, colcount::AbstractVector, stype::Node, v::Integer) end -# Find a child w of v such that -# v ∈ snd(w). +# Find a child w of v such that v ∈ supernode(w). # If no such child exists, return nothing. -function child_in_supernode(etree::EliminationTree, degree::AbstractVector, stype::Maximal, v::Integer) - for w in childindices(etree.tree, v) - if degree[w] == degree[v] + 1 - return w +function child_in_supernode(tree::Tree, colcount::AbstractVector, stype::Maximal, v::Integer) + u = nothing + + for w in childindices(tree, v) + if colcount[w] == colcount[v] + 1 + u = w + break end end + + u end -# Find a child w of v such that -# v ∈ snd(w). +# Find a child w of v such that v ∈ supernode(w). # If no such child exists, return nothing. -function child_in_supernode(etree::EliminationTree, degree::AbstractVector, stype::Fundamental, v::Integer) - ws = childindices(etree.tree, v) - - if length(ws) == 1 - w = only(ws) +function child_in_supernode(tree::Tree, colcount::AbstractVector, stype::Fundamental, v::Integer) + u = nothing - if degree[w] == degree[v] + 1 - return w + for w in childindices(tree, v) + if isnothing(u) && colcount[w] == colcount[v] + 1 + u = w + else + u = nothing + break end end + + u end diff --git a/src/junction_trees/trees.jl b/src/junction_trees/trees.jl index 759d19f..7fbd994 100644 --- a/src/junction_trees/trees.jl +++ b/src/junction_trees/trees.jl @@ -1,8 +1,10 @@ # A rooted tree. +# This type implements the indexed tree interface. struct Tree - root::Int # root - parent::Vector{Int} # parent - children::Vector{Vector{Int}} # children + parent::Vector{Int} # vector of parents + child::Vector{Int} # vector of left-children + brother::Vector{Int} # vector of right-siblings + root::Int # root end @@ -11,46 +13,48 @@ end # parent list of parents # ---------------------------------------- function Tree(parent::AbstractVector) - n = root = length(parent) - children = Vector{Vector{Int}}(undef, n) - - for i in 1:n - children[i] = [] - end + n = length(parent) + child = zeros(Int, n) + brother = zeros(Int, n) + root = n - for i in 1:n - j = parent[i] - - if i == j + for i in n:-1:1 + if iszero(parent[i]) root = i else - push!(children[j], i) + brother[i] = child[parent[i]] + child[parent[i]] = i end end - Tree(root, parent, children) -end - - -# Get the number of nodes in a tree. -function Base.length(tree::Tree) - length(tree.parent) + Tree(parent, child, brother, root) end # Compute a postordering of tree's vertices. function postorder(tree::Tree) - n = length(tree) - order = Vector{Int}(undef, n) - index = Vector{Int}(undef, n) - - for node in PreOrderDFS(IndexNode(tree)) - order[n] = node.index - index[node.index] = n - n -= 1 + n = treesize(tree) + stack = FixedStack{Int}(n) + push!(stack, rootindex(tree)) + order = Order(undef, n) + child = copy(tree.child) + count = 1 + + while !isempty(stack) + j = pop!(stack) + i = child[j] + + if iszero(i) + order[count] = j + count += 1 + else + child[j] = tree.brother[i] + push!(stack, j) + push!(stack, i) + end end - - Order(order, index) + + order end @@ -59,20 +63,36 @@ end ########################## +function AbstractTrees.treesize(tree::Tree) + length(tree.parent) +end + + function AbstractTrees.rootindex(tree::Tree) tree.root end function AbstractTrees.parentindex(tree::Tree, i::Integer) - if i != rootindex(tree) - tree.parent[i] + j = tree.parent[i] + + if !iszero(j) + j + end +end + + +function AbstractTrees.nextsiblingindex(tree::Tree, i::Integer) + j = tree.brother[i] + + if !iszero(j) + j end end -function AbstractTrees.childindices(tree::Tree, i::Integer) - tree.children[i] +function AbstractTrees.SiblingLinks(::Type{IndexNode{Tree, Int}}) + StoredSiblings() end diff --git a/test/Decompositions.jl b/test/Decompositions.jl index c9cb286..8ea00ec 100644 --- a/test/Decompositions.jl +++ b/test/Decompositions.jl @@ -114,47 +114,47 @@ add_edges!(graph, decomposition = StrDecomp(graph, Order(1:17), Maximal()) + @test decomposition.decomp_shape == @acset Graph begin V = 8 E = 7 src = [1, 2, 3, 4, 5, 6, 7] - tgt = [5, 5, 4, 5, 6, 8, 8] + tgt = [8, 3, 6, 6, 6, 7, 8] end @test map(i -> ob_map(decomposition.diagram, i), 1:15) == [ - induced_subgraph(graph, [7, 8, 9, 15]), # g h i o - induced_subgraph(graph, [6, 9, 16]), # f i p + induced_subgraph(graph, [10, 11, 13, 14, 17]), # j k m n q induced_subgraph(graph, [2, 3, 4]), # b c d induced_subgraph(graph, [1, 3, 4, 5, 15]), # a c d e o + induced_subgraph(graph, [6, 9, 16]), # f i p + induced_subgraph(graph, [7, 8, 9, 15]), # g h i o induced_subgraph(graph, [5, 9, 15, 16]), # e i o p induced_subgraph(graph, [15, 16, 17]), # o p q - induced_subgraph(graph, [10, 11, 13, 14, 17]), # j k m n q induced_subgraph(graph, [12, 13, 14, 16, 17]), # l m n p q - induced_subgraph(graph, [9, 15]), # i o - induced_subgraph(graph, [9, 16]), # i p + induced_subgraph(graph, [13, 14, 17]), # m n q induced_subgraph(graph, [3, 4]), # c d induced_subgraph(graph, [5, 15]), # e o + induced_subgraph(graph, [9, 16]), # i p + induced_subgraph(graph, [9, 15]), # i o induced_subgraph(graph, [15, 16]), # o p induced_subgraph(graph, [16, 17]), # p q - induced_subgraph(graph, [13, 14, 17]), # m n q ] @test map(i -> hom_map(decomposition.diagram, i), 1:14) == [ - ACSetTransformation(induced_subgraph(graph, [9, 15]), induced_subgraph(graph, [5, 9, 15, 16]), V=[2, 3], E=Int[]), # i o → e i o p - ACSetTransformation(induced_subgraph(graph, [9, 16]), induced_subgraph(graph, [5, 9, 15, 16]), V=[2, 4], E=Int[]), # i p → e i o p + ACSetTransformation(induced_subgraph(graph, [13, 14, 17]), induced_subgraph(graph, [12, 13, 14, 16, 17]), V=[2, 3, 5], E=Int[]), # m n q → l m n p q ACSetTransformation(induced_subgraph(graph, [3, 4]), induced_subgraph(graph, [1, 3, 4, 5, 15]), V=[2, 3], E=Int[]), # c d → a c d e o ACSetTransformation(induced_subgraph(graph, [5, 15]), induced_subgraph(graph, [5, 9, 15, 16]), V=[1, 3], E=Int[]), # e o → e i o p + ACSetTransformation(induced_subgraph(graph, [9, 16]), induced_subgraph(graph, [5, 9, 15, 16]), V=[2, 4], E=Int[]), # i p → e i o p + ACSetTransformation(induced_subgraph(graph, [9, 15]), induced_subgraph(graph, [5, 9, 15, 16]), V=[2, 3], E=Int[]), # i o → e i o p ACSetTransformation(induced_subgraph(graph, [15, 16]), induced_subgraph(graph, [15, 16, 17]), V=[1, 2], E=Int[]), # o p → o p q ACSetTransformation(induced_subgraph(graph, [16, 17]), induced_subgraph(graph, [12, 13, 14, 16, 17]), V=[4, 5], E=Int[]), # p q → l m n p q - ACSetTransformation(induced_subgraph(graph, [13, 14, 17]), induced_subgraph(graph, [12, 13, 14, 16, 17]), V=[2, 3, 5], E=Int[]), # m n q → l m n p q - ACSetTransformation(induced_subgraph(graph, [9, 15]), induced_subgraph(graph, [7, 8, 9, 15]), V=[3, 4], E=Int[]), # i o → g h i o - ACSetTransformation(induced_subgraph(graph, [9, 16]), induced_subgraph(graph, [6, 9, 16]), V=[2, 3], E=Int[]), # i p → f i p + ACSetTransformation(induced_subgraph(graph, [13, 14, 17]), induced_subgraph(graph, [10, 11, 13, 14, 17]), V=[3, 4, 5], E=Int[]), # m n q → j k m n q ACSetTransformation(induced_subgraph(graph, [3, 4]), induced_subgraph(graph, [2, 3, 4]), V=[2, 3], E=Int[]), # c d → b c d ACSetTransformation(induced_subgraph(graph, [5, 15]), induced_subgraph(graph, [1, 3, 4, 5, 15]), V=[4, 5], E=Int[]), # e o → a c d e o + ACSetTransformation(induced_subgraph(graph, [9, 16]), induced_subgraph(graph, [6, 9, 16]), V=[2, 3], E=Int[]), # i p → f i p + ACSetTransformation(induced_subgraph(graph, [9, 15]), induced_subgraph(graph, [7, 8, 9, 15]), V=[3, 4], E=Int[]), # i o → g h i o ACSetTransformation(induced_subgraph(graph, [15, 16]), induced_subgraph(graph, [5, 9, 15, 16]), V=[3, 4], E=Int[]), # o p → e i o p ACSetTransformation(induced_subgraph(graph, [16, 17]), induced_subgraph(graph, [15, 16, 17]), V=[2, 3], E=Int[]), # p q → o p q - ACSetTransformation(induced_subgraph(graph, [13, 14, 17]), induced_subgraph(graph, [10, 11, 13, 14, 17]), V=[3, 4, 5], E=Int[]), # m n q → j k m n q ] - end diff --git a/test/JunctionTrees.jl b/test/JunctionTrees.jl index 89ecf85..4957084 100644 --- a/test/JunctionTrees.jl +++ b/test/JunctionTrees.jl @@ -36,21 +36,21 @@ order = JunctionTrees.Order(1:17) # Figure 4.3 jtree = JunctionTree(graph, order, Node()) -@test width(jtree) == 4 -@test height(jtree) == 7 -@test length(jtree) == 17 +@test treewidth(jtree) == 4 +@test treeheight(jtree) == 7 +@test treesize(jtree) == 17 @test map(i -> parentindex(jtree, i), 1:17) == [ 2, - 9, - 9, - 6, - 6, - 7, + 4, + 4, + 5, + 16, + 8, 8, 9, 10, - 16, + 14, 14, 13, 14, @@ -60,62 +60,62 @@ jtree = JunctionTree(graph, order, Node()) nothing, ] -@test map(i -> childindices(jtree, i), 1:17) == [ +@test map(i -> collect(childindices(jtree, i)), 1:17) == [ [], [1], [], + [2, 3], + [4], [], [], - [4, 5], - [6], - [7], - [2, 3, 8], + [6, 7], + [8], [9], [], [], [12], - [11, 13], + [10, 11, 13], [14], - [10, 15], + [5, 15], [16], ] @test map(i -> residual(jtree, i), 1:17) == [ - [7], # g - [8], # h - [6], # f - [2], # b + [10], # j + [11], # k + [12], # l + [13], # m + [14], # n [1], # a + [2], # b [3], # c [4], # d [5], # e + [6], # f + [7], # g + [8], # h [9], # i [15], # o - [12], # l - [10], # j - [11], # k - [13], # m - [14], # n [16], # p [17], # q ] @test map(i -> seperator(jtree, i), 1:17) == [ - [8, 9, 15], # h i o - [9, 15], # i o - [9, 16], # i p - [3, 4], # c d + [11, 13, 14, 17], # k m n q + [13, 14, 17], # m n q + [13, 14, 16, 17], # m n p q + [14, 16, 17], # n p q + [16, 17], # p q [3, 4, 5, 15], # c d e o + [3, 4], # c d [4, 5, 15], # d e o [5, 15], # e o [9, 15, 16], # i o p + [9, 16], # i p + [8, 9, 15], # h i o + [9, 15], # i o [15, 16], # o p [16, 17], # p q - [13, 14, 16, 17], # m n p q - [11, 13, 14, 17], # k m n q - [13, 14, 17], # m n q - [14, 16, 17], # n p q - [16, 17], # p q [17], # q [], # ] @@ -123,116 +123,116 @@ jtree = JunctionTree(graph, order, Node()) # Figure 4.7 (left) jtree = JunctionTree(graph, order, Maximal()) -@test width(jtree) == 4 -@test height(jtree) == 4 -@test length(jtree) == 8 +@test treewidth(jtree) == 4 +@test treeheight(jtree) == 4 +@test treesize(jtree) == 8 @test map(i -> parentindex(jtree, i), 1:8) == [ - 5, - 5, - 4, - 5, - 6, 8, + 3, + 6, + 6, + 6, + 7, 8, nothing ] -@test map(i -> childindices(jtree, i), 1:8) == [ +@test map(i -> collect(childindices(jtree, i)), 1:8) == [ [], [], + [2], [], - [3], - [1, 2, 4], - [5], [], - [6, 7], + [3, 4, 5], + [6], + [1, 7], ] @test map(i -> residual(jtree, i), 1:8) == [ - [7, 8], # g h - [6], # f + [10, 11], # j k [2], # b [1, 3, 4], # a c d + [6], # f + [7, 8], # g h [5, 9], # e i [15], # o - [10, 11], # j k [12, 13, 14, 16, 17], # l m n p q ] @test map(i -> seperator(jtree, i), 1:8) == [ - [9, 15], # i o - [9, 16], # i p + [13, 14, 17], # m n q [3, 4], # c d [5, 15], # e o + [9, 16], # i p + [9, 15], # i o [15, 16], # o p [16, 17], # p q - [13, 14, 17], # m n q [], # ] # Figure 4.9 jtree = JunctionTree(graph, order, Fundamental()) -@test width(jtree) == 4 -@test height(jtree) == 5 -@test length(jtree) == 12 +@test treewidth(jtree) == 4 +@test treeheight(jtree) == 5 +@test treesize(jtree) == 12 @test map(i -> parentindex(jtree, i), 1:12) == [ - 7, - 7, - 5, - 5, + 3, + 3, + 12, + 6, 6, 7, - 8, - 12, - 11, + 10, + 10, + 10, 11, 12, nothing, ] -@test map(i -> childindices(jtree, i), 1:12) == [ +@test map(i -> collect(childindices(jtree, i)), 1:12) == [ [], [], + [1, 2], [], [], - [3, 4], - [5], - [1, 2, 6], - [7], + [4, 5], + [6], [], [], - [9, 10], - [8, 11], + [7, 8, 9], + [10], + [3, 11], ] @test map(i -> residual(jtree, i), 1:12) == [ - [7, 8], # g h - [6], # f - [2], # b + [10, 11], # j k + [12], # l + [13, 14], # m n [1], # a + [2], # b [3, 4], # c d [5], # e + [6], # f + [7, 8], # g h [9], # i [15], # o - [12], # l - [10, 11], # j k - [13, 14], # m n [16, 17], # p q ] @test map(i -> seperator(jtree, i), 1:12) == [ - [9, 15], # i o - [9, 16], # i p - [3, 4], # c d + [13, 14, 17], # m n q + [13, 14, 16, 17], # m n p q + [16, 17], # p q [3, 4, 5, 15], # c d e o + [3, 4], # c d [5, 15], # e o [9, 15, 16], # i o p + [9, 16], # i p + [9, 15], # i o [15, 16], # o p [16, 17], # p q - [13, 14, 16, 17], # m n p q - [13, 14, 17], # m n q - [16, 17], # p q [], # ]