You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have run into an error when trying to compute the gradient or jacobian using Zygote for a function that contains a cross product. I'm relatively new to using AD, so I would like to know if this is user error or something that needs to be fixed/added to the ChainRules package. The MWA and stacktrace is below:
using LinearAlgebra, Zygote
g(a, b) =hypot(cross(a, b)...)
h(a, b) =dot(cross(a, b))
jacobian(g, rand(3), rand(3))
LinearAlgebra.cross(a::Tangent, b::AbstractVector) =-LinearAlgebra.cross(b, a)
LinearAlgebra.cross(a::AbstractVector, b::Tangent) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]]
Is this fix (a) needed or am I doing something incorrectly and (b) if a fix is needed, should it be in the ChainRules package? Does the Tangent need to be projected onto the subspace of the AbstractVector?
The text was updated successfully, but these errors were encountered:
I have run into an error when trying to compute the
gradient
orjacobian
using Zygote for a function that contains a cross product. I'm relatively new to using AD, so I would like to know if this is user error or something that needs to be fixed/added to the ChainRules package. The MWA and stacktrace is below:The stacktrace:
A quick fix is if I extend the cross product:
Is this fix (a) needed or am I doing something incorrectly and (b) if a fix is needed, should it be in the
ChainRules
package? Does theTangent
need to be projected onto the subspace of theAbstractVector
?The text was updated successfully, but these errors were encountered: