Skip to content

Commit

Permalink
Support functors (#398)
Browse files Browse the repository at this point in the history
Fixes #367.

Additionally, I removed the `name` field of `Model` since it seemed redundant with `nameof(model.f)` (if `model.f isa Function`) and `Symbol(model.f)` (otherwise). This could be separated or reverted.

TODO:
- [x] Add tests

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
devmotion and devmotion committed Mar 12, 2022
1 parent de40505 commit 748b191
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.18.0"
version = "0.19.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/benchmark_body.jmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ results["evaluation_typed"]
```

```julia; echo=false; results="hidden";
BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(m.name)_benchmarks.json"), results)
BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(nameof(m))_benchmarks.json"), results)
```

```julia; wrap=false
Expand All @@ -32,15 +32,15 @@ end
```julia; echo=false; results="hidden"
if WEAVE_ARGS[:include_typed_code]
# Serialize the output of `typed_code` so we can compare later.
haskey(WEAVE_ARGS, :name) && serialize(joinpath("results", WEAVE_ARGS[:name],"$(m.name).jls"), string(typed));
haskey(WEAVE_ARGS, :name) && serialize(joinpath("results", WEAVE_ARGS[:name],"$(nameof(m)).jls"), string(typed));
end
```

```julia; wrap=false; echo=false;
if haskey(WEAVE_ARGS, :name_old)
# We want to compare the generated code to the previous version.
import DiffUtils
typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(m.name).jls"));
typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(nameof(m)).jls"));
DiffUtils.diff(typed_old, string(typed), width=130)
end
```
19 changes: 12 additions & 7 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -594,19 +594,24 @@ function build_output(modelinfo, linenumbernode)
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]

# Obtain or generate the name of the model to support functors:
# https://github.com/TuringLang/DynamicPPL.jl/issues/367
modeldef = modelinfo[:modeldef]
if MacroTools.@capture(modeldef[:name], ::T_)
name = gensym(:f)
modeldef[:name] = Expr(:(::), name, T)
elseif MacroTools.@capture(modeldef[:name], (name_::_ | name_))
else
throw(ArgumentError("unsupported format of model function"))
end

# Update the function body of the user-specified model.
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef = modelinfo[:modeldef]
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$(modeldef[:name]),
$allargs_namedtuple,
$defaults_namedtuple,
)
return $(DynamicPPL.Model)($name, $allargs_namedtuple, $defaults_namedtuple)
end

return MacroTools.@q begin
Expand Down
24 changes: 10 additions & 14 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
name::Symbol
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
Expand Down Expand Up @@ -34,53 +33,49 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
AbstractProbabilisticProgram
name::Symbol
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx

@doc """
Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple)
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
Create a model of name `name` with evaluation function `f` and missing arguments
overwritten by `missings`.
Create a model with evaluation function `f` and missing arguments overwritten by
`missings`.
"""
function Model{missings}(
name::Symbol,
f::F,
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
context::Ctx=DefaultContext(),
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
name, f, args, defaults, context
f, args, defaults, context
)
end
end

"""
Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()])
Model(f, args::NamedTuple[, defaults::NamedTuple = ()])
Create a model of name `name` with evaluation function `f` and missing arguments deduced
from `args`.
Create a model with evaluation function `f` and missing arguments deduced from `args`.
Default arguments `defaults` are used internally when constructing instances of the same
model with different arguments.
"""
@generated function Model(
name::Symbol,
f::F,
args::NamedTuple{argnames,Targs},
defaults::NamedTuple=NamedTuple(),
context::AbstractContext=DefaultContext(),
) where {F,argnames,Targs}
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
return :(Model{$missings}(name, f, args, defaults, context))
return :(Model{$missings}(f, args, defaults, context))
end

function contextualize(model::Model, context::AbstractContext)
return Model(model.name, model.f, model.args, model.defaults, context)
return Model(model.f, model.args, model.defaults, context)
end

"""
Expand Down Expand Up @@ -518,7 +513,8 @@ getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missing
Get the name of the `model` as `Symbol`.
"""
Base.nameof(model::Model) = model.name
Base.nameof(model::Model) = Symbol(model.f)
Base.nameof(model::Model{<:Function}) = nameof(model.f)

"""
rand([rng=Random.GLOBAL_RNG], [T=NamedTuple], model::Model)
Expand Down
4 changes: 2 additions & 2 deletions src/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ end
return quote
$(warnings...)
Model{$(Tuple(missings))}(
model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
)
end
end
Expand Down Expand Up @@ -237,6 +237,6 @@ end
# `args` is inserted as properly typed NamedTuple expression;
# `missings` is splatted into a tuple at compile time and inserted as literal
return :(Model{$(Tuple(missings))}(
model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
))
end
2 changes: 1 addition & 1 deletion src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ julia> @varname(var"my prefix.x") in keys(VarInfo(outer()))
true
julia> # Using string interpolation.
@model outer() = @submodel prefix="\$(inner().name)" a = inner()
@model outer() = @submodel prefix="\$(nameof(inner()))" a = inner()
outer (generic function with 2 methods)
julia> @varname(var"inner.x") in keys(VarInfo(outer()))
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function test_sampler_demo_models(
rtol=1e-3,
kwargs...,
)
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS
@testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...)
μ = meanfunction(chain)
@test μ target atol = atol rtol = rtol
Expand Down
30 changes: 30 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# some functors (#367)
struct MyModel
a::Int
end
@model function (f::MyModel)(x)
m ~ Normal(f.a, 1)
return x ~ Normal(m, 1)
end
struct MyZeroModel end
@model function (::MyZeroModel)(x)
m ~ Normal(0, 1)
return x ~ Normal(m, 1)
end

@testset "model.jl" begin
@testset "convenience functions" begin
model = gdemo_default
Expand Down Expand Up @@ -61,9 +75,25 @@
m ~ Normal(0, 1)
x ~ Normal(m, 1)
end
function test3 end
@model function (::typeof(test3))(x)
m ~ Normal(0, 1)
return x ~ Normal(m, 1)
end
function test4 end
@model function (a::typeof(test4))(x)
m ~ Normal(0, 1)
return x ~ Normal(m, 1)
end

@test nameof(test1(rand())) == :test1
@test nameof(test2(rand())) == :test2
@test nameof(test3(rand())) == :test3
@test nameof(test4(rand())) == :test4

# callables
@test nameof(MyModel(3)(rand())) == Symbol("MyModel(3)")
@test nameof(MyZeroModel()(rand())) == Symbol("MyZeroModel()")
end

@testset "Internal methods" begin
Expand Down
3 changes: 2 additions & 1 deletion test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
end
end

@testset "SimpleVarInfo on $(model.name)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "SimpleVarInfo on $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
# We might need to pre-allocate for the variable `m`, so we need
# to see whether this is the case.
m = model().m
Expand Down
2 changes: 1 addition & 1 deletion test/turing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DynamicPPL = "0.18"
DynamicPPL = "0.19"
Turing = "0.18, 0.19, 0.20"
julia = "1.3"

2 comments on commit 748b191

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56493

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.19.0 -m "<description of version>" 748b19184e60b25313e267836aedf8d6c7fc47fa
git push origin v0.19.0

Please sign in to comment.