Skip to content

Commit

Permalink
Introducing AbstractPPL dependency (#214)
Browse files Browse the repository at this point in the history
This has three very rudimentary consequences:

- `VarName` and its helpers are moved over to AbstractPPL completely
- The new abstract base type for models is `AbstractPPL.AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel`
- `AbstractVarInfo <: AbstractPPL.AbstractModelTrace`

More abstractions (and hopefully concrete generalizations, too) are about to come.
  • Loading branch information
phipsgabler committed Mar 29, 2021
1 parent 1ab2e4e commit 3602c56
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 262 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.7"
version = "0.10.8"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -13,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
AbstractMCMC = "2"
AbstractPPL = "0.1.2"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
ChainRulesCore = "0.9.7"
Distributions = "0.23.8, 0.24"
Expand Down
9 changes: 5 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module DynamicPPL

using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using AbstractMCMC: AbstractSampler, AbstractChains
using AbstractPPL
using Distributions
using Bijectors

Expand Down Expand Up @@ -49,13 +50,13 @@ export AbstractVarInfo,
link!,
invlink!,
tonamedtuple,
#VarName
# VarName (reexport from AbstractPPL)
VarName,
inspace,
subsumes,
@varname,
# Compiler
@model,
@varname,
# Utilities
vectorize,
reconstruct,
Expand Down Expand Up @@ -104,7 +105,7 @@ export loglikelihood
function getspace end

# Necessary forward declarations
abstract type AbstractVarInfo end
abstract type AbstractVarInfo <: AbstractModelTrace end
abstract type AbstractContext end


Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbabilisticProgram
name::Symbol
f::F
args::NamedTuple{argnames,Targs}
Expand Down
235 changes: 9 additions & 226 deletions src/varname.jl
Original file line number Diff line number Diff line change
@@ -1,240 +1,23 @@
"""
VarName(sym[, indexing=()])
inargnames(varname::VarName, model::Model)
A variable identifier for a symbol `sym` and indices `indexing` in the format
returned by [`@vinds`](@ref).
Statically check whether the variable of name `varname` is an argument of the `model`.
The Julia variable in the model corresponding to `sym` can refer to a single value or to a
hierarchical array structure of univariate, multivariate or matrix variables. The field `indexing`
stores the indices requires to access the random variable from the Julia variable indicated by `sym`
as a tuple of tuples. Each element of the tuple thereby contains the indices of one indexing
operation.
`VarName`s can be manually constructed using the `VarName(sym, indexing)` constructor, or from an
indexing expression through the [`@varname`](@ref) convenience macro.
# Examples
```jldoctest
julia> vn = VarName(:x, ((Colon(), 1), (2,)))
x[Colon(),1][2]
julia> vn.indexing
((Colon(), 1), (2,))
julia> VarName(DynamicPPL.@vsym(x[:, 1][1+1]), DynamicPPL.@vinds(x[:, 1][1+1]))
x[Colon(),1][2]
```
"""
struct VarName{sym, T<:Tuple}
indexing::T
end

VarName(sym::Symbol, indexing::Tuple = ()) = VarName{sym, typeof(indexing)}(indexing)

"""
VarName(vn::VarName[, indexing=()])
Return a copy of `vn` with a new index `indexing`.
"""
function VarName(vn::VarName, indexing::Tuple = ())
return VarName{getsym(vn), typeof(indexing)}(indexing)
end


"""
getsym(vn::VarName)
Return the symbol of the Julia variable used to generate `vn`.
"""
getsym(vn::VarName{sym}) where sym = sym


"""
getindexing(vn::VarName)
Return the indexing tuple of the Julia variable used to generate `vn`.
"""
getindexing(vn::VarName) = vn.indexing


Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getindexing(vn)), h)
Base.:(==)(x::VarName, y::VarName) = getsym(x) == getsym(y) && getindexing(x) == getindexing(y)

function Base.show(io::IO, vn::VarName)
print(io, getsym(vn))
for indices in getindexing(vn)
print(io, "[")
join(io, indices, ",")
print(io, "]")
end
end


"""
Symbol(vn::VarName)
Return a `Symbol` represenation of the variable identifier `VarName`.
"""
Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol


"""
inspace(vn::Union{VarName, Symbol}, space::Tuple)
Check whether `vn`'s variable symbol is in `space`.
"""
inspace(vn, space::Tuple{}) = true # empty space is treated as universal set
inspace(vn, space::Tuple) = vn in space
inspace(vn::VarName, space::Tuple{}) = true
inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)

_in(vn::VarName, s::Symbol) = getsym(vn) == s
_in(vn::VarName, s::VarName) = subsumes(s, vn)


"""
subsumes(u::VarName, v::VarName)
Check whether the variable name `v` describes a sub-range of the variable `u`. Supported
indexing:
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.
- Array of scalar: `x[[1, 2], 3]` subsumes `x[1, 3]`, `x[1:3]` subsumes `x[2][1]`, etc.
(basically everything that fulfills `issubset`).
- Slices: `x[2, :]` subsumes `x[2, 10][1]`, etc.
Currently _not_ supported are:
- Boolean indexing, literal `CartesianIndex` (these could be added, though)
- Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for `x` a matrix
- Trailing ones: `x[2, 1]` does not subsume `x[2]` for `x` a vector
"""
function subsumes(u::VarName, v::VarName)
return getsym(u) == getsym(v) && subsumes(u.indexing, v.indexing)
end

subsumes(::Tuple{}, ::Tuple{}) = true # x subsumes x
subsumes(::Tuple{}, ::Tuple) = true # x subsumes x[1]
subsumes(::Tuple, ::Tuple{}) = false # x[1] does not subsume x
function subsumes(t::Tuple, u::Tuple) # does x[i]... subsume x[j]...?
return _issubindex(first(t), first(u)) && subsumes(Base.tail(t), Base.tail(u))
end

const AnyIndex = Union{Int, AbstractVector{Int}, Colon}
_issubindex_(::Tuple{Vararg{AnyIndex}}, ::Tuple{Vararg{AnyIndex}}) = false
function _issubindex(t::NTuple{N, AnyIndex}, u::NTuple{N, AnyIndex}) where {N}
return all(_issubrange(j, i) for (i, j) in zip(t, u))
end

const ConcreteIndex = Union{Int, AbstractVector{Int}} # this include all kinds of ranges
"""Determine whether indices `i` are contained in `j`, treating `:` as universal set."""
_issubrange(i::ConcreteIndex, j::ConcreteIndex) = issubset(i, j)
_issubrange(i::Union{ConcreteIndex, Colon}, j::Colon) = true
_issubrange(i::Colon, j::ConcreteIndex) = true



"""
@varname(expr)
A macro that returns an instance of [`VarName`](@ref) given a symbol or indexing expression `expr`.
The `sym` value is taken from the actual variable name, and the index values are put appropriately
into the constructor (and resolved at runtime).
# Examples
```jldoctest
julia> @varname(x).indexing
()
julia> @varname(x[1]).indexing
((1,),)
julia> @varname(x[:, 1]).indexing
((Colon(), 1),)
julia> @varname(x[:, 1][2]).indexing
((Colon(), 1), (2,))
julia> @varname(x[1,2][1+5][45][3]).indexing
((1, 2), (6,), (45,), (3,))
```
!!! compat "Julia 1.5"
Using `begin` in an indexing expression to refer to the first index requires at least
Julia 1.5.
Possibly existing indices of `varname` are neglected.
"""
macro varname(expr::Union{Expr, Symbol})
return esc(varname(expr))
end

varname(expr::Symbol) = VarName(expr)
function varname(expr::Expr)
if Meta.isexpr(expr, :ref)
sym, inds = vsym(expr), vinds(expr)
return :($(DynamicPPL.VarName)($(QuoteNode(sym)), $inds))
else
throw("VarName: Mis-formed variable name $(expr)!")
end
end


"""
@vsym(expr)
A macro that returns the variable symbol given the input variable expression `expr`.
For example, `@vsym x[1]` returns `:x`.
"""
macro vsym(expr::Union{Expr, Symbol})
return QuoteNode(vsym(expr))
@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
return s in argnames
end

vsym(expr::Symbol) = expr
function vsym(expr::Expr)
if Meta.isexpr(expr, :ref)
return vsym(expr.args[1])
else
throw("VarName: Mis-formed variable name $(expr)!")
end
end

"""
@vinds(expr)
inmissings(varname::VarName, model::Model)
Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1, :][2]` returns
`((1, Colon()), (2,))`.
Statically check whether the variable of name `varname` is a statically declared unobserved variable
of the `model`.
!!! compat "Julia 1.5"
Using `begin` in an indexing expression to refer to the first index requires at least
Julia 1.5.
Possibly existing indices of `varname` are neglected.
"""
macro vinds(expr::Union{Expr, Symbol})
return esc(vinds(expr))
end

vinds(expr::Symbol) = Expr(:tuple)
function vinds(expr::Expr)
if Meta.isexpr(expr, :ref)
ex = copy(expr)
@static if VERSION < v"1.5.0-DEV.666"
Base.replace_ref_end!(ex)
else
Base.replace_ref_begin_end!(ex)
end
last = Expr(:tuple, ex.args[2:end]...)
init = vinds(ex.args[1]).args
return Expr(:tuple, init..., last)
else
throw("VarName: Mis-formed variable name $(expr)!")
end
end

@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
return s in argnames
end

@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
return s in missings
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -17,6 +18,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1"
AbstractPPL = "0.1.2"
Bijectors = "0.8.2"
Distributions = "0.24"
DistributionsAD = "0.6.3"
Expand Down
29 changes: 0 additions & 29 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,35 +222,6 @@ end
varinfo = VarInfo(model)
@test getlogp(varinfo) == lp
end
@testset "var name splitting" begin
var_expr = :(x)
@test vsym(var_expr) == :x
@test vinds(var_expr) == :(())

var_expr = :(x[1,1][2,3])
@test vsym(var_expr) == :x
@test vinds(var_expr) == :(((1, 1), (2, 3)))

var_expr = :(x[:,1][2,:])
@test vsym(var_expr) == :x
@test vinds(var_expr) == :(((:, 1), (2, :)))

var_expr = :(x[2:3,1][2,1:2])
@test vsym(var_expr) == :x
@test vinds(var_expr) == :(((2:3, 1), (2, 1:2)))

var_expr = :(x[2:3,2:3][[1,2],[1,2]])
@test vsym(var_expr) == :x
@test vinds(var_expr) == :(((2:3, 2:3), ([1, 2], [1, 2])))

var_expr = :(x[end])
@test vsym(var_expr) == :x
@test vinds(var_expr) == :((($lastindex(x),),))

var_expr = :(x[1, end])
@test vsym(var_expr) == :x
@test vinds(var_expr) == :(((1, $lastindex(x, 2)),))
end
@testset "user-defined variable name" begin
@model f1() = x ~ NamedDist(Normal(), :y)
@model f2() = x ~ NamedDist(Normal(), @varname(y[2][:,1]))
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DynamicPPL
using AbstractMCMC
using AbstractPPL
using Bijectors
using Distributions
using DistributionsAD
Expand All @@ -16,7 +17,7 @@ using Random
using Serialization
using Test

using DynamicPPL: vsym, vinds, getargs_dottilde, getargs_tilde, Selector
using DynamicPPL: getargs_dottilde, getargs_tilde, Selector

const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL)))
const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing")
Expand Down

2 comments on commit 3602c56

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33095

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.8 -m "<description of version>" 3602c56dba16bc2b60efbef6bf6ff31226354c7d
git push origin v0.10.8

Please sign in to comment.