From 9043d657a7d1580ba24c19e7f7128f5edd53bda5 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Mon, 23 Nov 2020 02:28:51 +1100 Subject: [PATCH] allow redefinition of inputs in logprob (#192) Co-authored-by: David Widmann --- Project.toml | 2 +- src/context_implementations.jl | 4 ++-- src/prob_macro.jl | 2 +- test/prob_macro.jl | 19 +++++++++++++++++++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index e38ab9eee..1a07c0b4a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.9.7" +version = "0.9.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/context_implementations.jl b/src/context_implementations.jl index cf7e7f7f3..6b3542acd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -30,7 +30,7 @@ function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) return _tilde(rng, sampler, right, vn, vi) end function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end @@ -169,7 +169,7 @@ function dot_tilde( inds, vi, ) - if ctx.vars !== nothing + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) set_val!(vi, vns, dist, var) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 6d7b5ebee..52d1c7466 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -190,7 +190,7 @@ function Distributions.loglikelihood( if isdefined(right, :chain) # Element-wise likelihood for each value in chain chain = right.chain - ctx = LikelihoodContext() + ctx = LikelihoodContext(right) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) logps = map(iters) do (sample_idx, chain_idx) setval!(vi, chain, sample_idx, chain_idx) diff --git a/test/prob_macro.jl b/test/prob_macro.jl index 5bebbffdf..28a8ee6cc 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -128,4 +128,23 @@ Random.seed!(129) chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true) logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2" end + + @testset "issue190" begin + @model function gdemo(x, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ filldist(Normal(m, sqrt(s)), length(y)) + for i in 1:length(y) + y[i] ~ Normal(x[i], sqrt(s)) + end + end + c = Chains(rand(10, 2), [:m, :s]) + model_gdemo = gdemo([1.0, 0.0], [1.5, 0.0]) + r1 = prob"y = [1.5] | chain=c, model = model_gdemo, x = [1.0]" + r2 = map(c[:s]) do s + # exp(logpdf(..)) not pdf because this is exactly what the prob"" macro does, so we test r1 == r2 + exp(logpdf(Normal(1, sqrt(s)), 1.5)) + end + @test r1 == r2 + end end