Skip to content

Commit

Permalink
Fixes failing tests (#276)
Browse files Browse the repository at this point in the history
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Jul 13, 2021
1 parent 4de6f54 commit 6c60c3b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
3 changes: 2 additions & 1 deletion test/compat/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

return logpdf(InverseGamma(2, 3), s) +
logpdf(Normal(0, sqrt(s)), m) +
logpdf(dist, 1.5) + logpdf(dist, 2.0)
logpdf(dist, 1.5) +
logpdf(dist, 2.0)
end

test_model_ad(gdemo_default, logp_gdemo_default)
Expand Down
12 changes: 7 additions & 5 deletions test/turing/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
y = randn()
model = demo(xs, y)
chain = sample(model, MH(), MCMCThreads(), 100, 2)
var_to_likelihoods = pointwise_loglikelihoods(model, chain)
var_to_likelihoods = pointwise_loglikelihoods(
model, MCMCChains.get_sections(chain, :parameters)
)
@test haskey(var_to_likelihoods, "xs[1]")
@test haskey(var_to_likelihoods, "xs[2]")
@test haskey(var_to_likelihoods, "xs[3]")
Expand All @@ -32,8 +34,8 @@
results = pointwise_loglikelihoods(model, var_info)
var_to_likelihoods = Dict(string(vn) =>for (vn, ℓ) in results)
s, m = var_info[SampleFromPrior()]
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"]
@test logpdf(Normal(m, s), xs[2]) == var_to_likelihoods["xs[2]"]
@test logpdf(Normal(m, s), xs[3]) == var_to_likelihoods["xs[3]"]
@test logpdf(Normal(m, s), y) == var_to_likelihoods["y"]
@test [logpdf(Normal(m, s), xs[1])] == var_to_likelihoods["xs[1]"]
@test [logpdf(Normal(m, s), xs[2])] == var_to_likelihoods["xs[2]"]
@test [logpdf(Normal(m, s), xs[3])] == var_to_likelihoods["xs[3]"]
@test [logpdf(Normal(m, s), y)] == var_to_likelihoods["y"]
end
20 changes: 10 additions & 10 deletions test/turing/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
chain1 = sample(model1, MH(), 100)
chain2 = sample(model2, MH(), 100)

res11 = generated_quantities(model1, chain1)
res21 = generated_quantities(model2, chain1)
res11 = generated_quantities(model1, MCMCChains.get_sections(chain1, :parameters))
res21 = generated_quantities(model2, MCMCChains.get_sections(chain1, :parameters))

res12 = generated_quantities(model1, chain2)
res22 = generated_quantities(model2, chain2)
res12 = generated_quantities(model1, MCMCChains.get_sections(chain2, :parameters))
res22 = generated_quantities(model2, MCMCChains.get_sections(chain2, :parameters))

# Check that the two different models produce the same values for
# the same chains.
Expand All @@ -43,8 +43,8 @@
# Ensure that they're not all the same (some can be, because rejected samples)
@test any(res12[1:(end - 1)] .!= res12[2:end])

test_setval!(model1, chain1)
test_setval!(model2, chain2)
test_setval!(model1, MCMCChains.get_sections(chain1, :parameters))
test_setval!(model2, MCMCChains.get_sections(chain2, :parameters))

# Next level
@model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV}
Expand Down Expand Up @@ -79,11 +79,11 @@
chain3 = sample(model3, MH(), 100)
chain4 = sample(model4, MH(), 100)

res33 = generated_quantities(model3, chain3)
res43 = generated_quantities(model4, chain3)
res33 = generated_quantities(model3, MCMCChains.get_sections(chain3, :parameters))
res43 = generated_quantities(model4, MCMCChains.get_sections(chain3, :parameters))

res34 = generated_quantities(model3, chain4)
res44 = generated_quantities(model4, chain4)
res34 = generated_quantities(model3, MCMCChains.get_sections(chain4, :parameters))
res44 = generated_quantities(model4, MCMCChains.get_sections(chain4, :parameters))

# Check that the two different models produce the same values for
# the same chains.
Expand Down
18 changes: 14 additions & 4 deletions test/turing/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

model = demo(xval)
varinfo = VarInfo(model)
chain = sample(model, IS(), iters; save_state=true)
chain = MCMCChains.get_sections(
sample(model, IS(), iters; save_state=true), :parameters
)
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
lps = logpdf.(Normal.(chain["m"], 1), xval)
@test logprob"x = xval | chain = chain" == lps
Expand Down Expand Up @@ -40,7 +42,9 @@

model = demo(xval)
varinfo = VarInfo(model)
chain = sample(model, HMC(0.5, 1), iters; save_state=true)
chain = MCMCChains.get_sections(
sample(model, HMC(0.5, 1), iters; save_state=true), :parameters
)
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())

names = namesingroup(chain, "m")
Expand Down Expand Up @@ -74,7 +78,10 @@
group = rand(1:4, 100)
n_groups = 4

chain1 = sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
chain1 = MCMCChains.get_sections(
sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true),
:parameters,
)
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1"

@model function model2(y, group, n_groups)
Expand All @@ -85,7 +92,10 @@
end
end

chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
chain2 = MCMCChains.get_sections(
sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true),
:parameters,
)
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
end
end

0 comments on commit 6c60c3b

Please sign in to comment.