Skip to content

Commit

Permalink
Condition with Dict as underlying storage (#419)
Browse files Browse the repository at this point in the history
This PR allows usage of `Dict` as the underlying storage in addition to the currently supported `NamedTuple`.

Similarly to `SimpleVarInfo`, this gives us two approaches: one with somewhat limited support, as outlined in the docstring, but with (usually) zero runtime overhead (`NamedTuple`), and one with full support but with runtime overead (`Dict`).
  • Loading branch information
torfjelde committed Aug 29, 2022
1 parent 0ba86e2 commit 08ef935
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 128 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.20.0"
version = "0.20.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
73 changes: 45 additions & 28 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,28 +272,18 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
end
end

struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext
struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx

function ConditionContext{Values}(
values::Values, context::AbstractContext
) where {names,Values<:NamedTuple{names}}
return new{names,typeof(values),typeof(context)}(values, context)
end
end

function ConditionContext(values::NamedTuple)
return ConditionContext(values, DefaultContext())
end
function ConditionContext(values::NamedTuple, context::AbstractContext)
return ConditionContext{typeof(values)}(values, context)
end
const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
const DictConditionContext = ConditionContext{<:AbstractDict}

ConditionContext(values) = ConditionContext(values, DefaultContext())

# Try to avoid nested `ConditionContext`.
function ConditionContext(
values::NamedTuple{Names}, context::ConditionContext
) where {Names}
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), childcontext(context))
Expand All @@ -303,7 +293,7 @@ function Base.show(io::IO, context::ConditionContext)
return print(io, "ConditionContext($(context.values), $(childcontext(context)))")
end

NodeTrait(context::ConditionContext) = IsParent()
NodeTrait(::ConditionContext) = IsParent()
childcontext(context::ConditionContext) = context.context
setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child)

Expand All @@ -313,14 +303,9 @@ setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.value
Return `true` if `vn` is found in `context`.
"""
hasvalue(context, vn) = false

function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
return sym in vars
end
function hasvalue(
context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}
) where {vars,sym}
return sym in vars
hasvalue(context::ConditionContext, vn::VarName) = nested_haskey(context.values, vn)
function hasvalue(context::ConditionContext, vns::AbstractArray{<:VarName})
return all(Base.Fix1(nested_haskey, context.values), vns)
end

"""
Expand All @@ -331,7 +316,8 @@ Return value of `vn` in `context`.
function getvalue(context::AbstractContext, vn)
return error("context $(context) does not contain value for $vn")
end
getvalue(context::ConditionContext, vn) = get(context.values, vn)
getvalue(context::NamedConditionContext, vn) = get(context.values, vn)
getvalue(context::ConditionContext, vn) = nested_getindex(context.values, vn)

"""
hasvalue_nested(context, vn)
Expand Down Expand Up @@ -386,15 +372,33 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default.
See also: [`decondition`](@ref)
"""
AbstractPPL.condition(; values...) = condition(DefaultContext(), NamedTuple(values))
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
return condition((value, values...))
end
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
return condition(DefaultContext(), values)
end
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
function AbstractPPL.condition(context::AbstractContext, values::NamedTuple)
function AbstractPPL.condition(
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
)
return ConditionContext(values, context)
end
function AbstractPPL.condition(context::AbstractContext; values...)
return condition(context, NamedTuple(values))
end
function AbstractPPL.condition(
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
)
return condition(context, (value, values...))
end
function AbstractPPL.condition(
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
)
return condition(context, Dict(values))
end

"""
decondition(context::AbstractContext, syms...)
Expand Down Expand Up @@ -430,6 +434,19 @@ function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
)
end

function AbstractPPL.decondition(
context::NamedConditionContext, vn::VarName{sym}
) where {sym}
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
)
end
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
)
end

"""
conditioned(context::AbstractContext)
Expand Down
Loading

0 comments on commit 08ef935

Please sign in to comment.