Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for eigen-solvers #2264

Open
MasonProtter opened this issue Jan 13, 2025 · 2 comments
Open

Support for eigen-solvers #2264

MasonProtter opened this issue Jan 13, 2025 · 2 comments

Comments

@MasonProtter
Copy link

This currently doesn't work:

julia> using Enzyme

julia> let 
           autodiff(Forward, Duplicated(1.0, 1.0)) do x
               M = [x    1+x
                    1+x' x^2]
               
               sum(eigvecs(M))
           end
       end
ERROR: 
No forward mode derivative found for ejlstr$dsyevr_64_$libblastrampoline.so.5
 at context:   call void @"ejlstr$dsyevr_64_$libblastrampoline.so.5"(i8* noundef nonnull %5, i8* noundef nonnull %6, i8* noundef nonnull %7, i8* noundef nonnull %9, i64 %174, i8* noundef nonnull %11, i8* noundef nonnull %13, i8* noundef nonnull %15, i8* noundef nonnull %17, i8* noundef nonnull %19, i8* noundef nonnull %21, i64 noundef %149, i64 %157, i64 %175, i8* noundef nonnull %23, i64 %150, i64 %176, i8* noundef nonnull %25, i64 %177, i8* noundef nonnull %27, i8* noundef nonnull %4, i64 noundef 1, i64 noundef 1, i64 noundef 1) #141 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %173, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %169, { i8*, {} addrspace(10)* } %107, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %165, { i8*, {} addrspace(10)* } %156, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %161, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !262

Stacktrace:
 [1] syevr!
   @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:5397


Stacktrace:
  [1] syevr!
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:5397
  [2] eigen!
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/symmetriceigen.jl:8 [inlined]
  [3] #_eigen#96
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:250 [inlined]
  [4] fwddiffejulia___eigen_96_12524wrap
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:0
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
  [6] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
  [7] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
  [8] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::LinearAlgebra.var"##_eigen#96", df::Nothing, primal_1::Bool, shadow_1_1::Nothing, primal_2::Bool, shadow_2_1::Nothing, primal_3::typeof(LinearAlgebra.eigsortby), shadow_3_1::Nothing, primal_4::typeof(LinearAlgebra._eigen), shadow_4_1::Nothing, primal_5::Matrix{…}, shadow_5_1::Matrix{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/jitrules.jl:303
  [9] _eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:247 [inlined]
 [10] #eigen#94
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:239 [inlined]
 [11] eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:238 [inlined]
 [12] eigvecs
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:274 [inlined]
 [13] #1
    @ ./REPL[2]:6 [inlined]
 [14] fwddiffejulia__1_7197_inner_1wrap
    @ ./REPL[2]:0
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
 [16] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
 [17] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
 [18] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:654 [inlined]
 [19] autodiff(mode::ForwardMode{false, FFIABI, true, false}, f::Const{var"#1#2"}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:544
 [20] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:516 [inlined]
 [21] autodiff(f::Function, m::ForwardMode{false, FFIABI, false, false}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:1019
 [22] top-level scope
    @ REPL[2]:2
Some type information was truncated. Use `show(err)` to see complete types.

It'd be really nice to be able to differentiate through eigen-solver calls.

@wsmoses
Copy link
Member

wsmoses commented Jan 13, 2025

@michel2323 potentially another one for the blas rules?

@MasonProtter
Copy link
Author

So I tried today to make an EnzymeRule for this based on the one in ChainRules.jl, but I'm running into some trouble caused by eigen being type unstable. Any advice on how to deal with type unstable forward rules @wsmoses?

using Enzyme, LinearAlgebra
using LinearAlgebra: BlasFloat

import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

function forward(config::FwdConfig,
                 func::Const{typeof(eigen!)},
                 ::Type{<:Duplicated},
                 A::Duplicated; kwargs...)
    A, ΔA = A.val, A.dval
    if ishermitian(A)
        error("Not yet implemented")
    end
    # adapted from Chainrules.jl 
    F = eigen!(A; kwargs...)::Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}
    λ, V = F.values, F.vectors
    tmp = V \ ΔA
    ∂K = tmp * V
    ∂Kdiag = @view ∂K[diagind(∂K)]
    ∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag)
    ∂K ./= transpose(λ) .- λ
    fill!(∂Kdiag, 0)
    ∂V = mul!(tmp, V, ∂K)
    _eigen_norm_phase_fwd!(∂V, A, V)
    ∂F = Eigen(∂λ, ∂V)::Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}
    Duplicated(F, ∂F)
end

function _eigen_norm_phase_fwd!(∂V, A, V)
    # From Chainrules.jl 
    @inbounds for i in axes(V, 2)
        v, ∂v = @views V[:, i], ∂V[:, i]
        # account for unit normalization
        ∂c_norm = -realdot(v, ∂v)
        if eltype(V) <: Real
            ∂c = ∂c_norm
        else
            # account for rotation of largest element to real
            k = _findrealmaxabs2(v)
            ∂c_phase = -imag(∂v[k]) / real(v[k])
            ∂c = complex(∂c_norm, ∂c_phase)
        end
        ∂v .+= v .* ∂c
    end
    return ∂V
end

# From https://github.com/JuliaMath/RealDot.jl/blob/main/src/RealDot.jl
@inline realdot(x, y) = real(LinearAlgebra.dot(x, y))
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
@inline realdot(x::Real, y::Number) = x * real(y)
@inline realdot(x::Number, y::Real) = real(x) * y
@inline realdot(x::Real, y::Real) = x * y

# From ChainTules.jl
function _findrealmaxabs2(x)
    amax = abs2(first(x))
    imax = 1
    @inbounds for i in 2:length(x)
        xi = x[i]
        !isreal(xi) && continue
        a = abs2(xi)
        a < amax && continue
        amax, imax = a, i
    end
    return imax
end
julia> let 
           autodiff(Forward, Duplicated(1.0, 1.0)) do x
               M = [x-im    1+x
                    1-x x^2]
               λ, V = eigen(M)
               sum(V) - sum(λ)
           end
       end
ERROR: Enzyme execution failed.
Enzyme: incorrect return type of prima/shadow forward custom rule - FwdConfigWidth{1, true, true, false} Duplicated{Union{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}, Eigen{ComplexF64, Float64, Matrix{ComplexF64}, Vector{Float64}}}} Type[Const{typeof(Core.kwcall)}, Const{typeof(eigen!)}, Duplicated{Matrix{ComplexF64}}] want just shadow type Duplicated{Union{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}, Eigen{ComplexF64, Float64, Matrix{ComplexF64}, Vector{Float64}}}} found Duplicated{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}}

Stacktrace:
  [1] #_eigen#96
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:252 [inlined]
  [2] fwddiffejulia___eigen_96_47518wrap
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:0
  [3] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
  [4] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
  [5] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
  [6] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::LinearAlgebra.var"##_eigen#96", df::Nothing, primal_1::Bool, shadow_1_1::Nothing, primal_2::Bool, shadow_2_1::Nothing, primal_3::typeof(LinearAlgebra.eigsortby), shadow_3_1::Nothing, primal_4::typeof(LinearAlgebra._eigen), shadow_4_1::Nothing, primal_5::Matrix{…}, shadow_5_1::Matrix{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/jitrules.jl:303
  [7] _eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:247 [inlined]
  [8] #eigen#94
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:239 [inlined]
  [9] eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:238 [inlined]
 [10] #23
    @ ./REPL[6]:5 [inlined]
 [11] fwddiffejulia__23_47406_inner_1wrap
    @ ./REPL[6]:0
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
 [13] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
 [14] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
 [15] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:654 [inlined]
 [16] autodiff(mode::ForwardMode{false, FFIABI, true, false}, f::Const{var"#23#24"}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:544
 [17] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:516 [inlined]
 [18] autodiff(f::Function, m::ForwardMode{false, FFIABI, false, false}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:1019
 [19] top-level scope
    @ REPL[6]:2
Some type information was truncated. Use `show(err)` to see complete types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants