Skip to content

Commit

Permalink
Add more elementary functions (#266)
Browse files Browse the repository at this point in the history
* rename logexp to elementary_functions

* add sqrt method

* fix sqrt method

* add cbrt method

* add power methods

* add tests for sqrt

* fix typos

* comment out some `cbrt` methods that requires `cbrt(::Quaternion)`

* add tests for cbrt

* add tests for power(`^`)

* add tests for general dimensions

* fix methods for general dimensions
  • Loading branch information
hyrodium authored Sep 21, 2023
1 parent 4240829 commit 75c98e3
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 94 deletions.
2 changes: 1 addition & 1 deletion src/Rotations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ include("rodrigues_params.jl")
include("error_maps.jl")
include("rotation_error.jl")
include("rotation_generator.jl")
include("logexp.jl")
include("elementary_functions.jl")
include("eigen.jl")
include("rand.jl")
include("rotation_between.jl")
Expand Down
5 changes: 0 additions & 5 deletions src/angleaxis_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ end
@inline Base.:*(aa1::AngleAxis, aa2::AngleAxis) = QuatRotation(aa1) * QuatRotation(aa2)

@inline Base.inv(aa::AngleAxis) = AngleAxis(-aa.theta, aa.axis_x, aa.axis_y, aa.axis_z)
@inline Base.:^(aa::AngleAxis, t::Real) = AngleAxis(aa.theta*t, aa.axis_x, aa.axis_y, aa.axis_z)
@inline Base.:^(aa::AngleAxis, t::Integer) = AngleAxis(aa.theta*t, aa.axis_x, aa.axis_y, aa.axis_z) # to avoid ambiguity


# define identity rotations for convenience
@inline Base.one(::Type{AngleAxis}) = AngleAxis(0.0, 1.0, 0.0, 0.0)
Expand Down Expand Up @@ -204,8 +201,6 @@ end
@inline Base.:*(rv1::RotationVec, rv2::RotationVec) = QuatRotation(rv1) * QuatRotation(rv2)

@inline Base.inv(rv::RotationVec) = RotationVec(-rv.sx, -rv.sy, -rv.sz)
@inline Base.:^(rv::RotationVec, t::Real) = RotationVec(rv.sx*t, rv.sy*t, rv.sz*t)
@inline Base.:^(rv::RotationVec, t::Integer) = RotationVec(rv.sx*t, rv.sy*t, rv.sz*t) # to avoid ambiguity

# rotation properties
@inline rotation_angle(rv::RotationVec) = sqrt(rv.sx * rv.sx + rv.sy * rv.sy + rv.sz * rv.sz)
Expand Down
2 changes: 0 additions & 2 deletions src/core_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ end
end

Base.:*(r1::Angle2d, r2::Angle2d) = Angle2d(r1.theta + r2.theta)
Base.:^(r::Angle2d, t::Real) = Angle2d(r.theta*t)
Base.:^(r::Angle2d, t::Integer) = Angle2d(r.theta*t)
Base.inv(r::Angle2d) = Angle2d(-r.theta)

@inline function Base.getindex(r::Angle2d, i::Int)
Expand Down
117 changes: 117 additions & 0 deletions src/elementary_functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## log
# 2d
Base.log(R::Angle2d) = Angle2dGenerator(R.theta)
Base.log(R::RotMatrix{2}) = RotMatrixGenerator(log(Angle2d(R)))
#= We can define log for Rotation{2} like this,
but the subtypes of Rotation{2} are only Angle2d and RotMatrix{2},
so we don't need this defnition. =#
# Base.log(R::Rotation{2}) = log(Angle2d(R))

# 3d
Base.log(R::RotationVec) = RotationVecGenerator(R.sx,R.sy,R.sz)
Base.log(R::RotMatrix{3}) = RotMatrixGenerator(log(RotationVec(R)))
Base.log(R::Rotation{3}) = log(RotationVec(R))

# General dimensions
function Base.log(R::RotMatrix{N}) where N
# This will be faster when log(::SMatrix) is implemented in StaticArrays.jl
@static if VERSION < v"1.7"
# This if block is related to this PR.
# https://github.com/JuliaLang/julia/pull/40573
S = SMatrix{N,N}(real(log(Matrix(R))))
else
S = SMatrix{N,N}(log(Matrix(R)))
end
RotMatrixGenerator((S-S')/2)
end

## exp
# 2d
Base.exp(R::Angle2dGenerator) = Angle2d(R.v)
Base.exp(R::RotMatrixGenerator{2}) = RotMatrix(exp(Angle2dGenerator(R)))
# Same as log(R::Rotation{2}), this definition is not necessary until someone add a new subtype of RotationGenerator{2}.
# Base.exp(R::RotationGenerator{2}) = exp(Angle2dGenerator(R))

# 3d
Base.exp(R::RotationVecGenerator) = RotationVec(R.x,R.y,R.z)
Base.exp(R::RotMatrixGenerator{3}) = RotMatrix(exp(RotationVecGenerator(R)))
# Same as log(R::Rotation{2}), this definition is not necessary until someone add a new subtype of RotationGenerator{3}.
# Base.exp(R::RotationGenerator{3}) = exp(RotationVecGenerator(R))

# General dimensions
Base.exp(R::RotMatrixGenerator{N}) where N = RotMatrix(exp(SMatrix(R)))

## sqrt
# 2d
Base.sqrt(r::Angle2d) = Angle2d(r.theta/2)
Base.sqrt(r::RotMatrix{2}) = RotMatrix(sqrt(Angle2d(r)))

# 3d
Base.sqrt(r::RotX) = RotX(r.theta/2)
Base.sqrt(r::RotY) = RotY(r.theta/2)
Base.sqrt(r::RotZ) = RotZ(r.theta/2)
Base.sqrt(r::AngleAxis) = AngleAxis(r.theta/2, r.axis_x, r.axis_y, r.axis_z)
Base.sqrt(r::RotationVec) = RotationVec(r.sx/2, r.sy/2, r.sz/2)
Base.sqrt(r::QuatRotation) = QuatRotation(sqrt(r.q))
Base.sqrt(r::RotMatrix{3}) = RotMatrix{3}(sqrt(QuatRotation(r)))
Base.sqrt(r::RodriguesParam) = RodriguesParam(sqrt(QuatRotation(r)))
Base.sqrt(r::MRP) = MRP(sqrt(QuatRotation(r)))
Base.sqrt(r::Rotation{3}) = sqrt(QuatRotation(r))

# General dimensions
Base.sqrt(r::Rotation{N}) where N = RotMatrix(sqrt(SMatrix(r)))

## cbrt
# 2d
Base.cbrt(r::Angle2d) = Angle2d(r.theta/3)
Base.cbrt(r::RotMatrix{2}) = RotMatrix(cbrt(Angle2d(r)))

# 3d
Base.cbrt(r::RotX) = RotX(r.theta/3)
Base.cbrt(r::RotY) = RotY(r.theta/3)
Base.cbrt(r::RotZ) = RotZ(r.theta/3)
Base.cbrt(r::AngleAxis) = AngleAxis(r.theta/3, r.axis_x, r.axis_y, r.axis_z)
Base.cbrt(r::RotationVec) = RotationVec(r.sx/3, r.sy/3, r.sz/3)
# We can implement these `cbrt` methods when https://github.com/JuliaLang/julia/issues/36534 is resolved.
# Base.cbrt(r::QuatRotation) = QuatRotation(cbrt(r.q))
# Base.cbrt(r::RotMatrix{3}) = RotMatrix{3}(cbrt(QuatRotation(r)))
# Base.cbrt(r::RodriguesParam) = RodriguesParam(cbrt(QuatRotation(r)))
# Base.cbrt(r::MRP) = MRP(cbrt(QuatRotation(r)))
# Base.cbrt(r::Rotation{3}) = cbrt(QuatRotation(r))

# General dimensions
Base.cbrt(r::Rotation{N}) where N = exp(log(r)/3)

## power
# 2d
Base.:^(r::Angle2d, p::Real) = Angle2d(r.theta*p)
Base.:^(r::Angle2d, p::Integer) = Angle2d(r.theta*p)
Base.:^(r::RotMatrix{2}, p::Real) = RotMatrix(Angle2d(r)^p)
Base.:^(r::RotMatrix{2}, p::Integer) = RotMatrix(Angle2d(r)^p)

# 3d
Base.:^(r::RotX, p::Real) = RotX(r.theta*p)
Base.:^(r::RotX, p::Integer) = RotX(r.theta*p)
Base.:^(r::RotY, p::Real) = RotY(r.theta*p)
Base.:^(r::RotY, p::Integer) = RotY(r.theta*p)
Base.:^(r::RotZ, p::Real) = RotZ(r.theta*p)
Base.:^(r::RotZ, p::Integer) = RotZ(r.theta*p)
Base.:^(r::AngleAxis, p::Real) = AngleAxis(r.theta*p, r.axis_x, r.axis_y, r.axis_z)
Base.:^(r::AngleAxis, p::Integer) = AngleAxis(r.theta*p, r.axis_x, r.axis_y, r.axis_z)
Base.:^(r::RotationVec, p::Real) = RotationVec(r.sx*p, r.sy*p, r.sz*p)
Base.:^(r::RotationVec, p::Integer) = RotationVec(r.sx*p, r.sy*p, r.sz*p)
Base.:^(r::QuatRotation, p::Real) = QuatRotation((r.q)^p)
Base.:^(r::QuatRotation, p::Integer) = QuatRotation((r.q)^p)
Base.:^(r::RotMatrix{3}, p::Real) = RotMatrix{3}(QuatRotation(r)^p)
Base.:^(r::RotMatrix{3}, p::Integer) = RotMatrix{3}(QuatRotation(r)^p)
Base.:^(r::RodriguesParam, p::Real) = RodriguesParam(QuatRotation(r)^p)
Base.:^(r::RodriguesParam, p::Integer) = RodriguesParam(QuatRotation(r)^p)
Base.:^(r::MRP, p::Real) = MRP(QuatRotation(r)^p)
Base.:^(r::MRP, p::Integer) = MRP(QuatRotation(r)^p)
Base.:^(r::Rotation{3}, p::Real) = QuatRotation(r)^p
Base.:^(r::Rotation{3}, p::Integer) = QuatRotation(r)^p

# General dimensions
Base.:^(r::Rotation{N}, p::Real) where N = exp(log(r)*p)
# There is the same implementation of ^(r::Rotation{N}, p::Integer) as in Base
Base.:^(A::Rotation{N}, p::Integer) where N = p < 0 ? Base.power_by_squaring(inv(A), -p) : Base.power_by_squaring(A, p)
42 changes: 0 additions & 42 deletions src/logexp.jl

This file was deleted.

153 changes: 153 additions & 0 deletions test/elementary_functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
@testset "log" begin
all_types = (RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec,
QuatRotation, RodriguesParam, MRP,
RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX,
RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ,
RotX, RotY, RotZ,
RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY,
RotMatrix2, RotMatrix{2}, Angle2d)

@testset "$(T)" for T in all_types, F in (one, rand)
R = F(T)
@test R exp(log(R))
@test log(R) isa RotationGenerator
@test exp(log(R)) isa Rotation
end

@testset "$(N)-dim" for N in 1:5
M = @SMatrix rand(N,N)
R = nearest_rotation(M)
@test isrotationgenerator(log(R))
@test log(R) isa RotMatrixGenerator
@test exp(log(R)) isa RotMatrix
end
end

@testset "exp(zero)" begin
all_types = (RotMatrixGenerator{3}, RotationVecGenerator,
RotMatrixGenerator{2}, Angle2dGenerator)

@testset "$(T)" for T in all_types
r = zero(T)
@test one(exp(r)) exp(r)
@test exp(r) isa Rotation
end
end

@testset "exp(::RotMatrixGenerator)" begin
for N in 2:3
r = zero(RotMatrixGenerator{N})
@test r isa RotMatrixGenerator{N}
@test exp(r) isa RotMatrix{N}
end
end

@testset "sqrt" begin
all_types = (
RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec,
QuatRotation, RodriguesParam, MRP,
RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX,
RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ,
RotX, RotY, RotZ,
RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY,
RotMatrix2, RotMatrix{2}, Angle2d
)

compat_types = (
RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec,
QuatRotation, RodriguesParam, MRP,
RotX, RotY, RotZ,
RotMatrix2, RotMatrix{2}, Angle2d
)

@testset "$(T)" for T in all_types, F in (one, rand)
R = F(T)
@test R sqrt(R) * sqrt(R)
@test sqrt(R) isa Rotation
end

@testset "$(T)-compat" for T in compat_types
R = one(T)
@test sqrt(R) isa T
end

@testset "$(T)-noncompat3d" for T in setdiff(all_types, compat_types)
R = one(T)
@test sqrt(R) isa QuatRotation
end

@testset "$(N)-dim" for N in 1:5
M = @SMatrix rand(N,N)
R = nearest_rotation(M)
@test R sqrt(R) * sqrt(R)
@test sqrt(R) isa RotMatrix{N}
end
end

@testset "cbrt" begin
supported_types = (
AngleAxis, RotationVec,
RotX, RotY, RotZ,
RotMatrix2, RotMatrix{2}, Angle2d
)

@testset "$(T)" for T in supported_types, F in (one, rand)
R = F(T)
@test R cbrt(R) * cbrt(R) * cbrt(R)
@test cbrt(R) isa Rotation
end

@testset "$(N)-dim" for N in 1:5
M = @SMatrix rand(N,N)
R = nearest_rotation(M)
@test R cbrt(R) * cbrt(R) * cbrt(R)
@test cbrt(R) isa RotMatrix{N}
end
end

@testset "power" begin
all_types = (
RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec,
QuatRotation, RodriguesParam, MRP,
RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX,
RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ,
RotX, RotY, RotZ,
RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY,
RotMatrix2, RotMatrix{2}, Angle2d
)

compat_types = (
RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec,
QuatRotation, RodriguesParam, MRP,
RotX, RotY, RotZ,
RotMatrix2, RotMatrix{2}, Angle2d
)

@testset "$(T)" for T in all_types, F in (one, rand)
R = F(T)
@test R^2 R * R
@test R^1.5 sqrt(R) * sqrt(R) * sqrt(R)
@test R isa Rotation
end

@testset "$(T)-compat" for T in compat_types
R = one(T)
@test R^2 isa T
@test R^1.5 isa T
end

@testset "$(T)-noncompat3d" for T in setdiff(all_types, compat_types)
R = one(T)
@test R^2 isa QuatRotation
@test R^1.5 isa QuatRotation
end

@testset "$(N)-dim" for N in 1:5
M = @SMatrix rand(N,N)
R = nearest_rotation(M)
@test R^2 R * R
@test R^1.5 sqrt(R) * sqrt(R) * sqrt(R)
@test R^2 isa RotMatrix{N}
@test R^1.5 isa RotMatrix{N}
end
end
43 changes: 0 additions & 43 deletions test/logexp.jl

This file was deleted.

Loading

0 comments on commit 75c98e3

Please sign in to comment.