From 7ef5da709564802fc2ccae182e0f77d1b7af5958 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 9 Aug 2023 07:03:56 +0100 Subject: [PATCH] Bugfix in `VarInfo`. (#516) * Bugfix in `VarInfo`. * Update Project.toml * Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/varinfo.jl Co-authored-by: David Widmann * Added tests. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix for tests in #516 (#517) * fixed tests for linking of dirichlet with different dimensionality * added usage of same logp in TestUtils.setup_varinfos --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann Co-authored-by: Tor Erlend Fjelde --- Project.toml | 2 +- src/test_utils.jl | 3 ++- src/varinfo.jl | 2 +- test/linking.jl | 33 +++++++++++++++++++-------------- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index e85f3b59f..c5b7a0241 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.12" +version = "0.23.13" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/test_utils.jl b/src/test_utils.jl index fd3bf3d62..6604b5df1 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -51,9 +51,10 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames) svi_typed = SimpleVarInfo(example_values) svi_untyped = SimpleVarInfo(OrderedDict()) + lp = getlogp(vi_typed) return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi # Set them all to the same values. - update_values!!(vi, example_values, varnames) + DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) end end diff --git a/src/varinfo.jl b/src/varinfo.jl index bda979eef..c1ccc34b9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -366,7 +366,7 @@ setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) for f in names length = :(sum(length, metadata.$f.ranges)) finish = :($start + $length - 1) - push!(expr.args, :(metadata.$f.vals .= val[($start):($finish)])) + push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length))) start = :($start + $length) end return expr diff --git a/test/linking.jl b/test/linking.jl index 26d28c13d..bb0081780 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -91,21 +91,26 @@ end end end + # Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504 @testset "dirichlet" begin - @model demo_dirichlet() = x ~ Dirichlet(2, 1.0) - model = demo_dirichlet() - vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),)) - @testset "$(short_varinfo_name(vi))" for vi in vis - @test length(vi[:]) == 2 - @test iszero(getlogp(vi)) - # Linked. - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) - @test length(vi_linked[:]) == 1 - @test !iszero(getlogp(vi_linked)) # should now include the log-absdet-jacobian correction - # Invlinked. - vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) - @test length(vi_invlinked[:]) == 2 - @test iszero(getlogp(vi_invlinked)) + @model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0) + @testset "d=$d" for d in [2, 3, 5] + model = demo_dirichlet(d) + vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),)) + @testset "$(short_varinfo_name(vi))" for vi in vis + lp = logpdf(Dirichlet(d, 1.0), vi[:]) + @test length(vi[:]) == d + @test getlogp(vi) ≈ lp + # Linked. + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + @test length(vi_linked[:]) == d - 1 + # Should now include the log-absdet-jacobian correction. + @test !(getlogp(vi_linked) ≈ lp) + # Invlinked. + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + @test length(vi_invlinked[:]) == d + @test getlogp(vi_invlinked) ≈ lp + end end end end