Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error differentiating composed cross product with Zygote #689

Closed
benjaminfaber opened this issue Jan 12, 2023 · 2 comments
Closed

Error differentiating composed cross product with Zygote #689

benjaminfaber opened this issue Jan 12, 2023 · 2 comments

Comments

@benjaminfaber
Copy link

I have run into an error when trying to compute the gradient or jacobian using Zygote for a function that contains a cross product. I'm relatively new to using AD, so I would like to know if this is user error or something that needs to be fixed/added to the ChainRules package. The MWA and stacktrace is below:

using LinearAlgebra, Zygote

g(a, b) = hypot(cross(a, b)...)
h(a, b) = dot(cross(a, b))

jacobian(g, rand(3), rand(3))

The stacktrace:

ERROR: MethodError: no method matching cross(::Vector{Float64}, ::Tangent{Any, Tuple{Float64, Float64, Float64}})
Closest candidates are:
  cross(::AbstractVector, ::AbstractVector) at ~/build/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:310
  cross(::Any, ::AbstractThunk) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:89
  cross(::AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:88
  ...
Stacktrace:
  [1] (::ChainRules.var"#1949#1952"{Tangent{Any, Tuple{Float64, Float64, Float64}}, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}})()
    @ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/LinearAlgebra/dense.jl:109
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
  [3] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:105 [inlined]
  [4] map
    @ ./tuple.jl:223 [inlined]
  [5] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:106 [inlined]
  [6] (::Zygote.ZBack{ChainRules.var"#cross_pullback#1951"{Vector{Float64}, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}})(dy::Tuple{Float64, Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:206
  [7] Pullback
    @ ./REPL[76]:1 [inlined]
  [8] (::typeof((g)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
  [9] #208
    @ ~/.julia/packages/Zygote/tFaxC/src/lib/lib.jl:206 [inlined]
 [10] (::Zygote.var"#2066#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}}, typeof((g))}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [11] Pullback
    @ ./operators.jl:1035 [inlined]
 [12] (::typeof((#_#95)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
 [13] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof((#_#95))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/lib.jl:206
 [14] #2066#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [15] Pullback
    @ ./operators.jl:1033 [inlined]
 [16] (::typeof((ComposedFunction{typeof(Zygote._jvec), typeof(g)}(Zygote._jvec, g))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#60#61"{typeof((ComposedFunction{typeof(Zygote._jvec), typeof(g)}(Zygote._jvec, g)))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface.jl:45
 [18] withjacobian(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/grad.jl:150
 [19] jacobian(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/grad.jl:128
 [20] top-level scope
    @ REPL[77]:1

A quick fix is if I extend the cross product:

LinearAlgebra.cross(a::Tangent, b::AbstractVector) = -LinearAlgebra.cross(b, a)
LinearAlgebra.cross(a::AbstractVector, b::Tangent) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]]

Is this fix (a) needed or am I doing something incorrectly and (b) if a fix is needed, should it be in the ChainRules package? Does the Tangent need to be projected onto the subspace of the AbstractVector?

@mcabbott
Copy link
Member

The rule for cross being used is here:

function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real})
project_a = ProjectTo(a)
project_b = ProjectTo(b)
Ω = cross(a, b)
function cross_pullback(Ω̄)
ΔΩ = unthunk(Ω̄)
da = @thunk(project_a(cross(b, ΔΩ)))
db = @thunk(project_b(cross(ΔΩ, a)))
return (NoTangent(), da, db)
end
return Ω, cross_pullback
end

Things like jacobian(cross, rand(3), rand(3))[1] work fine.

I think the problem is the splat, ..., which is #599

@benjaminfaber
Copy link
Author

Thank you, that appears to be the problem. I can use norm instead hypot and that works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants