Skip to content

Commit

Permalink
Small simplification of compiler (#221)
Browse files Browse the repository at this point in the history
## Overview
At the moment, we perform a check at model-expansion as to whether or not `vsym(left) in args`, where `args` is the arguments of the model. 
1. If `true`, we return a block of code which uses `DynamicPPL.isassumption` to check whether or not to call `assume` or `observe` for the the variable present in `args`. 
2. Otherwise, we generate a block which is identical to the `assume` block in the if-statement mentioned in (1).

The thing is, `DynamicPPL.isassumption` performs exactly the same check as above but using `DynamicPPL.inargnames`, i.e. at runtime. So if we're using  `TypedVarInfo`, the check at macro-expansion vs. at runtime is completely redundant since all the information necessary to determine `DynamicPPL.inargnames` is available at compile-time.

Therefore I suggest we remove this check at model-expansion, and simply handle it using `DynamicPPL.isassumption`.

## Pros & cons
Pros:
- No need to pass `args` around everywhere
- `generate_tilde` and `generate_dot_tilde` are much simpler: two possible blocks we can generate, either a) assume/observe, or b) observe literal.

Cons:
- We need to perform _one_ more check at runtime when using `UntypedVarInfo`.


**IMO, this is really worth it.**

## Motivation (sort of)
The main motivation behind this PR is simplification, but there's a different reason why I came across this.

I came to this because I was thinking about trying to "customize" the behavior of `~`, and I was thinking of using a macro to do it, e.g. `@mymacro x ~ Normal()`. Atm we're actually performing model-expansion on the code passed to the macro and thus trying to alter the way DynamicPPL treats `~` using a macro is veeeery difficult since you actually have to work with the *expanded* code, but let's ignore that issue for now (and take that discussion somewhere else, because IMO we shouldn't do this). 

Suppose we didn't perform model-expansions of the code fed to the macros, then you can just copy-paste `generate_tilde`, customize it do what you want, and BAM, you got yourself a working `@mymacro x ~ Normal()` which can do neat stuff! This is *not* possible atm because we don't have access to `args`, and so you have to take the approach in this PR to get there. That means that it's of course possible to do atm, but it's a bit icky since it ends up looking fundamentally different from `generate_tilde` rather than just slightly different.

Then we can implement things like a `@tilde` which will expand to `generate_tilde` which can be used *internally* in functions (if the "internal" variables are present in the functions of course, but we can also simplify this in different ways), actually allowing people to modularize their models a bit, and `@reparam` from #220 using very similar pieces of code, a `@track` macro can be introduced to deal with the explicit tracking of variables rather than putting this directly in the compiler, etc. Endless opportunities! (Of course, I'm not suggesting we add these, but this makes it a bit easier to explore.)

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
torfjelde and devmotion committed Apr 7, 2021
1 parent 068e5d3 commit b25fdab
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: 1
version: 1.5
arch: x64
- uses: julia-actions/julia-buildpkg@latest
- name: Clone Downstream
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.11"
version = "0.10.12"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
94 changes: 38 additions & 56 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function model(mod, linenumbernode, expr, warn)

# Generate main body
modelinfo[:body] = generate_mainbody(
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
mod, modelinfo[:modeldef][:body], warn
)

return build_output(modelinfo, linenumbernode)
Expand Down Expand Up @@ -155,92 +155,84 @@ function build_model_info(input_expr)
end

"""
generate_mainbody(mod, expr, args, warn)
generate_mainbody(mod, expr, warn)
Generate the body of the main evaluation function from expression `expr` and arguments
`args`.
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)

generate_mainbody!(mod, found, x, args, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
if warn && sym in INTERNALNAMES && sym found
@warn "you are using the internal variable `$(sym)`"
push!(found, sym)
end
return sym
end
function generate_mainbody!(mod, found, expr::Expr, args, warn)
function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
return generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
) |> Base.remove_linenums!
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
return generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
) |> Base.remove_linenums!
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
end



"""
generate_tilde(left, right, args)
generate_tilde(left, right)
Generate an `observe` expression for data variables and `assume` expression for parameter
variables.
"""
function generate_tilde(left, right, args)
function generate_tilde(left, right)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds
@gensym out vn inds isassumption
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
else
$(DynamicPPL.tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

return quote
$(top...)
$left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn,
$inds, _varinfo)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
else
$(DynamicPPL.tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

Expand All @@ -252,40 +244,30 @@ function generate_tilde(left, right, args)
end

"""
generate_dot_tilde(left, right, args)
generate_dot_tilde(left, right)
Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right, args)
function generate_dot_tilde(left, right)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds
@gensym out vn inds isassumption
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
else
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

return quote
$(top...)
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
else
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

Expand Down

2 comments on commit b25fdab

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33800

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.12 -m "<description of version>" b25fdabebb69b4a1dcf281f19c8ac39fd5b6a4bd
git push origin v0.10.12

Please sign in to comment.