Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
Devetak committed Feb 22, 2024
1 parent f8ba79b commit 799d2ea
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
58 changes: 47 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,72 @@
# DiscoDiff.jl

A small package for differentiable discontinuities in Julia. Implements a simple API to generate differentiable discontinuous functions using the pass-through trick.
A small package for differentiable discontinuities in Julia. Implements a simple API to generate differentiable discontinuous functions using the pass-through trick. Works both in forward and reverse mode with scalars and arrays.

## Main API

To generate a differentiable function version of a discontinuous function `f` such that the gradient of `f` is the gradient of `g`, simply use:

````julia
new_f = construct_diff_version(f,g)

````

Note that it is recommended that `g` satisfies the limit:
This is used in the case

$$
\lim_{k \to \infty}g(kx) = f(x).
\frac{df}{dx}
$$

is either not defined or does not have the desired properties. For example where $f$ is the sign function. Sometimes we want to still be able to propagate gradients trough this. In this case we impose

$$
\frac{df}{dx} = \frac{dg}{dx}
$$

Use it as:

```julia
new_f(x)
# control gradient steppes
new_f(x, k = 100)

new_f(2.0, k = 100.0)
```

Currently supports only scalar to scalar functions. Currently assumes that the discontinuity is at `f(0)` only.
In the second case we have

$$
\frac{df}{dx}(2.0) = \frac{dg}{dx}(100.0 \cdot 2.0)
$$

Note: to avoid type instabilities ensure $x$ and $k$ are of the same type. The package works both with forward and reverse mode automatic differentiation.

````julia
using Zygote, ForwardDiff
using DiscoDiff
using LinearAlgebra

f(x) = 1.0
g(x) = x
new_f = construct_diff_version(f,g)

f(1.0) == 1.0
Zygote.gradient(new_f, 1.0)[1] == 1.0
ForwardDiff.derivative(new_f, 1.0) == 1.0
````

And it supports not scalar functions

````julia
using Zygote, ForwardDiff
using DiscoDiff
f = construct_diff_version(x -> x, x -> x.^2)
x = rand(10)
f(x) == x
Zygote.jacobian(f, x)[1] == diagm(2 * x)
ForwardDiff.jacobian(f, x) == diagm(2 * x)
````

# Other

# Heaviside Function Documentation
We also export to read-made function.

## Overview

Expand Down Expand Up @@ -70,14 +106,14 @@ We implement a differentiable version of the sign function, where the derivative
For the Heaviside function:

```julia
heaviside(1.0) == 1.0
heaviside(1.0, k = 2.0) == 1.0
heaviside(1.0)
heaviside(1.0, k = 2.0)
```

For the sign function

```julia
sign_diff(2.0) == 1.0
sign_diff(2.0)
sign_diff(2.0, k = 2.0)
```

Expand Down
18 changes: 14 additions & 4 deletions src/ignore_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ API. In forward mode it simply returns a new number with a zero dual.
ignore_gradient(x) = ChainRulesCore.@ignore_derivatives x
ignore_gradient(x::ForwardDiff.Dual) = typeof(x)(x.value)

function ignore_gradient(arr::AbstractArray{<:ForwardDiff.Dual})
return typeof(arr)(ignore_gradient.(arr))
end

"""
construct_diff_version(f, g) -> pass_trough_function
Expand All @@ -25,11 +29,17 @@ the steepnes of the gradient, default is 1.
"""
function construct_diff_version(f, g)
@inline function pass_trough_function(x::T; k = nothing) where {T}
if k == nothing
k = one(T)
if isnothing(k)
if T <: Number
k = one(T)
elseif T <: AbstractArray
k = one(eltype(T))
else
error("Type not supported only supports Number and AbstractArray.")
end
end
zero = g(x * k) - ignore_gradient(g(x * k))
return ignore_gradient(f(x)) + zero
zero = g(x * k) .- ignore_gradient(g(x * k))
return ignore_gradient(f(x)) .+ zero
end
return pass_trough_function
end

0 comments on commit 799d2ea

Please sign in to comment.