From 4938a31a4ced04c30e11d9d4875345f17fe69e70 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Jun 2022 12:26:54 +0000 Subject: [PATCH] Fix for #371 (#372) This PR adds a method called `resolve_varnames(varname, dist)` and adds an additional generated variable for each `~` which now holds the RHS of `~`. It does address #371 but uncertain if this is the best way, so wouldn't recommend merging this just yet. But putting it here so we can colab on it. Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- src/compiler.jl | 18 +++++++++++++----- src/distribution_wrappers.jl | 11 +++++++++++ test/compiler.jl | 12 ++++++++++++ 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index e060d685a..088de1752 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.1" +version = "0.19.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/compiler.jl b/src/compiler.jl index 4038d5d14..0ea6981f3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -175,6 +175,9 @@ function unwrap_right_left_vns( return unwrap_right_left_vns(right, left, vns) end +resolve_varnames(vn::VarName, _) = vn +resolve_varnames(vn::VarName, dist::NamedDist) = dist.name + ################# # Main Compiler # ################# @@ -379,16 +382,19 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption value + @gensym vn isassumption value dist # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact # that in DynamicPPL we the entire function body. Instead we should be # more selective with our escape. Until that's the case, we remove them all. return quote - $vn = $(AbstractPPL.drop_escape(varname(left))) + $dist = $right + $vn = $(DynamicPPL.resolve_varnames)( + $(AbstractPPL.drop_escape(varname(left))), $dist + ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption - $(generate_tilde_assume(left, right, vn)) + $(generate_tilde_assume(left, dist, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) @@ -397,7 +403,7 @@ function generate_tilde(left, right) $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), + $(DynamicPPL.check_tilde_rhs)($dist), $(maybe_view(left)), $vn, __varinfo__, @@ -442,7 +448,9 @@ function generate_dot_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption value return quote - $vn = $(AbstractPPL.drop_escape(varname(left))) + $vn = $(DynamicPPL.resolve_varnames)( + $(AbstractPPL.drop_escape(varname(left))), $right + ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index a0f8dbc47..4045cc089 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -13,6 +13,17 @@ end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}()) +Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) +function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) + return Distributions.logpdf(dist.dist, x) +end +function Distributions.loglikelihood(dist::NamedDist, x::Real) + return Distributions.loglikelihood(dist.dist, x) +end +function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real}) + return Distributions.loglikelihood(dist.dist, x) +end + struct NoDist{variate,support,Td<:Distribution{variate,support}} <: Distribution{variate,support} dist::Td diff --git a/test/compiler.jl b/test/compiler.jl index f59f013ac..e7e43102b 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -312,6 +312,18 @@ end @test vi2.metadata.y.vns[1] == @varname(y[2][:, 1]) @test haskey(vi3.metadata, :y) @test vi3.metadata.y.vns[1] == @varname(y[1]) + + # Conditioning + f1_c = f1() | (y=1,) + f2_c = f2() | NamedTuple((Symbol(@varname(y[2][:, 1])) => 1,)) + f3_c = f3() | NamedTuple((Symbol(@varname(y[1])) => 1,)) + @test f1_c() == 1 + # TODO(torfjelde): We need conditioning for `Dict`. + @test_broken f2_c() == 1 + @test_broken f3_c() == 1 + @test_broken getlogp(VarInfo(f1_c)) == + getlogp(VarInfo(f2_c)) == + getlogp(VarInfo(f3_c)) end @testset "custom tilde" begin @model demo() = begin