From 3602c56dba16bc2b60efbef6bf6ff31226354c7d Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 29 Mar 2021 13:03:39 +0000 Subject: [PATCH] Introducing AbstractPPL dependency (#214) This has three very rudimentary consequences: - `VarName` and its helpers are moved over to AbstractPPL completely - The new abstract base type for models is `AbstractPPL.AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel` - `AbstractVarInfo <: AbstractPPL.AbstractModelTrace` More abstractions (and hopefully concrete generalizations, too) are about to come. --- Project.toml | 4 +- src/DynamicPPL.jl | 9 +- src/model.jl | 2 +- src/varname.jl | 235 ++-------------------------------------------- test/Project.toml | 2 + test/compiler.jl | 29 ------ test/runtests.jl | 3 +- 7 files changed, 22 insertions(+), 262 deletions(-) diff --git a/Project.toml b/Project.toml index 8cd3f5b73..cfd21170e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.7" +version = "0.10.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -13,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] AbstractMCMC = "2" +AbstractPPL = "0.1.2" Bijectors = "0.5.2, 0.6, 0.7, 0.8" ChainRulesCore = "0.9.7" Distributions = "0.23.8, 0.24" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ffc196bea..e369fee3f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -1,6 +1,7 @@ module DynamicPPL -using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel +using AbstractMCMC: AbstractSampler, AbstractChains +using AbstractPPL using Distributions using Bijectors @@ -49,13 +50,13 @@ export AbstractVarInfo, link!, invlink!, tonamedtuple, -#VarName +# VarName (reexport from AbstractPPL) VarName, inspace, subsumes, + @varname, # Compiler @model, - @varname, # Utilities vectorize, reconstruct, @@ -104,7 +105,7 @@ export loglikelihood function getspace end # Necessary forward declarations -abstract type AbstractVarInfo end +abstract type AbstractVarInfo <: AbstractModelTrace end abstract type AbstractContext end diff --git a/src/model.jl b/src/model.jl index 5e217859a..fbb83c8a9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,7 +32,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} diff --git a/src/varname.jl b/src/varname.jl index ed58e4754..f45b0b430 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,240 +1,23 @@ """ - VarName(sym[, indexing=()]) + inargnames(varname::VarName, model::Model) -A variable identifier for a symbol `sym` and indices `indexing` in the format -returned by [`@vinds`](@ref). +Statically check whether the variable of name `varname` is an argument of the `model`. -The Julia variable in the model corresponding to `sym` can refer to a single value or to a -hierarchical array structure of univariate, multivariate or matrix variables. The field `indexing` -stores the indices requires to access the random variable from the Julia variable indicated by `sym` -as a tuple of tuples. Each element of the tuple thereby contains the indices of one indexing -operation. - -`VarName`s can be manually constructed using the `VarName(sym, indexing)` constructor, or from an -indexing expression through the [`@varname`](@ref) convenience macro. - -# Examples - -```jldoctest -julia> vn = VarName(:x, ((Colon(), 1), (2,))) -x[Colon(),1][2] - -julia> vn.indexing -((Colon(), 1), (2,)) - -julia> VarName(DynamicPPL.@vsym(x[:, 1][1+1]), DynamicPPL.@vinds(x[:, 1][1+1])) -x[Colon(),1][2] -``` -""" -struct VarName{sym, T<:Tuple} - indexing::T -end - -VarName(sym::Symbol, indexing::Tuple = ()) = VarName{sym, typeof(indexing)}(indexing) - -""" - VarName(vn::VarName[, indexing=()]) - -Return a copy of `vn` with a new index `indexing`. -""" -function VarName(vn::VarName, indexing::Tuple = ()) - return VarName{getsym(vn), typeof(indexing)}(indexing) -end - - -""" - getsym(vn::VarName) - -Return the symbol of the Julia variable used to generate `vn`. -""" -getsym(vn::VarName{sym}) where sym = sym - - -""" - getindexing(vn::VarName) - -Return the indexing tuple of the Julia variable used to generate `vn`. -""" -getindexing(vn::VarName) = vn.indexing - - -Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getindexing(vn)), h) -Base.:(==)(x::VarName, y::VarName) = getsym(x) == getsym(y) && getindexing(x) == getindexing(y) - -function Base.show(io::IO, vn::VarName) - print(io, getsym(vn)) - for indices in getindexing(vn) - print(io, "[") - join(io, indices, ",") - print(io, "]") - end -end - - -""" - Symbol(vn::VarName) - -Return a `Symbol` represenation of the variable identifier `VarName`. -""" -Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol - - -""" - inspace(vn::Union{VarName, Symbol}, space::Tuple) - -Check whether `vn`'s variable symbol is in `space`. -""" -inspace(vn, space::Tuple{}) = true # empty space is treated as universal set -inspace(vn, space::Tuple) = vn in space -inspace(vn::VarName, space::Tuple{}) = true -inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space) - -_in(vn::VarName, s::Symbol) = getsym(vn) == s -_in(vn::VarName, s::VarName) = subsumes(s, vn) - - -""" - subsumes(u::VarName, v::VarName) - -Check whether the variable name `v` describes a sub-range of the variable `u`. Supported -indexing: - -- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. -- Array of scalar: `x[[1, 2], 3]` subsumes `x[1, 3]`, `x[1:3]` subsumes `x[2][1]`, etc. - (basically everything that fulfills `issubset`). -- Slices: `x[2, :]` subsumes `x[2, 10][1]`, etc. - -Currently _not_ supported are: - -- Boolean indexing, literal `CartesianIndex` (these could be added, though) -- Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for `x` a matrix -- Trailing ones: `x[2, 1]` does not subsume `x[2]` for `x` a vector -""" -function subsumes(u::VarName, v::VarName) - return getsym(u) == getsym(v) && subsumes(u.indexing, v.indexing) -end - -subsumes(::Tuple{}, ::Tuple{}) = true # x subsumes x -subsumes(::Tuple{}, ::Tuple) = true # x subsumes x[1] -subsumes(::Tuple, ::Tuple{}) = false # x[1] does not subsume x -function subsumes(t::Tuple, u::Tuple) # does x[i]... subsume x[j]...? - return _issubindex(first(t), first(u)) && subsumes(Base.tail(t), Base.tail(u)) -end - -const AnyIndex = Union{Int, AbstractVector{Int}, Colon} -_issubindex_(::Tuple{Vararg{AnyIndex}}, ::Tuple{Vararg{AnyIndex}}) = false -function _issubindex(t::NTuple{N, AnyIndex}, u::NTuple{N, AnyIndex}) where {N} - return all(_issubrange(j, i) for (i, j) in zip(t, u)) -end - -const ConcreteIndex = Union{Int, AbstractVector{Int}} # this include all kinds of ranges -"""Determine whether indices `i` are contained in `j`, treating `:` as universal set.""" -_issubrange(i::ConcreteIndex, j::ConcreteIndex) = issubset(i, j) -_issubrange(i::Union{ConcreteIndex, Colon}, j::Colon) = true -_issubrange(i::Colon, j::ConcreteIndex) = true - - - -""" - @varname(expr) - -A macro that returns an instance of [`VarName`](@ref) given a symbol or indexing expression `expr`. - -The `sym` value is taken from the actual variable name, and the index values are put appropriately -into the constructor (and resolved at runtime). - -# Examples - -```jldoctest -julia> @varname(x).indexing -() - -julia> @varname(x[1]).indexing -((1,),) - -julia> @varname(x[:, 1]).indexing -((Colon(), 1),) - -julia> @varname(x[:, 1][2]).indexing -((Colon(), 1), (2,)) - -julia> @varname(x[1,2][1+5][45][3]).indexing -((1, 2), (6,), (45,), (3,)) -``` - -!!! compat "Julia 1.5" - Using `begin` in an indexing expression to refer to the first index requires at least - Julia 1.5. +Possibly existing indices of `varname` are neglected. """ -macro varname(expr::Union{Expr, Symbol}) - return esc(varname(expr)) -end - -varname(expr::Symbol) = VarName(expr) -function varname(expr::Expr) - if Meta.isexpr(expr, :ref) - sym, inds = vsym(expr), vinds(expr) - return :($(DynamicPPL.VarName)($(QuoteNode(sym)), $inds)) - else - throw("VarName: Mis-formed variable name $(expr)!") - end -end - - -""" - @vsym(expr) - -A macro that returns the variable symbol given the input variable expression `expr`. -For example, `@vsym x[1]` returns `:x`. -""" -macro vsym(expr::Union{Expr, Symbol}) - return QuoteNode(vsym(expr)) +@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F} + return s in argnames end -vsym(expr::Symbol) = expr -function vsym(expr::Expr) - if Meta.isexpr(expr, :ref) - return vsym(expr.args[1]) - else - throw("VarName: Mis-formed variable name $(expr)!") - end -end """ - @vinds(expr) + inmissings(varname::VarName, model::Model) -Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1, :][2]` returns -`((1, Colon()), (2,))`. +Statically check whether the variable of name `varname` is a statically declared unobserved variable +of the `model`. -!!! compat "Julia 1.5" - Using `begin` in an indexing expression to refer to the first index requires at least - Julia 1.5. +Possibly existing indices of `varname` are neglected. """ -macro vinds(expr::Union{Expr, Symbol}) - return esc(vinds(expr)) -end - -vinds(expr::Symbol) = Expr(:tuple) -function vinds(expr::Expr) - if Meta.isexpr(expr, :ref) - ex = copy(expr) - @static if VERSION < v"1.5.0-DEV.666" - Base.replace_ref_end!(ex) - else - Base.replace_ref_begin_end!(ex) - end - last = Expr(:tuple, ex.args[2:end]...) - init = vinds(ex.args[1]).args - return Expr(:tuple, init..., last) - else - throw("VarName: Mis-formed variable name $(expr)!") - end -end - -@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F} - return s in argnames -end - @generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T} return s in missings end diff --git a/test/Project.toml b/test/Project.toml index aefa44438..0084a3668 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -17,6 +18,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1" +AbstractPPL = "0.1.2" Bijectors = "0.8.2" Distributions = "0.24" DistributionsAD = "0.6.3" diff --git a/test/compiler.jl b/test/compiler.jl index ee5e1e126..55ad0c706 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -222,35 +222,6 @@ end varinfo = VarInfo(model) @test getlogp(varinfo) == lp end - @testset "var name splitting" begin - var_expr = :(x) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :(()) - - var_expr = :(x[1,1][2,3]) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :(((1, 1), (2, 3))) - - var_expr = :(x[:,1][2,:]) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :(((:, 1), (2, :))) - - var_expr = :(x[2:3,1][2,1:2]) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :(((2:3, 1), (2, 1:2))) - - var_expr = :(x[2:3,2:3][[1,2],[1,2]]) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :(((2:3, 2:3), ([1, 2], [1, 2]))) - - var_expr = :(x[end]) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :((($lastindex(x),),)) - - var_expr = :(x[1, end]) - @test vsym(var_expr) == :x - @test vinds(var_expr) == :(((1, $lastindex(x, 2)),)) - end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:,1])) diff --git a/test/runtests.jl b/test/runtests.jl index cbd657243..05e3a74e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using DynamicPPL using AbstractMCMC +using AbstractPPL using Bijectors using Distributions using DistributionsAD @@ -16,7 +17,7 @@ using Random using Serialization using Test -using DynamicPPL: vsym, vinds, getargs_dottilde, getargs_tilde, Selector +using DynamicPPL: getargs_dottilde, getargs_tilde, Selector const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing")