diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index abd2232e..cc2e5249 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -5,6 +5,8 @@ using LinearSolve.LinearAlgebra using EnzymeCore using EnzymeCore: EnzymeRules +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:LinearSolve.SciMLLinearSolveAlgorithm}) = true + function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} diff --git a/test/enzyme.jl b/test/enzyme.jl index b09c0de5..01239dc2 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -157,7 +157,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), @test db1 ≈ db12 @test db2 ≈ db22 -#= + function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -167,12 +167,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES()) norm(s1 + s2) end -Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); +Enzyme.autodiff(set_runtime_activity(Reverse), f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# A = rand(n, n); dA = zeros(n, n);