Skip to content

Commit

Permalink
Removed context from InferenceProblem constructor.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelsonric committed Sep 21, 2023
1 parent 84a581f commit ab40f0e
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 298 deletions.
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mean(::GaussianSystem)
```@docs
InferenceProblem
InferenceProblem(::RelationDiagram, ::AbstractDict, ::AbstractDict, ::AbstractDict)
InferenceProblem(::RelationDiagram, ::AbstractDict, ::AbstractDict)
InferenceProblem(::BayesNet, ::AbstractVector, ::AbstractDict)
solve(::InferenceProblem, ::EliminationAlgorithm, ::SupernodeType, ::ArchitectureType)
Expand Down
2 changes: 1 addition & 1 deletion src/AlgebraicInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export ⊗, cov, invcov, normal, kernel, mean, oapply, var

# Inference Problems
export InferenceProblem
export init
export init, reduce_to_context


# Inference Solvers
Expand Down
170 changes: 100 additions & 70 deletions src/architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ end
# Construct an architecture.
function Architecture(
model::GraphicalModel,
elalg::EliminationAlgorithm,
stype::SupernodeType)
elimination_algorithm::EliminationAlgorithm,
supernode_type::SupernodeType)

labels = model.labels
factors = model.factors
tree = JoinTree(model.graph, elalg, stype)
tree = JoinTree(model.graph, elimination_algorithm, supernode_type)

vvll = deepcopy(model.vvll)
assignments = Vector{Vector{Int}}(undef, length(tree))
Expand Down Expand Up @@ -106,23 +106,29 @@ end

# Answer a query.
# Algorithm 4.2 in doi:10.1002/9781118010877.
function CommonSolve.solve!(arch::Architecture, atype::ShenoyShafer, query)
arch.collect_phase_complete || collect_phase!(arch, atype)
function CommonSolve.solve!(
architecture::Architecture,
architecture_type::ShenoyShafer,
query::AbstractVector)

vars = [arch.labels.index[l] for l in query]
if !architecture.collect_phase_complete
collect_phase!(architecture, architecture_type)
end

vars = [architecture.labels.index[l] for l in query]

for n in arch.tree.order
node = IndexNode(arch.tree, n)
for n in architecture.tree.order
node = IndexNode(architecture.tree, n)
sep, res = nodevalue(node)

if vars [sep; res]
distribute_phase!(arch, atype, node.index)
distribute_phase!(architecture, architecture_type, node.index)

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
fac = combine(mbx.factor, mbx.message_from_parent)

for child in children(node)
mbx = mailbox(arch, child.index)
mbx = mailbox(architecture, child.index)
fac = combine(fac, mbx.message_to_parent)
end

Expand All @@ -139,19 +145,25 @@ end

# Answer a query.
# Algorithm 4.4 in doi:10.1002/9781118010877.
function CommonSolve.solve!(arch::Architecture, atype::LauritzenSpiegelhalter, query)
arch.collect_phase_complete || collect_phase!(arch, atype)
function CommonSolve.solve!(
architecture::Architecture,
architecture_type::LauritzenSpiegelhalter,
query::AbstractVector)

if !architecture.collect_phase_complete
collect_phase!(architecture, architecture_type)
end

vars = [arch.labels.index[l] for l in query]
vars = [architecture.labels.index[l] for l in query]

for n in arch.tree.order
node = IndexNode(arch.tree, n)
for n in architecture.tree.order
node = IndexNode(architecture.tree, n)
sep, res = nodevalue(node)

if vars [sep; res]
distribute_phase!(arch, atype, node.index)
distribute_phase!(architecture, architecture_type, node.index)

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
fac = combine(mbx.cpd, mbx.message_from_parent)

fac = project(fac, vars)
Expand All @@ -166,19 +178,20 @@ end


# Sample from an architecture.
function Base.rand(rng::AbstractRNG, arch::Architecture)
@assert arch.collect_phase_complete
function Base.rand(rng::AbstractRNG, architecture::Architecture)
@assert architecture.collect_phase_complete

x = Vector{Vector{Float64}}(undef, length(arch.labels))
m = length(architecture.labels)
x = Vector{Vector{Float64}}(undef, m)

for n in reverse(arch.tree.order)
node = IndexNode(arch.tree, n)
for n in reverse(architecture.tree.order)
node = IndexNode(architecture.tree, n)

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
rand!(rng, mbx.cpd, x)
end

Dict(zip(arch.labels, x))
Dict(zip(architecture.labels, x))
end


Expand All @@ -188,127 +201,144 @@ end


# Compute the mean of an architecture.
function Statistics.mean(arch::Architecture)
@assert arch.collect_phase_complete
function Statistics.mean(architecture::Architecture)
@assert architecture.collect_phase_complete

x = Vector{Vector{Float64}}(undef, length(arch.labels))
m = length(architecture.labels)
x = Vector{Vector{Float64}}(undef, m)

for n in reverse(arch.tree.order)
node = IndexNode(arch.tree, n)
for n in reverse(architecture.tree.order)
node = IndexNode(architecture.tree, n)

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
mean!(mbx.cpd, x)
end

Dict(zip(arch.labels, x))
Dict(zip(architecture.labels, x))
end


# The collect phase of the Shenoy-Shafer architecture.
# Algorithm 4.1 in doi:10.1002/9781118010877.
function collect_phase!(arch::Architecture{<:Any, T₁, T₂}, atype::ShenoyShafer) where {T₁, T₂}
for n in arch.tree.order
node = IndexNode(arch.tree, n)
function collect_phase!(
architecture::Architecture{<:Any, T₁, T₂},
architecture_type::ShenoyShafer) where {T₁, T₂}

mbx = mailbox(arch, node.index)
mbx.factor = factor(arch, node.index)
for n in architecture.tree.order
node = IndexNode(architecture.tree, n)

mbx = mailbox(architecture, node.index)
mbx.factor = factor(architecture, node.index)
msg = mbx.factor

for child in children(node)
mbx = mailbox(arch, child.index)
mbx = mailbox(architecture, child.index)
msg = combine(msg, mbx.message_to_parent)
end

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
mbx.message_to_parent, mbx.cpd = disintegrate(msg, first(nodevalue(node)))
end

mbx = mailbox(arch, rootindex(arch.tree))
mbx = mailbox(architecture, rootindex(architecture.tree))
mbx.message_from_parent = zero(Factor{T₁, T₂})
arch.collect_phase_complete = true
architecture.collect_phase_complete = true
end


# The collect phase of the Lauritzen-Spiegelhalter architecture.
# Algorithm 4.3 in doi:10.1002/9781118010877.
function collect_phase!(arch::Architecture{<:Any, T₁, T₂}, atype::LauritzenSpiegelhalter) where {T₁, T₂}
for n in arch.tree.order
node = IndexNode(arch.tree, n)
function collect_phase!(
architecture::Architecture{<:Any, T₁, T₂},
architecture_type::LauritzenSpiegelhalter) where {T₁, T₂}

for n in architecture.tree.order
node = IndexNode(architecture.tree, n)

msg = factor(arch, node.index)
msg = factor(architecture, node.index)

for child in children(node)
mbx = mailbox(arch, child.index)
mbx = mailbox(architecture, child.index)
msg = combine(msg, mbx.message_to_parent)
end

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
mbx.message_to_parent, mbx.cpd = disintegrate(msg, first(nodevalue(node)))
end

mbx = mailbox(arch, rootindex(arch.tree))
mbx = mailbox(architecture, rootindex(architecture.tree))
mbx.message_to_parent = nothing
mbx.message_from_parent = zero(Factor{T₁, T₂})
arch.collect_phase_complete = true
architecture.collect_phase_complete = true
end


# The distribute phase of the Shenoy-Shafer architecture.
# The distribute phase of the Shenoy-Shafer architecture. Only distributes from the root to
# node n.
# Algorithm 4.1 in doi:10.1002/9781118010877.
function distribute_phase!(arch::Architecture, atype::ShenoyShafer, n::Integer)
node = IndexNode(arch.tree, n)
mbx = mailbox(arch, node.index)
function distribute_phase!(
architecture::Architecture,
architecture_type::ShenoyShafer,
n::Integer)

node = IndexNode(architecture.tree, n)
mbx = mailbox(architecture, node.index)

ancestors = Int[]

while !isroot(node) && isnothing(mbx.message_from_parent)
push!(ancestors, node.index)
node = parent(node)
mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
end

for n in ancestors[end:-1:1]
node = IndexNode(arch.tree, n)
node = IndexNode(architecture.tree, n)
prnt = parent(node)

mbx = mailbox(arch, prnt.index)
mbx = mailbox(architecture, prnt.index)
msg = combine(mbx.factor, mbx.message_from_parent)

for sibling in children(prnt)
if node != sibling
mbx = mailbox(arch, sibling.index)
mbx = mailbox(architecture, sibling.index)
msg = combine(msg, mbx.message_to_parent)
end
end

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
mbx.message_from_parent = project(msg, first(nodevalue(node)))
end
end


# The distribute phase of the Lauritzen-Spiegelhalter architecture.
# The distribute phase of the Lauritzen-Spiegelhalter architecture. Only distributes from
# the root to node n.
# Algorithm 4.3 in doi:10.1002/9781118010877.
function distribute_phase!(arch::Architecture, atype::LauritzenSpiegelhalter, n::Integer)
node = IndexNode(arch.tree, n)
mbx = mailbox(arch, node.index)
function distribute_phase!(
architecture::Architecture,
architecture_type::LauritzenSpiegelhalter,
n::Integer)

node = IndexNode(architecture.tree, n)
mbx = mailbox(architecture, node.index)

ancestors = Int[]

while !isroot(node) && isnothing(mbx.message_from_parent)
push!(ancestors, node.index)
node = parent(node)
mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
end

for n in ancestors[end:-1:1]
node = IndexNode(arch.tree, n)
node = IndexNode(architecture.tree, n)
prnt = parent(node)

mbx = mailbox(arch, prnt.index)
mbx = mailbox(architecture, prnt.index)
msg = combine(mbx.cpd, mbx.message_from_parent)

mbx = mailbox(arch, node.index)
mbx = mailbox(architecture, node.index)
mbx.message_to_parent = nothing
mbx.message_from_parent = project(msg, first(nodevalue(node)))
end
Expand All @@ -319,18 +349,18 @@ end
# μ: n → pa(n)
# and
# μ: pa(n) → n
function mailbox(arch::Architecture, n::Int)
function mailbox(arch::Architecture, n::Integer)
arch.mailboxes[n]
end


# Compute the join tree factor
# ψₙ
function factor(arch::Architecture{<:Any, T₁, T₂}, n::Int) where {T₁, T₂}
function factor(architecture::Architecture{<:Any, T₁, T₂}, n::Integer) where {T₁, T₂}
fac = zero(Factor{T₁, T₂})

for f in arch.assignments[n]
fac = combine(fac, arch.factors[f])
for f in architecture.assignments[n]
fac = combine(fac, architecture.factors[f])
end

fac
Expand Down
16 changes: 14 additions & 2 deletions src/cpds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ function combine(cpd₁::CPD, fac₂::Factor{T₁, T₂}) where {T₁, T₂}
end


function disintegrate(fac::Factor{T₁, T₂}, vars::Vector{Int}) where {T₁, T₂}
# Compute the projection
# fac ↓ vars
function project(fac::Factor{T₁, T₂}, vars::AbstractVector) where {T₁, T₂}
first(disintegrate(fac, vars))
end


function disintegrate(fac::Factor{T₁, T₂}, vars::AbstractVector) where {T₁, T₂}
i₁ = Int[]
i₂ = Int[]

Expand All @@ -88,7 +95,12 @@ function disintegrate(fac::Factor{T₁, T₂}, vars::Vector{Int}) where {T₁, T
end


function disintegrate(hom::GaussianSystem, i₁::Vector{Int}, i₂::Vector{Int}, obs::Vector{Int})
function disintegrate(
hom::GaussianSystem,
i₁::AbstractVector,
i₂::AbstractVector,
obs::AbstractVector)

cms = cumsum(obs)

j₁ = Int[]
Expand Down
Loading

0 comments on commit ab40f0e

Please sign in to comment.