diff --git a/src/macros.jl b/src/macros.jl index 9ee37ce..8f6257d 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -83,6 +83,8 @@ Regular `*` is matrix multiplication, broadcasted `*` is elementwise multiplicat two have different gradients as defined above. `unbroadcast(a,b)` reduces `b` to the same shape as `a` by performing the necessary summations. """ +:(@primitive), :(@primitive1), :(@primitive2) + macro primitive(f,g...) # @primitive sin(x::Number),dy,y (dy.*cos.(x)) (f,dy,y) = fparse(f) b = Expr(:block) @@ -151,6 +153,8 @@ define the broadcasting version and `@zerograd2` if you only want to define the version. Note that `kwargs` are NOT unboxed. """ +:(@zerograd), :(@zerograd1), :(@zerograd2) + macro zerograd(f) # @zerograd sign(x::Number) (f,dy,y) = fparse(f) b = Expr(:block) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index a51919d..f942bd6 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -151,7 +151,7 @@ gradient of `sum(f(x...; kw...))`. Keyword arguments: * `verbose=1`: 0 prints nothing, 1 shows failing tests, 2 shows all tests. """ -gcheck, @gcheck +gcheck, :(@gcheck) function gcheck(f, x...; kw=(), nsample=10, verbose=1, rtol=0.05, atol=0.01, delta=0.0001) y = @diff gcsum(f, x...; kw...)