From 153244f5ebac15f17572aeeedb8e181f439ee48a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 25 Sep 2024 14:09:17 -0400 Subject: [PATCH] feat: use contexts from DifferentiationInterface.jl --- lib/SciMLJacobianOperators/Project.toml | 2 +- .../src/SciMLJacobianOperators.jl | 56 ++++++++----------- lib/SciMLJacobianOperators/test/core_tests.jl | 6 +- 3 files changed, 26 insertions(+), 38 deletions(-) diff --git a/lib/SciMLJacobianOperators/Project.toml b/lib/SciMLJacobianOperators/Project.toml index 5c1bc20ff..5209241bc 100644 --- a/lib/SciMLJacobianOperators/Project.toml +++ b/lib/SciMLJacobianOperators/Project.toml @@ -19,7 +19,7 @@ ADTypes = "1.8.1" Aqua = "0.8.7" ConcreteStructs = "0.2.3" ConstructionBase = "1.5" -DifferentiationInterface = "0.5" +DifferentiationInterface = "0.6" Enzyme = "0.12, 0.13" EnzymeCore = "0.7, 0.8" ExplicitImports = "1.9.0" diff --git a/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl b/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl index a173dc395..e807e983c 100644 --- a/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl +++ b/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl @@ -1,9 +1,9 @@ module SciMLJacobianOperators -using ADTypes: ADTypes, AutoSparse, AutoEnzyme +using ADTypes: ADTypes, AutoSparse using ConcreteStructs: @concrete using ConstructionBase: ConstructionBase -using DifferentiationInterface: DifferentiationInterface +using DifferentiationInterface: DifferentiationInterface, Constant using EnzymeCore: EnzymeCore using FastClosures: @closure using LinearAlgebra: LinearAlgebra @@ -112,10 +112,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = iip = SciMLBase.isinplace(prob) T = promote_type(eltype(u), eltype(fu)) - vjp_autodiff = set_function_as_const(get_dense_ad(vjp_autodiff)) + vjp_autodiff = get_dense_ad(vjp_autodiff) vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff) - jvp_autodiff = set_function_as_const(get_dense_ad(jvp_autodiff)) + jvp_autodiff = get_dense_ad(jvp_autodiff) jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff) output_cache = fu isa Number ? T(fu) : similar(fu, T) @@ -295,23 +295,21 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem, @assert autodiff!==nothing "`vjp_autodiff` must be provided if `f` doesn't have \ analytic `vjp` or `jac`." - # TODO: Once DI supports const params we can use `p` - fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p) if SciMLBase.isinplace(f) - @assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \ - problems." + @assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \ + problems." fu_cache = copy(fu) - v_fake = copy(fu) - di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake) + di_extras = DI.prepare_pullback(f, fu_cache, autodiff, u, (fu,), Constant(prob.p)) return @closure (vJ, v, u, p) -> begin - DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff, - u, reshape(v, size(fu_cache)), di_extras) + DI.pullback!(f, fu_cache, (reshape(vJ, size(u)),), di_extras, autodiff, + u, (reshape(v, size(fu_cache)),), Constant(p)) return end else - di_extras = DI.prepare_pullback(fₚ, autodiff, u, fu) + di_extras = DI.prepare_pullback(f, autodiff, u, (fu,), Constant(prob.p)) return @closure (v, u, p) -> begin - return DI.pullback(fₚ, autodiff, u, reshape(v, size(fu)), di_extras) + return only(DI.pullback( + f, di_extras, autodiff, u, (reshape(v, size(fu)),), Constant(p))) end end end @@ -342,23 +340,21 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem, @assert autodiff!==nothing "`jvp_autodiff` must be provided if `f` doesn't have \ analytic `vjp` or `jac`." - # TODO: Once DI supports const params we can use `p` - fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p) if SciMLBase.isinplace(f) - @assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \ - problems." + @assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \ + problems." fu_cache = copy(fu) - di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u) + di_extras = DI.prepare_pushforward(f, fu_cache, autodiff, u, (u,), Constant(prob.p)) return @closure (Jv, v, u, p) -> begin - DI.pushforward!( - fₚ, fu_cache, reshape(Jv, size(fu_cache)), - autodiff, u, reshape(v, size(u)), di_extras) + DI.pushforward!(f, fu_cache, (reshape(Jv, size(fu_cache)),), di_extras, + autodiff, u, (reshape(v, size(u)),), Constant(p)) return end else - di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u) + di_extras = DI.prepare_pushforward(f, autodiff, u, (u,), Constant(prob.p)) return @closure (v, u, p) -> begin - return DI.pushforward(fₚ, autodiff, u, reshape(v, size(u)), di_extras) + return only(DI.pushforward( + f, di_extras, autodiff, u, (reshape(v, size(u)),), Constant(p))) end end end @@ -371,10 +367,8 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem, @assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \ analytic `vjp` or `jvp` or `jac`." - # TODO: Once DI supports const params we can use `p` - fₚ = Base.Fix2(f, prob.p) - di_extras = DI.prepare_derivative(fₚ, autodiff, u) - return @closure (v, u, p) -> DI.derivative(fₚ, autodiff, u, di_extras) * v + di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p)) + return @closure (v, u, p) -> DI.derivative(f, di_extras, autodiff, u, Constant(p)) * v end get_dense_ad(::Nothing) = nothing @@ -386,12 +380,6 @@ function get_dense_ad(ad::AutoSparse) return dense_ad end -# In our case we know that it is safe to mark the function as const -set_function_as_const(ad) = ad -function set_function_as_const(ad::AutoEnzyme{M, Nothing}) where {M} - return AutoEnzyme(; ad.mode, function_annotation = EnzymeCore.Const) -end - export JacobianOperator, VecJacOperator, JacVecOperator export StatefulJacobianOperator export StatefulJacobianNormalFormOperator diff --git a/lib/SciMLJacobianOperators/test/core_tests.jl b/lib/SciMLJacobianOperators/test/core_tests.jl index e3b595221..6fb024a0c 100644 --- a/lib/SciMLJacobianOperators/test/core_tests.jl +++ b/lib/SciMLJacobianOperators/test/core_tests.jl @@ -7,7 +7,7 @@ AutoEnzyme(), AutoEnzyme(; mode = Enzyme.Reverse), AutoZygote(), - AutoReverseDiff(), + # AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503 AutoTracker(), AutoFiniteDiff() ] @@ -91,7 +91,7 @@ end reverse_ADs = [ AutoEnzyme(), AutoEnzyme(; mode = Enzyme.Reverse), - AutoReverseDiff(), + # AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503 AutoFiniteDiff() ] @@ -182,7 +182,7 @@ end AutoEnzyme(; mode = Enzyme.Reverse), AutoZygote(), AutoTracker(), - AutoReverseDiff(), + # AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503 AutoFiniteDiff() ]