Skip to content

Commit

Permalink
Overload AbstractPPL.condition and AbstractPPL.decondition (#337)
Browse files Browse the repository at this point in the history
Fixes #336.

Tests will fail until the Zygote bug is fixed... Maybe we should just mark them as broken so we can merge and release some PRs?

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
devmotion and devmotion committed Nov 4, 2021
1 parent 7a8ba7e commit dd1d301
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.16.0"
version = "0.16.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
28 changes: 17 additions & 11 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,15 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default.
See also: [`decondition`](@ref)
"""
condition(; values...) = condition(DefaultContext(), NamedTuple(values))
condition(values::NamedTuple) = condition(DefaultContext(), values)
condition(context::AbstractContext, values::NamedTuple{()}) = context
condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context)
condition(context::AbstractContext; values...) = condition(context, NamedTuple(values))
AbstractPPL.condition(; values...) = condition(DefaultContext(), NamedTuple(values))
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
function AbstractPPL.condition(context::AbstractContext, values::NamedTuple)
return ConditionContext(values, context)
end
function AbstractPPL.condition(context::AbstractContext; values...)
return condition(context, NamedTuple(values))
end

"""
decondition(context::AbstractContext, syms...)
Expand All @@ -381,20 +385,22 @@ Note that this recursively traverses contexts, deconditioning all along the way.
See also: [`condition`](@ref)
"""
decondition(::IsLeaf, context, args...) = context
function decondition(::IsParent, context, args...)
AbstractPPL.decondition(::IsLeaf, context, args...) = context
function AbstractPPL.decondition(::IsParent, context, args...)
return setchildcontext(context, decondition(childcontext(context), args...))
end
decondition(context, args...) = decondition(NodeTrait(context), context, args...)
function decondition(context::ConditionContext)
function AbstractPPL.decondition(context, args...)
return decondition(NodeTrait(context), context, args...)
end
function AbstractPPL.decondition(context::ConditionContext)
return decondition(childcontext(context))
end
function decondition(context::ConditionContext, sym)
function AbstractPPL.decondition(context::ConditionContext, sym)
return condition(
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
)
end
function decondition(context::ConditionContext, sym, syms...)
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
return decondition(
condition(
decondition(childcontext(context), syms...),
Expand Down
6 changes: 3 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ From this we can tell what the correct way to condition `m` within `demo_inner`
is in the two different models.
"""
condition(model::Model; values...) = condition(model, NamedTuple(values))
function condition(model::Model, values)
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
function AbstractPPL.condition(model::Model, values)
return contextualize(model, condition(model.context, values))
end

Expand Down Expand Up @@ -307,7 +307,7 @@ julia> model(rng)
(m = 0.683947930996541, x = 10.0)
```
"""
function decondition(model::Model, syms...)
function AbstractPPL.decondition(model::Model, syms...)
return contextualize(model, decondition(model.context, syms...))
end

Expand Down
2 changes: 1 addition & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ end

# Ensure we can specialize on arguments.
@model demo(x) = x ~ Normal()
length(methods(demo))
@test length(methods(demo)) == 4
@test f(demo(1.0))
f(::Model{typeof(demo),(:x,)}) = false
@test !f(demo(1.0))
Expand Down
7 changes: 6 additions & 1 deletion test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ function test_model_ad(model, logp_manual)

y, back = Zygote.pullback(logp_model, x)
@test y lp
@test back(1)[1] grad
# will be fixed by https://github.com/FluxML/Zygote.jl/pull/1104
if Threads.nthreads() > 1
@test_broken back(1)[1] grad
else
@test back(1)[1] grad
end
end

"""
Expand Down

2 comments on commit dd1d301

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48180

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.16.1 -m "<description of version>" dd1d30115dea98885d5e03ce361b376d131e1b28
git push origin v0.16.1

Please sign in to comment.