From 75c98e33f9bf35687cc84694db6fc82fe449168b Mon Sep 17 00:00:00 2001 From: Yuto Horikawa Date: Thu, 21 Sep 2023 17:21:29 +0900 Subject: [PATCH] Add more elementary functions (#266) * rename logexp to elementary_functions * add sqrt method * fix sqrt method * add cbrt method * add power methods * add tests for sqrt * fix typos * comment out some `cbrt` methods that requires `cbrt(::Quaternion)` * add tests for cbrt * add tests for power(`^`) * add tests for general dimensions * fix methods for general dimensions --- src/Rotations.jl | 2 +- src/angleaxis_types.jl | 5 -- src/core_types.jl | 2 - src/elementary_functions.jl | 117 +++++++++++++++++++++++++++ src/logexp.jl | 42 ---------- test/elementary_functions.jl | 153 +++++++++++++++++++++++++++++++++++ test/logexp.jl | 43 ---------- test/runtests.jl | 2 +- 8 files changed, 272 insertions(+), 94 deletions(-) create mode 100644 src/elementary_functions.jl delete mode 100644 src/logexp.jl create mode 100644 test/elementary_functions.jl delete mode 100644 test/logexp.jl diff --git a/src/Rotations.jl b/src/Rotations.jl index 22cfd9bd..ff7c2f88 100644 --- a/src/Rotations.jl +++ b/src/Rotations.jl @@ -23,7 +23,7 @@ include("rodrigues_params.jl") include("error_maps.jl") include("rotation_error.jl") include("rotation_generator.jl") -include("logexp.jl") +include("elementary_functions.jl") include("eigen.jl") include("rand.jl") include("rotation_between.jl") diff --git a/src/angleaxis_types.jl b/src/angleaxis_types.jl index b4511623..3c45605a 100644 --- a/src/angleaxis_types.jl +++ b/src/angleaxis_types.jl @@ -125,9 +125,6 @@ end @inline Base.:*(aa1::AngleAxis, aa2::AngleAxis) = QuatRotation(aa1) * QuatRotation(aa2) @inline Base.inv(aa::AngleAxis) = AngleAxis(-aa.theta, aa.axis_x, aa.axis_y, aa.axis_z) -@inline Base.:^(aa::AngleAxis, t::Real) = AngleAxis(aa.theta*t, aa.axis_x, aa.axis_y, aa.axis_z) -@inline Base.:^(aa::AngleAxis, t::Integer) = AngleAxis(aa.theta*t, aa.axis_x, aa.axis_y, aa.axis_z) # to avoid ambiguity - # define identity rotations for convenience @inline Base.one(::Type{AngleAxis}) = AngleAxis(0.0, 1.0, 0.0, 0.0) @@ -204,8 +201,6 @@ end @inline Base.:*(rv1::RotationVec, rv2::RotationVec) = QuatRotation(rv1) * QuatRotation(rv2) @inline Base.inv(rv::RotationVec) = RotationVec(-rv.sx, -rv.sy, -rv.sz) -@inline Base.:^(rv::RotationVec, t::Real) = RotationVec(rv.sx*t, rv.sy*t, rv.sz*t) -@inline Base.:^(rv::RotationVec, t::Integer) = RotationVec(rv.sx*t, rv.sy*t, rv.sz*t) # to avoid ambiguity # rotation properties @inline rotation_angle(rv::RotationVec) = sqrt(rv.sx * rv.sx + rv.sy * rv.sy + rv.sz * rv.sz) diff --git a/src/core_types.jl b/src/core_types.jl index 1c5900a6..5238987c 100644 --- a/src/core_types.jl +++ b/src/core_types.jl @@ -175,8 +175,6 @@ end end Base.:*(r1::Angle2d, r2::Angle2d) = Angle2d(r1.theta + r2.theta) -Base.:^(r::Angle2d, t::Real) = Angle2d(r.theta*t) -Base.:^(r::Angle2d, t::Integer) = Angle2d(r.theta*t) Base.inv(r::Angle2d) = Angle2d(-r.theta) @inline function Base.getindex(r::Angle2d, i::Int) diff --git a/src/elementary_functions.jl b/src/elementary_functions.jl new file mode 100644 index 00000000..ff28d841 --- /dev/null +++ b/src/elementary_functions.jl @@ -0,0 +1,117 @@ +## log +# 2d +Base.log(R::Angle2d) = Angle2dGenerator(R.theta) +Base.log(R::RotMatrix{2}) = RotMatrixGenerator(log(Angle2d(R))) +#= We can define log for Rotation{2} like this, +but the subtypes of Rotation{2} are only Angle2d and RotMatrix{2}, +so we don't need this defnition. =# +# Base.log(R::Rotation{2}) = log(Angle2d(R)) + +# 3d +Base.log(R::RotationVec) = RotationVecGenerator(R.sx,R.sy,R.sz) +Base.log(R::RotMatrix{3}) = RotMatrixGenerator(log(RotationVec(R))) +Base.log(R::Rotation{3}) = log(RotationVec(R)) + +# General dimensions +function Base.log(R::RotMatrix{N}) where N + # This will be faster when log(::SMatrix) is implemented in StaticArrays.jl + @static if VERSION < v"1.7" + # This if block is related to this PR. + # https://github.com/JuliaLang/julia/pull/40573 + S = SMatrix{N,N}(real(log(Matrix(R)))) + else + S = SMatrix{N,N}(log(Matrix(R))) + end + RotMatrixGenerator((S-S')/2) +end + +## exp +# 2d +Base.exp(R::Angle2dGenerator) = Angle2d(R.v) +Base.exp(R::RotMatrixGenerator{2}) = RotMatrix(exp(Angle2dGenerator(R))) +# Same as log(R::Rotation{2}), this definition is not necessary until someone add a new subtype of RotationGenerator{2}. +# Base.exp(R::RotationGenerator{2}) = exp(Angle2dGenerator(R)) + +# 3d +Base.exp(R::RotationVecGenerator) = RotationVec(R.x,R.y,R.z) +Base.exp(R::RotMatrixGenerator{3}) = RotMatrix(exp(RotationVecGenerator(R))) +# Same as log(R::Rotation{2}), this definition is not necessary until someone add a new subtype of RotationGenerator{3}. +# Base.exp(R::RotationGenerator{3}) = exp(RotationVecGenerator(R)) + +# General dimensions +Base.exp(R::RotMatrixGenerator{N}) where N = RotMatrix(exp(SMatrix(R))) + +## sqrt +# 2d +Base.sqrt(r::Angle2d) = Angle2d(r.theta/2) +Base.sqrt(r::RotMatrix{2}) = RotMatrix(sqrt(Angle2d(r))) + +# 3d +Base.sqrt(r::RotX) = RotX(r.theta/2) +Base.sqrt(r::RotY) = RotY(r.theta/2) +Base.sqrt(r::RotZ) = RotZ(r.theta/2) +Base.sqrt(r::AngleAxis) = AngleAxis(r.theta/2, r.axis_x, r.axis_y, r.axis_z) +Base.sqrt(r::RotationVec) = RotationVec(r.sx/2, r.sy/2, r.sz/2) +Base.sqrt(r::QuatRotation) = QuatRotation(sqrt(r.q)) +Base.sqrt(r::RotMatrix{3}) = RotMatrix{3}(sqrt(QuatRotation(r))) +Base.sqrt(r::RodriguesParam) = RodriguesParam(sqrt(QuatRotation(r))) +Base.sqrt(r::MRP) = MRP(sqrt(QuatRotation(r))) +Base.sqrt(r::Rotation{3}) = sqrt(QuatRotation(r)) + +# General dimensions +Base.sqrt(r::Rotation{N}) where N = RotMatrix(sqrt(SMatrix(r))) + +## cbrt +# 2d +Base.cbrt(r::Angle2d) = Angle2d(r.theta/3) +Base.cbrt(r::RotMatrix{2}) = RotMatrix(cbrt(Angle2d(r))) + +# 3d +Base.cbrt(r::RotX) = RotX(r.theta/3) +Base.cbrt(r::RotY) = RotY(r.theta/3) +Base.cbrt(r::RotZ) = RotZ(r.theta/3) +Base.cbrt(r::AngleAxis) = AngleAxis(r.theta/3, r.axis_x, r.axis_y, r.axis_z) +Base.cbrt(r::RotationVec) = RotationVec(r.sx/3, r.sy/3, r.sz/3) +# We can implement these `cbrt` methods when https://github.com/JuliaLang/julia/issues/36534 is resolved. +# Base.cbrt(r::QuatRotation) = QuatRotation(cbrt(r.q)) +# Base.cbrt(r::RotMatrix{3}) = RotMatrix{3}(cbrt(QuatRotation(r))) +# Base.cbrt(r::RodriguesParam) = RodriguesParam(cbrt(QuatRotation(r))) +# Base.cbrt(r::MRP) = MRP(cbrt(QuatRotation(r))) +# Base.cbrt(r::Rotation{3}) = cbrt(QuatRotation(r)) + +# General dimensions +Base.cbrt(r::Rotation{N}) where N = exp(log(r)/3) + +## power +# 2d +Base.:^(r::Angle2d, p::Real) = Angle2d(r.theta*p) +Base.:^(r::Angle2d, p::Integer) = Angle2d(r.theta*p) +Base.:^(r::RotMatrix{2}, p::Real) = RotMatrix(Angle2d(r)^p) +Base.:^(r::RotMatrix{2}, p::Integer) = RotMatrix(Angle2d(r)^p) + +# 3d +Base.:^(r::RotX, p::Real) = RotX(r.theta*p) +Base.:^(r::RotX, p::Integer) = RotX(r.theta*p) +Base.:^(r::RotY, p::Real) = RotY(r.theta*p) +Base.:^(r::RotY, p::Integer) = RotY(r.theta*p) +Base.:^(r::RotZ, p::Real) = RotZ(r.theta*p) +Base.:^(r::RotZ, p::Integer) = RotZ(r.theta*p) +Base.:^(r::AngleAxis, p::Real) = AngleAxis(r.theta*p, r.axis_x, r.axis_y, r.axis_z) +Base.:^(r::AngleAxis, p::Integer) = AngleAxis(r.theta*p, r.axis_x, r.axis_y, r.axis_z) +Base.:^(r::RotationVec, p::Real) = RotationVec(r.sx*p, r.sy*p, r.sz*p) +Base.:^(r::RotationVec, p::Integer) = RotationVec(r.sx*p, r.sy*p, r.sz*p) +Base.:^(r::QuatRotation, p::Real) = QuatRotation((r.q)^p) +Base.:^(r::QuatRotation, p::Integer) = QuatRotation((r.q)^p) +Base.:^(r::RotMatrix{3}, p::Real) = RotMatrix{3}(QuatRotation(r)^p) +Base.:^(r::RotMatrix{3}, p::Integer) = RotMatrix{3}(QuatRotation(r)^p) +Base.:^(r::RodriguesParam, p::Real) = RodriguesParam(QuatRotation(r)^p) +Base.:^(r::RodriguesParam, p::Integer) = RodriguesParam(QuatRotation(r)^p) +Base.:^(r::MRP, p::Real) = MRP(QuatRotation(r)^p) +Base.:^(r::MRP, p::Integer) = MRP(QuatRotation(r)^p) +Base.:^(r::Rotation{3}, p::Real) = QuatRotation(r)^p +Base.:^(r::Rotation{3}, p::Integer) = QuatRotation(r)^p + +# General dimensions +Base.:^(r::Rotation{N}, p::Real) where N = exp(log(r)*p) +# There is the same implementation of ^(r::Rotation{N}, p::Integer) as in Base +Base.:^(A::Rotation{N}, p::Integer) where N = p < 0 ? Base.power_by_squaring(inv(A), -p) : Base.power_by_squaring(A, p) diff --git a/src/logexp.jl b/src/logexp.jl deleted file mode 100644 index f3c4b0d5..00000000 --- a/src/logexp.jl +++ /dev/null @@ -1,42 +0,0 @@ -## log -# 2d -Base.log(R::Angle2d) = Angle2dGenerator(R.theta) -Base.log(R::RotMatrix{2}) = RotMatrixGenerator(log(Angle2d(R))) -#= We can define log for Rotation{2} like this, -but the subtypes of Rotation{2} are only Angle2d and RotMatrix{2}, -so we don't need this defnition. =# -# Base.log(R::Rotation{2}) = log(Angle2d(R)) - -# 3d -Base.log(R::RotationVec) = RotationVecGenerator(R.sx,R.sy,R.sz) -Base.log(R::RotMatrix{3}) = RotMatrixGenerator(log(RotationVec(R))) -Base.log(R::Rotation{3}) = log(RotationVec(R)) - -# General dimensions -function Base.log(R::RotMatrix{N}) where N - # This will be faster when log(::SMatrix) is implemented in StaticArrays.jl - @static if VERSION < v"1.7" - # This if block is related to this PR. - # https://github.com/JuliaLang/julia/pull/40573 - S = SMatrix{N,N}(real(log(Matrix(R)))) - else - S = SMatrix{N,N}(log(Matrix(R))) - end - RotMatrixGenerator((S-S')/2) -end - -## exp -# 2d -Base.exp(R::Angle2dGenerator) = Angle2d(R.v) -Base.exp(R::RotMatrixGenerator{2}) = RotMatrix(exp(Angle2dGenerator(R))) -# Same as log(R::Rotation{2}), this definition is not necessary until someone add a new subtype of RotationGenerator{2}. -# Base.exp(R::RotationGenerator{2}) = exp(Angle2dGenerator(R)) - -# 3d -Base.exp(R::RotationVecGenerator) = RotationVec(R.x,R.y,R.z) -Base.exp(R::RotMatrixGenerator{3}) = RotMatrix(exp(RotationVecGenerator(R))) -# Same as log(R::Rotation{2}), this definition is not necessary until someone add a new subtype of RotationGenerator{3}. -# Base.exp(R::RotationGenerator{3}) = exp(RotationVecGenerator(R)) - -# General dimensions -Base.exp(R::RotMatrixGenerator{N}) where N = RotMatrix(exp(SMatrix(R))) diff --git a/test/elementary_functions.jl b/test/elementary_functions.jl new file mode 100644 index 00000000..5bcca4b3 --- /dev/null +++ b/test/elementary_functions.jl @@ -0,0 +1,153 @@ +@testset "log" begin + all_types = (RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec, + QuatRotation, RodriguesParam, MRP, + RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX, + RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ, + RotX, RotY, RotZ, + RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY, + RotMatrix2, RotMatrix{2}, Angle2d) + + @testset "$(T)" for T in all_types, F in (one, rand) + R = F(T) + @test R ≈ exp(log(R)) + @test log(R) isa RotationGenerator + @test exp(log(R)) isa Rotation + end + + @testset "$(N)-dim" for N in 1:5 + M = @SMatrix rand(N,N) + R = nearest_rotation(M) + @test isrotationgenerator(log(R)) + @test log(R) isa RotMatrixGenerator + @test exp(log(R)) isa RotMatrix + end +end + +@testset "exp(zero)" begin + all_types = (RotMatrixGenerator{3}, RotationVecGenerator, + RotMatrixGenerator{2}, Angle2dGenerator) + + @testset "$(T)" for T in all_types + r = zero(T) + @test one(exp(r)) ≈ exp(r) + @test exp(r) isa Rotation + end +end + +@testset "exp(::RotMatrixGenerator)" begin + for N in 2:3 + r = zero(RotMatrixGenerator{N}) + @test r isa RotMatrixGenerator{N} + @test exp(r) isa RotMatrix{N} + end +end + +@testset "sqrt" begin + all_types = ( + RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec, + QuatRotation, RodriguesParam, MRP, + RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX, + RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ, + RotX, RotY, RotZ, + RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY, + RotMatrix2, RotMatrix{2}, Angle2d + ) + + compat_types = ( + RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec, + QuatRotation, RodriguesParam, MRP, + RotX, RotY, RotZ, + RotMatrix2, RotMatrix{2}, Angle2d + ) + + @testset "$(T)" for T in all_types, F in (one, rand) + R = F(T) + @test R ≈ sqrt(R) * sqrt(R) + @test sqrt(R) isa Rotation + end + + @testset "$(T)-compat" for T in compat_types + R = one(T) + @test sqrt(R) isa T + end + + @testset "$(T)-noncompat3d" for T in setdiff(all_types, compat_types) + R = one(T) + @test sqrt(R) isa QuatRotation + end + + @testset "$(N)-dim" for N in 1:5 + M = @SMatrix rand(N,N) + R = nearest_rotation(M) + @test R ≈ sqrt(R) * sqrt(R) + @test sqrt(R) isa RotMatrix{N} + end +end + +@testset "cbrt" begin + supported_types = ( + AngleAxis, RotationVec, + RotX, RotY, RotZ, + RotMatrix2, RotMatrix{2}, Angle2d + ) + + @testset "$(T)" for T in supported_types, F in (one, rand) + R = F(T) + @test R ≈ cbrt(R) * cbrt(R) * cbrt(R) + @test cbrt(R) isa Rotation + end + + @testset "$(N)-dim" for N in 1:5 + M = @SMatrix rand(N,N) + R = nearest_rotation(M) + @test R ≈ cbrt(R) * cbrt(R) * cbrt(R) + @test cbrt(R) isa RotMatrix{N} + end +end + +@testset "power" begin + all_types = ( + RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec, + QuatRotation, RodriguesParam, MRP, + RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX, + RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ, + RotX, RotY, RotZ, + RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY, + RotMatrix2, RotMatrix{2}, Angle2d + ) + + compat_types = ( + RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec, + QuatRotation, RodriguesParam, MRP, + RotX, RotY, RotZ, + RotMatrix2, RotMatrix{2}, Angle2d + ) + + @testset "$(T)" for T in all_types, F in (one, rand) + R = F(T) + @test R^2 ≈ R * R + @test R^1.5 ≈ sqrt(R) * sqrt(R) * sqrt(R) + @test R isa Rotation + end + + @testset "$(T)-compat" for T in compat_types + R = one(T) + @test R^2 isa T + @test R^1.5 isa T + end + + @testset "$(T)-noncompat3d" for T in setdiff(all_types, compat_types) + R = one(T) + @test R^2 isa QuatRotation + @test R^1.5 isa QuatRotation + end + + @testset "$(N)-dim" for N in 1:5 + M = @SMatrix rand(N,N) + R = nearest_rotation(M) + @test R^2 ≈ R * R + @test R^1.5 ≈ sqrt(R) * sqrt(R) * sqrt(R) + @test R^2 isa RotMatrix{N} + @test R^1.5 isa RotMatrix{N} + end +end diff --git a/test/logexp.jl b/test/logexp.jl deleted file mode 100644 index f86af96e..00000000 --- a/test/logexp.jl +++ /dev/null @@ -1,43 +0,0 @@ -@testset "log" begin - all_types = (RotMatrix3, RotMatrix{3}, AngleAxis, RotationVec, - QuatRotation, RodriguesParam, MRP, - RotXYZ, RotYZX, RotZXY, RotXZY, RotYXZ, RotZYX, - RotXYX, RotYZY, RotZXZ, RotXZX, RotYXY, RotZYZ, - RotX, RotY, RotZ, - RotXY, RotYZ, RotZX, RotXZ, RotYX, RotZY, - RotMatrix2, RotMatrix{2}, Angle2d) - - @testset "$(T)" for T in all_types, F in (one, rand) - R = F(T) - @test R ≈ exp(log(R)) - @test log(R) isa RotationGenerator - @test exp(log(R)) isa Rotation - end - - @testset "$(N)-dim" for N in 1:5 - M = @SMatrix rand(N,N) - R = nearest_rotation(M) - @test isrotationgenerator(log(R)) - @test log(R) isa RotMatrixGenerator - @test exp(log(R)) isa RotMatrix - end -end - -@testset "exp(zero)" begin - all_types = (RotMatrixGenerator{3}, RotationVecGenerator, - RotMatrixGenerator{2}, Angle2dGenerator) - - @testset "$(T)" for T in all_types - r = zero(T) - @test one(exp(r)) ≈ exp(r) - @test exp(r) isa Rotation - end -end - -@testset "exp(::RotMatrixGenerator)" begin - for N in 2:3 - r = zero(RotMatrixGenerator{N}) - @test r isa RotMatrixGenerator{N} - @test exp(r) isa RotMatrix{N} - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 5e2909b4..1cebd2ed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,7 @@ include("distribution_tests.jl") include("eigen.jl") include("nearest_rotation.jl") include("rotation_generator.jl") -include("logexp.jl") +include("elementary_functions.jl") include("deprecated.jl") include(joinpath(@__DIR__, "..", "perf", "runbenchmarks.jl"))