-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for eigen-solvers #2264
Comments
@michel2323 potentially another one for the blas rules? |
So I tried today to make an EnzymeRule for this based on the one in ChainRules.jl, but I'm running into some trouble caused by using Enzyme, LinearAlgebra
using LinearAlgebra: BlasFloat
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
function forward(config::FwdConfig,
func::Const{typeof(eigen!)},
::Type{<:Duplicated},
A::Duplicated; kwargs...)
A, ΔA = A.val, A.dval
if ishermitian(A)
error("Not yet implemented")
end
# adapted from Chainrules.jl
F = eigen!(A; kwargs...)::Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂K = tmp * V
∂Kdiag = @view ∂K[diagind(∂K)]
∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag)
∂K ./= transpose(λ) .- λ
fill!(∂Kdiag, 0)
∂V = mul!(tmp, V, ∂K)
_eigen_norm_phase_fwd!(∂V, A, V)
∂F = Eigen(∂λ, ∂V)::Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}
Duplicated(F, ∂F)
end
function _eigen_norm_phase_fwd!(∂V, A, V)
# From Chainrules.jl
@inbounds for i in axes(V, 2)
v, ∂v = @views V[:, i], ∂V[:, i]
# account for unit normalization
∂c_norm = -realdot(v, ∂v)
if eltype(V) <: Real
∂c = ∂c_norm
else
# account for rotation of largest element to real
k = _findrealmaxabs2(v)
∂c_phase = -imag(∂v[k]) / real(v[k])
∂c = complex(∂c_norm, ∂c_phase)
end
∂v .+= v .* ∂c
end
return ∂V
end
# From https://github.com/JuliaMath/RealDot.jl/blob/main/src/RealDot.jl
@inline realdot(x, y) = real(LinearAlgebra.dot(x, y))
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
@inline realdot(x::Real, y::Number) = x * real(y)
@inline realdot(x::Number, y::Real) = real(x) * y
@inline realdot(x::Real, y::Real) = x * y
# From ChainTules.jl
function _findrealmaxabs2(x)
amax = abs2(first(x))
imax = 1
@inbounds for i in 2:length(x)
xi = x[i]
!isreal(xi) && continue
a = abs2(xi)
a < amax && continue
amax, imax = a, i
end
return imax
end julia> let
autodiff(Forward, Duplicated(1.0, 1.0)) do x
M = [x-im 1+x
1-x x^2]
λ, V = eigen(M)
sum(V) - sum(λ)
end
end
ERROR: Enzyme execution failed.
Enzyme: incorrect return type of prima/shadow forward custom rule - FwdConfigWidth{1, true, true, false} Duplicated{Union{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}, Eigen{ComplexF64, Float64, Matrix{ComplexF64}, Vector{Float64}}}} Type[Const{typeof(Core.kwcall)}, Const{typeof(eigen!)}, Duplicated{Matrix{ComplexF64}}] want just shadow type Duplicated{Union{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}, Eigen{ComplexF64, Float64, Matrix{ComplexF64}, Vector{Float64}}}} found Duplicated{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}}
Stacktrace:
[1] #_eigen#96
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:252 [inlined]
[2] fwddiffejulia___eigen_96_47518wrap
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:0
[3] macro expansion
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
[4] enzyme_call
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
[5] ForwardModeThunk
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
[6] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::LinearAlgebra.var"##_eigen#96", df::Nothing, primal_1::Bool, shadow_1_1::Nothing, primal_2::Bool, shadow_2_1::Nothing, primal_3::typeof(LinearAlgebra.eigsortby), shadow_3_1::Nothing, primal_4::typeof(LinearAlgebra._eigen), shadow_4_1::Nothing, primal_5::Matrix{…}, shadow_5_1::Matrix{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/jitrules.jl:303
[7] _eigen
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:247 [inlined]
[8] #eigen#94
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:239 [inlined]
[9] eigen
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:238 [inlined]
[10] #23
@ ./REPL[6]:5 [inlined]
[11] fwddiffejulia__23_47406_inner_1wrap
@ ./REPL[6]:0
[12] macro expansion
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
[13] enzyme_call
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
[14] ForwardModeThunk
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
[15] autodiff
@ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:654 [inlined]
[16] autodiff(mode::ForwardMode{false, FFIABI, true, false}, f::Const{var"#23#24"}, args::Duplicated{Float64})
@ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:544
[17] autodiff
@ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:516 [inlined]
[18] autodiff(f::Function, m::ForwardMode{false, FFIABI, false, false}, args::Duplicated{Float64})
@ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:1019
[19] top-level scope
@ REPL[6]:2
Some type information was truncated. Use `show(err)` to see complete types.
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This currently doesn't work:
It'd be really nice to be able to differentiate through eigen-solver calls.
The text was updated successfully, but these errors were encountered: