diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cafeeac..865c785 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: CI # Controls when the action will run. -on: +on: [push, pull_request] # Triggers the workflow on push or pull request events but only for the main branch push: branches: [main] diff --git a/Project.toml b/Project.toml index 97c130e..3f6c3ff 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,9 @@ uuid = "944e5dbc-7108-4f31-a215-df58b8009117" authors = ["Mitja Devetak "] version = "0.1.0" - [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/DiscoDiff.jl b/src/DiscoDiff.jl index 9b7fc2d..3ffbd05 100644 --- a/src/DiscoDiff.jl +++ b/src/DiscoDiff.jl @@ -7,7 +7,7 @@ function heaviside(x) return x > zero(T) ? one(T) : zero(T) end - +include("./ignore_gradient.jl") diff --git a/src/ignore_gradient.jl b/src/ignore_gradient.jl new file mode 100644 index 0000000..59550b6 --- /dev/null +++ b/src/ignore_gradient.jl @@ -0,0 +1,6 @@ +export ignore_gradient + +using ForwardDiff, ChainRulesCore + +ignore_gradient(x) = ChainRulesCore.@ignore_derivatives x +ignore_gradient(x::ForwardDiff.Dual) = typeof(x)(x.value) diff --git a/test/ignore_gradient.jl b/test/ignore_gradient.jl new file mode 100644 index 0000000..e9cb67f --- /dev/null +++ b/test/ignore_gradient.jl @@ -0,0 +1,17 @@ +using DiscoDiff +using Test +using ForwardDiff, Zygote + +@testset "ignore gradient" begin + + function f(x) + + return 3x + x^2 - ignore_gradient(x^2) + + end + + x = 2.0 + @test isapprox(ForwardDiff.derivative(f, x), 3 + 2x, atol = 1e-8) + @test isapprox(Zygote.gradient(f, x)[1], 3 + 2x, atol = 1e-8) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 1613a79..c7b8482 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,3 +4,5 @@ using DiscoDiff, Test @test heaviside(2.0) == 1.0 @test heaviside(-2.0) == 0.0 end + +include("./ignore_gradient.jl")