Skip to content

Commit

Permalink
added ci
Browse files Browse the repository at this point in the history
  • Loading branch information
Devetak committed Jan 18, 2024
1 parent cf38c13 commit 4d251fe
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
name: CI

# Controls when the action will run.
on:
on: [push, pull_request]
# Triggers the workflow on push or pull request events but only for the main branch
push:
branches: [main]
Expand Down
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ uuid = "944e5dbc-7108-4f31-a215-df58b8009117"
authors = ["Mitja Devetak <[email protected]>"]
version = "0.1.0"


[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2 changes: 1 addition & 1 deletion src/DiscoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function heaviside(x)
return x > zero(T) ? one(T) : zero(T)
end


include("./ignore_gradient.jl")



Expand Down
6 changes: 6 additions & 0 deletions src/ignore_gradient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export ignore_gradient

using ForwardDiff, ChainRulesCore

ignore_gradient(x) = ChainRulesCore.@ignore_derivatives x
ignore_gradient(x::ForwardDiff.Dual) = typeof(x)(x.value)
17 changes: 17 additions & 0 deletions test/ignore_gradient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using DiscoDiff
using Test
using ForwardDiff, Zygote

@testset "ignore gradient" begin

function f(x)

return 3x + x^2 - ignore_gradient(x^2)

end

x = 2.0
@test isapprox(ForwardDiff.derivative(f, x), 3 + 2x, atol = 1e-8)
@test isapprox(Zygote.gradient(f, x)[1], 3 + 2x, atol = 1e-8)

end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ using DiscoDiff, Test
@test heaviside(2.0) == 1.0
@test heaviside(-2.0) == 0.0
end

include("./ignore_gradient.jl")

0 comments on commit 4d251fe

Please sign in to comment.