From 799d2ea72749c3150b3bc5517e2714710b28aae4 Mon Sep 17 00:00:00 2001 From: Mitja Devetak Date: Thu, 22 Feb 2024 13:39:07 +0100 Subject: [PATCH] update README --- README.md | 58 ++++++++++++++++++++++++++++++++++-------- src/ignore_gradient.jl | 18 ++++++++++--- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 46e3c38..027b10c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # DiscoDiff.jl - A small package for differentiable discontinuities in Julia. Implements a simple API to generate differentiable discontinuous functions using the pass-through trick. + A small package for differentiable discontinuities in Julia. Implements a simple API to generate differentiable discontinuous functions using the pass-through trick. Works both in forward and reverse mode with scalars and arrays. ## Main API @@ -8,29 +8,65 @@ To generate a differentiable function version of a discontinuous function `f` su ````julia new_f = construct_diff_version(f,g) - ```` -Note that it is recommended that `g` satisfies the limit: +This is used in the case $$ -\lim_{k \to \infty}g(kx) = f(x). +\frac{df}{dx} $$ +is either not defined or does not have the desired properties. For example where $f$ is the sign function. Sometimes we want to still be able to propagate gradients trough this. In this case we impose + +$$ +\frac{df}{dx} = \frac{dg}{dx} +$$ Use it as: ```julia new_f(x) # control gradient steppes -new_f(x, k = 100) - +new_f(2.0, k = 100.0) ``` -Currently supports only scalar to scalar functions. Currently assumes that the discontinuity is at `f(0)` only. +In the second case we have + +$$ +\frac{df}{dx}(2.0) = \frac{dg}{dx}(100.0 \cdot 2.0) +$$ + +Note: to avoid type instabilities ensure $x$ and $k$ are of the same type. The package works both with forward and reverse mode automatic differentiation. + +````julia +using Zygote, ForwardDiff +using DiscoDiff +using LinearAlgebra + +f(x) = 1.0 +g(x) = x +new_f = construct_diff_version(f,g) + +f(1.0) == 1.0 +Zygote.gradient(new_f, 1.0)[1] == 1.0 +ForwardDiff.derivative(new_f, 1.0) == 1.0 +```` + +And it supports not scalar functions + +````julia +using Zygote, ForwardDiff +using DiscoDiff +f = construct_diff_version(x -> x, x -> x.^2) +x = rand(10) +f(x) == x +Zygote.jacobian(f, x)[1] == diagm(2 * x) +ForwardDiff.jacobian(f, x) == diagm(2 * x) +```` +# Other -# Heaviside Function Documentation +We also export to read-made function. ## Overview @@ -70,14 +106,14 @@ We implement a differentiable version of the sign function, where the derivative For the Heaviside function: ```julia -heaviside(1.0) == 1.0 -heaviside(1.0, k = 2.0) == 1.0 +heaviside(1.0) +heaviside(1.0, k = 2.0) ``` For the sign function ```julia -sign_diff(2.0) == 1.0 +sign_diff(2.0) sign_diff(2.0, k = 2.0) ``` diff --git a/src/ignore_gradient.jl b/src/ignore_gradient.jl index 223412f..1eaef8c 100644 --- a/src/ignore_gradient.jl +++ b/src/ignore_gradient.jl @@ -14,6 +14,10 @@ API. In forward mode it simply returns a new number with a zero dual. ignore_gradient(x) = ChainRulesCore.@ignore_derivatives x ignore_gradient(x::ForwardDiff.Dual) = typeof(x)(x.value) +function ignore_gradient(arr::AbstractArray{<:ForwardDiff.Dual}) + return typeof(arr)(ignore_gradient.(arr)) +end + """ construct_diff_version(f, g) -> pass_trough_function @@ -25,11 +29,17 @@ the steepnes of the gradient, default is 1. """ function construct_diff_version(f, g) @inline function pass_trough_function(x::T; k = nothing) where {T} - if k == nothing - k = one(T) + if isnothing(k) + if T <: Number + k = one(T) + elseif T <: AbstractArray + k = one(eltype(T)) + else + error("Type not supported only supports Number and AbstractArray.") + end end - zero = g(x * k) - ignore_gradient(g(x * k)) - return ignore_gradient(f(x)) + zero + zero = g(x * k) .- ignore_gradient(g(x * k)) + return ignore_gradient(f(x)) .+ zero end return pass_trough_function end