Skip to content

Commit

Permalink
WIP: fix KrylovJL_GMRES with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 28, 2023
1 parent 5a8aa51 commit 8fc4ae3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
Expand Down
8 changes: 5 additions & 3 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
@test db1 db12
@test db2 db22

#=

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
Expand All @@ -117,9 +117,11 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
norm(s1 + s2)
end

dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA dA2 atol=5e-5
@test db1 db12
@test db2 ≈ db22
=#
@test db2 db22

0 comments on commit 8fc4ae3

Please sign in to comment.