Skip to content

Commit

Permalink
WIP: fix KrylovJL_GMRES with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 12, 2024
1 parent 33911f6 commit 39b3717
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ using Enzyme

using EnzymeCore

@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true

function EnzymeCore.EnzymeRules.forward(
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down
6 changes: 4 additions & 2 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA),
@test db1 db12
@test db2 db22

#=

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
Expand All @@ -168,12 +168,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
norm(s1 + s2)
end

dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA dA2 atol=5e-5
@test db1 db12
@test db2 db22
=#

A = rand(n, n);
dA = zeros(n, n);
Expand Down

0 comments on commit 39b3717

Please sign in to comment.