From c6ccb0810532d1ade8037a78d9b7ec9856c4e221 Mon Sep 17 00:00:00 2001 From: Carlos Parada Date: Thu, 17 Feb 2022 07:30:53 -0800 Subject: [PATCH] Extra context constructors (#374) * Extra context constructors * Simplify * Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add test * fix * formatting * Update test/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/contexts.jl Co-authored-by: David Widmann * Apply suggestions from review * Update test/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * More readable * Remove inner constructor * remove brackets * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/contexts.jl * Update src/contexts.jl Co-authored-by: David Widmann Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann --- Project.toml | 2 +- src/contexts.jl | 30 +++++++++++++++++++++++++----- test/contexts.jl | 15 +++++++++++++++ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 223032d63..d7c561ab1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.5" +version = "0.17.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/contexts.jl b/src/contexts.jl index bfefb07d1..23a6128e1 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -120,7 +120,11 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts """ - SamplingContext(rng, sampler, context) + SamplingContext( + [rng::Random.AbstractRNG=Random.GLOBAL_RNG], + [sampler::AbstractSampler=SampleFromPrior()], + [context::AbstractContext=DefaultContext()], + ) Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. @@ -132,10 +136,26 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end -SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) -SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) -SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) -SamplingContext() = SamplingContext(SampleFromPrior()) + +function SamplingContext( + rng::Random.AbstractRNG=Random.GLOBAL_RNG, sampler::AbstractSampler=SampleFromPrior() +) + return SamplingContext(rng, sampler, DefaultContext()) +end + +function SamplingContext( + sampler::AbstractSampler, context::AbstractContext=DefaultContext() +) + return SamplingContext(Random.GLOBAL_RNG, sampler, context) +end + +function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) + return SamplingContext(rng, SampleFromPrior(), context) +end + +function SamplingContext(context::AbstractContext) + return SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), context) +end NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context diff --git a/test/contexts.jl b/test/contexts.jl index f3a1ae800..65629afec 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -256,4 +256,19 @@ end @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") @test getlens(vn_prefixed) === getlens(vn) end + + @testset "SamplingContext" begin + context = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + @test context isa SamplingContext + + # convenience constructors + @test SamplingContext() == context + @test SamplingContext(Random.GLOBAL_RNG) == context + @test SamplingContext(SampleFromPrior()) == context + @test SamplingContext(DefaultContext()) == context + @test SamplingContext(Random.GLOBAL_RNG, SampleFromPrior()) == context + @test SamplingContext(Random.GLOBAL_RNG, DefaultContext()) == context + @test SamplingContext(SampleFromPrior(), DefaultContext()) == context + @test SamplingContext(SampleFromPrior(), DefaultContext()) == context + end end