Skip to content

Commit

Permalink
Don't "compile" inputs of macros (#222)
Browse files Browse the repository at this point in the history
At the moment we will actually call `generate_mainbody!` on inputs to macros inside the model, e.g. in a model `@mymacro x ~ Normal()` will actually result in code `@mymacro $(generate_mainbody!(:(x ~ Normal())))` (or something, you get the idea). 

IMO, this shouldn't be done for the following reasons:
1. Breaks with what you'd expect in Julia, IMO, which is that a macro eats the "raw" code.
2. Means that if we want to do stuff like `@reparam` from #220  (and a bunch of others, see #221 for a small list of possibilities), we need touch the compiler rather than just make a small macro that will perform transformations *after* the compiler has done it's job (referring to DynamicPPL compiler here). 
3. If the user wants to use a macro on some variables, but they want the actual variable rather than messing around with the sample-statement, they can just separate it into two lines, e.g. `x ~ Normal(); @mymacro ...`. 

Also, to be completely honest, for the longest time I've just assumed that I'm not even allowed to do `@mymacro x ~ Normal()` and have things work 😅 I bet a lot of people have the same impression by default (though this might of course just not be true:) )
  • Loading branch information
torfjelde committed Apr 4, 2021
1 parent 2d6ef3f commit f531f12
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.9"
version = "0.10.10"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
33 changes: 16 additions & 17 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
macro model(expr, warn=true)
# include `LineNumberNode` with information about the call site in the
# generated function for easier debugging and interpretation of error messages
esc(model(expr, __source__, warn))
esc(model(__module__, __source__, expr, warn))
end

function model(expr, linenumbernode, warn)
function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:body] = generate_mainbody(
modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
)

return build_output(modelinfo, linenumbernode)
Expand Down Expand Up @@ -155,53 +155,52 @@ function build_model_info(input_expr)
end

"""
generate_mainbody(expr, args, warn)
generate_mainbody(mod, expr, args, warn)
Generate the body of the main evaluation function from expression `expr` and arguments
`args`.
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(expr, args, warn) = generate_mainbody!(Symbol[], expr, args, warn)
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)

generate_mainbody!(found, x, args, warn) = x
function generate_mainbody!(found, sym::Symbol, args, warn)
generate_mainbody!(mod, found, x, args, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
if warn && sym in INTERNALNAMES && sym found
@warn "you are using the internal variable `$(sym)`"
push!(found, sym)
end
return sym
end
function generate_mainbody!(found, expr::Expr, args, warn)
function generate_mainbody!(mod, found, expr::Expr, args, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Apply the `@.` macro first.
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
expr.args[1] === Symbol("@__dot__")
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn)
# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_dot_tilde(generate_mainbody!(found, L, args, warn),
generate_mainbody!(found, R, args, warn),
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return generate_tilde(generate_mainbody!(found, L, args, warn),
generate_mainbody!(found, R, args, warn),
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
end

return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn), expr.args)...)
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
end


Expand Down
44 changes: 44 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@ macro custom(expr)
end
end

macro mymodel1(ex)
# check if expression was modified by the DynamicPPL "compiler"
if ex == :(y ~ Uniform())
return esc(:(x ~ Normal()))
else
return esc(:(z ~ Exponential()))
end
end

struct MyModelStruct{T}
x::T
end
Base.:~(x, y::MyModelStruct) = y.x
macro mymodel2(ex)
# check if expression was modified by the DynamicPPL "compiler"
if ex == :(y ~ Uniform())
# Just returns 42
return :(4 ~ MyModelStruct(42))
else
return :(return -1)
end
end

@testset "compiler.jl" begin
@testset "model macro" begin
@model function testmodel_comp(x, y)
Expand Down Expand Up @@ -269,4 +292,25 @@ end
end
@test isempty(VarInfo(demo_with(0.0)))
end

@testset "macros within model" begin
# Macro expansion
@model function demo()
@mymodel1(y ~ Uniform())
end

@test haskey(VarInfo(demo()), @varname(x))

# Interpolation
# Will fail if:
# 1. Compiler expands `y ~ Uniform()` before expanding the macros
# => returns -1.
# 2. `@mymodel` is expanded before entire `@model` has been
# expanded => errors since `MyModelStruct` is not a distribution,
# and hence `tilde_observe` errors.
@model function demo()
$(@mymodel2(y ~ Uniform()))
end
@test demo()() == 42
end
end

2 comments on commit f531f12

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33524

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.10 -m "<description of version>" f531f12a47f10c06166bbb1e5469abf7b330c475
git push origin v0.10.10

Please sign in to comment.