Skip to content

Commit

Permalink
test: SciMLJacobianOperators testing and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 24, 2024
1 parent 36b70c8 commit 91662b6
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI_NonlinearSolve.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
directories: src,ext,lib/SciMLJacobianOperators/src
- uses: codecov/codecov-action@v4
with:
file: lcov.info
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI_SciMLJacobianOperators.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/SciMLJacobianOperators {0}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
directories: lib/SciMLJacobianOperators/src
- uses: codecov/codecov-action@v4
with:
file: lcov.info
Expand Down
20 changes: 19 additions & 1 deletion lib/SciMLJacobianOperators/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -16,23 +17,40 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[compat]
ADTypes = "1.8.1"
Aqua = "0.8.7"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
DifferentiationInterface = "0.5"
Enzyme = "0.12, 0.13"
EnzymeCore = "0.7, 0.8"
ExplicitImports = "1.9.0"
FastClosures = "0.3.2"
FiniteDiff = "2.24.0"
ForwardDiff = "0.10.36"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
ReverseDiff = "1.15"
SciMLBase = "2.54.0"
SciMLOperators = "0.3"
Setfield = "1"
Test = "1.10"
TestItemRunner = "1"
Tracker = "0.2.35"
Zygote = "0.6.70"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["InteractiveUtils", "Test", "TestItemRunner"]
test = ["Aqua", "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", "InteractiveUtils", "ReverseDiff", "Test", "TestItemRunner", "Tracker", "Zygote"]
26 changes: 20 additions & 6 deletions lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module SciMLJacobianOperators

using ADTypes: ADTypes, AutoSparse
using ADTypes: ADTypes, AutoSparse, AutoEnzyme
using ConcreteStructs: @concrete
using ConstructionBase: ConstructionBase
using DifferentiationInterface: DifferentiationInterface
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra
using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction
Expand Down Expand Up @@ -115,10 +116,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
iip = SciMLBase.isinplace(prob)
T = promote_type(eltype(u), eltype(fu))

vjp_autodiff = get_dense_ad(vjp_autodiff)
vjp_autodiff = set_function_as_const(get_dense_ad(vjp_autodiff))
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)

jvp_autodiff = get_dense_ad(jvp_autodiff)
jvp_autodiff = set_function_as_const(get_dense_ad(jvp_autodiff))
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)

output_cache = fu isa Number ? T(fu) : similar(fu, T)
Expand Down Expand Up @@ -259,6 +260,9 @@ end
function Base.:*(JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray)
return JᵀJ.vjp_operator * (JᵀJ.jvp_operator * x)
end
function Base.:*(JᵀJ::StatefulJacobianNormalFormOperator, x::Number)
return JᵀJ.vjp_operator * (JᵀJ.jvp_operator * x)
end

function LinearAlgebra.mul!(
JᵀJx::AbstractArray, JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray)
Expand All @@ -284,7 +288,7 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
jac_cache = similar(u, eltype(fu), length(fu), length(u))
return @closure (vJ, v, u, p) -> begin
f.jac(jac_cache, u, p)
mul!(vec(vJ), jac_cache', vec(v))
LinearAlgebra.mul!(vec(vJ), jac_cache', vec(v))
return
end
return vjp_op, vjp_extras
Expand All @@ -298,6 +302,8 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
# TODO: Once DI supports const params we can use `p`
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
if SciMLBase.isinplace(f)
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
problems."
fu_cache = copy(fu)
v_fake = copy(fu)
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
Expand Down Expand Up @@ -326,11 +332,11 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
jac_cache = similar(u, eltype(fu), length(fu), length(u))
return @closure (Jv, v, u, p) -> begin
f.jac(jac_cache, u, p)
mul!(vec(Jv), jac_cache, vec(v))
LinearAlgebra.mul!(vec(Jv), jac_cache, vec(v))
return
end
else
return @closure (v, u, p, _) -> reshape(f.jac(u, p) * vec(v), size(u))
return @closure (v, u, p) -> reshape(f.jac(u, p) * vec(v), size(u))
end
end

Expand All @@ -339,6 +345,8 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
# TODO: Once DI supports const params we can use `p`
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
if SciMLBase.isinplace(f)
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
problems."
fu_cache = copy(fu)
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
return @closure (Jv, v, u, p) -> begin
Expand Down Expand Up @@ -375,6 +383,12 @@ function get_dense_ad(ad::AutoSparse)
return dense_ad
end

# In our case we know that it is safe to mark the function as const
set_function_as_const(ad) = ad
function set_function_as_const(ad::AutoEnzyme{M, Nothing}) where {M}
return AutoEnzyme(; ad.mode, function_annotation = EnzymeCore.Const)
end

export JacobianOperator, VecJacOperator, JacVecOperator
export StatefulJacobianOperator
export StatefulJacobianNormalFormOperator
Expand Down
Loading

0 comments on commit 91662b6

Please sign in to comment.