Skip to content

Commit

Permalink
Extra context constructors (#374)
Browse files Browse the repository at this point in the history
* Extra context constructors

* Simplify

* Update src/contexts.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add test

* fix

* formatting

* Update test/contexts.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/contexts.jl

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from review

* Update test/contexts.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* More readable

* Remove inner constructor

* remove brackets

* Formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/contexts.jl

* Update src/contexts.jl

Co-authored-by: David Widmann <[email protected]>

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
3 people authored Feb 17, 2022
1 parent aeb5e03 commit c6ccb08
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.17.5"
version = "0.17.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
30 changes: 25 additions & 5 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right

# Contexts
"""
SamplingContext(rng, sampler, context)
SamplingContext(
[rng::Random.AbstractRNG=Random.GLOBAL_RNG],
[sampler::AbstractSampler=SampleFromPrior()],
[context::AbstractContext=DefaultContext()],
)
Create a context that allows you to sample parameters with the `sampler` when running the model.
The `context` determines how the returned log density is computed when running the model.
Expand All @@ -132,10 +136,26 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte
sampler::S
context::C
end
SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context)
SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context)
SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext())
SamplingContext() = SamplingContext(SampleFromPrior())

function SamplingContext(
rng::Random.AbstractRNG=Random.GLOBAL_RNG, sampler::AbstractSampler=SampleFromPrior()
)
return SamplingContext(rng, sampler, DefaultContext())
end

function SamplingContext(
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
)
return SamplingContext(Random.GLOBAL_RNG, sampler, context)
end

function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
return SamplingContext(rng, SampleFromPrior(), context)
end

function SamplingContext(context::AbstractContext)
return SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), context)
end

NodeTrait(context::SamplingContext) = IsParent()
childcontext(context::SamplingContext) = context.context
Expand Down
15 changes: 15 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,19 @@ end
@test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x")
@test getlens(vn_prefixed) === getlens(vn)
end

@testset "SamplingContext" begin
context = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext())
@test context isa SamplingContext

# convenience constructors
@test SamplingContext() == context
@test SamplingContext(Random.GLOBAL_RNG) == context
@test SamplingContext(SampleFromPrior()) == context
@test SamplingContext(DefaultContext()) == context
@test SamplingContext(Random.GLOBAL_RNG, SampleFromPrior()) == context
@test SamplingContext(Random.GLOBAL_RNG, DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
end
end

2 comments on commit c6ccb08

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/55653

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.17.6 -m "<description of version>" c6ccb0810532d1ade8037a78d9b7ec9856c4e221
git push origin v0.17.6

Please sign in to comment.