diff --git a/Project.toml b/Project.toml index c4b692b30..07c5ef7d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.7" +version = "0.17.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/model.jl b/src/model.jl index 144ac03d0..f973a047a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -517,6 +517,27 @@ Get the name of the `model` as `Symbol`. """ Base.nameof(model::Model) = model.name +""" + rand([rng=Random.GLOBAL_RNG], [T=NamedTuple], model::Model) + +Generate a sample of type `T` from the prior distribution of the `model`. +""" +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} + x = last( + evaluate!!( + model, + SimpleVarInfo{Float64}(), + SamplingContext(rng, SampleFromPrior(), DefaultContext()), + ), + ) + return DynamicPPL.values_as(x, T) +end + +# Default RNG and type +Base.rand(rng::Random.AbstractRNG, model::Model) = rand(rng, NamedTuple, model) +Base.rand(::Type{T}, model::Model) where {T} = rand(Random.GLOBAL_RNG, T, model) +Base.rand(model::Model) = rand(Random.GLOBAL_RNG, NamedTuple, model) + """ logjoint(model::Model, varinfo::AbstractVarInfo) diff --git a/test/model.jl b/test/model.jl index 466a7d1f4..fbc6e2ad4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -81,4 +81,26 @@ call_retval = model() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end + + @testset "rand" begin + model = gdemo_default + + Random.seed!(1776) + s, m = model() + sample_namedtuple = (; s=s, m=m) + sample_dict = Dict(:s => s, :m => m) + + # With explicit RNG + @test rand(Random.seed!(1776), model) == sample_namedtuple + @test rand(Random.seed!(1776), NamedTuple, model) == sample_namedtuple + @test rand(Random.seed!(1776), Dict, model) == sample_dict + + # Without explicit RNG + Random.seed!(1776) + @test rand(model) == sample_namedtuple + Random.seed!(1776) + @test rand(NamedTuple, model) == sample_namedtuple + Random.seed!(1776) + @test rand(Dict, model) == sample_dict + end end