Skip to content

Commit

Permalink
sincos VecUnroll of scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed May 20, 2023
1 parent cbaf937 commit a64124a
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SLEEFPirates"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
authors = ["chriselrod <[email protected]>"]
version = "0.6.38"
version = "0.6.39"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
209 changes: 150 additions & 59 deletions src/SLEEFPirates.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,46 @@
module SLEEFPirates
if isdefined(Base, :Experimental) &&
isdefined(Base.Experimental, Symbol("@max_methods"))
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
@eval Base.Experimental.@max_methods 1
end
using Base: llvmcall
using Base.Math: uinttype, exponent_bias, exponent_mask, significand_bits, IEEEFloat, exponent_raw_max
using Base.Math:
uinttype, exponent_bias, exponent_mask, significand_bits, IEEEFloat, exponent_raw_max

using VectorizationBase
using Static: True, False, One, lt, StaticInt

using VectorizationBase: vzero, AbstractSIMD, _Vec, fma_fast, data, VecUnroll, NativeTypes, FloatingTypes, vIEEEFloat,
vfmadd, vfnmadd, vfmsub, vfnmsub, Double, dadd, dadd2, dsub, dsub2, dmul, dsqu, dsqrt, ddiv, drec, scale, dnormalize
using VectorizationBase:
vzero,
AbstractSIMD,
_Vec,
fma_fast,
data,
VecUnroll,
NativeTypes,
FloatingTypes,
vIEEEFloat,
vfmadd,
vfnmadd,
vfmsub,
vfnmsub,
Double,
dadd,
dadd2,
dsub,
dsub2,
dmul,
dsqu,
dsqrt,
ddiv,
drec,
scale,
dnormalize


import IfElse: ifelse

export Vec, sigmoid_fast, tanh_fast,
PReLu, gelu, softplus, silu, Elu
#, loggamma
export Vec, sigmoid_fast, tanh_fast, PReLu, gelu, softplus, silu, Elu
#, loggamma

const FloatType64 = Union{Float64,AbstractSIMD{<:Any,Float64}}
const FloatType32 = Union{Float32,AbstractSIMD{<:Any,Float32}}
Expand All @@ -29,30 +52,38 @@ const IntegerType = Union{IntegerType64,IntegerType32}
fpinttype(::Type{Float64}) = Int
fpinttype(::Type{Float32}) = Int32
function fpinttype(::Type{Vec{N,Float64}}) where {N}
Vec{N,Int}
Vec{N,Int}
end
function fpinttype(::Type{Vec{N,Float32}}) where {N}
Vec{N,Int32}
Vec{N,Int32}
end


## constants

const MLN2 = 6.931471805599453094172321214581765680755001343602552541206800094933936219696955e-01 # log(2)
const MLN2 =
6.931471805599453094172321214581765680755001343602552541206800094933936219696955e-01 # log(2)
const MLN2E = 1.442695040888963407359924681001892137426645954152985934135449406931 # log2(e)

const M_PI = 3.141592653589793238462643383279502884 # pi
const PI_2 = 1.570796326794896619231321691639751442098584699687552910487472296153908203143099 # pi/2
const PI_4 = 7.853981633974483096156608458198757210492923498437764552437361480769541015715495e-01 # pi/4
const M_PI = 3.141592653589793238462643383279502884 # pi
const PI_2 =
1.570796326794896619231321691639751442098584699687552910487472296153908203143099 # pi/2
const PI_4 =
7.853981633974483096156608458198757210492923498437764552437361480769541015715495e-01 # pi/4
const M_1_PI = 0.318309886183790671537767526745028724 # 1/pi
const M_2_PI = 0.636619772367581343075535053490057448 # 2/pi
const M_4_PI = 1.273239544735162686151070106980114896275677165923651589981338752471174381073817 # 4/pi
const M_4_PI =
1.273239544735162686151070106980114896275677165923651589981338752471174381073817 # 4/pi

const MSQRT2 = 1.414213562373095048801688724209698078569671875376948073176679737990732478462102 # sqrt(2)
const M1SQRT2 = 7.071067811865475244008443621048490392848359376884740365883398689953662392310596e-01 # 1/sqrt(2)
const MSQRT2 =
1.414213562373095048801688724209698078569671875376948073176679737990732478462102 # sqrt(2)
const M1SQRT2 =
7.071067811865475244008443621048490392848359376884740365883398689953662392310596e-01 # 1/sqrt(2)

const M2P13 = 1.259921049894873164767210607278228350570251464701507980081975112155299676513956 # 2^1/3
const M2P23 = 1.587401051968199474751705639272308260391493327899853009808285761825216505624206 # 2^2/3
const M2P13 =
1.259921049894873164767210607278228350570251464701507980081975112155299676513956 # 2^1/3
const M2P23 =
1.587401051968199474751705639272308260391493327899853009808285761825216505624206 # 2^2/3

const MLOG10_2 = 3.3219280948873623478703194294893901758648313930

Expand All @@ -62,11 +93,12 @@ const MDLN10E(::Type{Float32}) = Double(0.4342945f0, -1.010305f-8)
const MDLN2E(::Type{Float64}) = Double(1.4426950408889634, 2.0355273740931033e-17) # log2(e)
const MDLN2E(::Type{Float32}) = Double(1.442695f0, 1.925963f-8)

const MDLN2(::Type{Float64}) = Double(0.693147180559945286226764, 2.319046813846299558417771e-17) # log(2)
const MDLN2(::Type{Float64}) =
Double(0.693147180559945286226764, 2.319046813846299558417771e-17) # log(2)
const MDLN2(::Type{Float32}) = Double(0.69314718246459960938f0, -1.904654323148236017f-9)

const MDPI(::Type{Float64}) = Double(3.141592653589793, 1.2246467991473532e-16) # pi
const MDPI(::Type{Float32}) = Double(3.1415927f0, -8.742278f-8)
const MDPI(::Type{Float64}) = Double(3.141592653589793, 1.2246467991473532e-16) # pi
const MDPI(::Type{Float32}) = Double(3.1415927f0, -8.742278f-8)
const MDPI2(::Type{Float64}) = Double(1.5707963267948966, 6.123233995736766e-17) # pi/2
const MDPI2(::Type{Float32}) = Double(1.5707964f0, -4.371139f-8)

Expand Down Expand Up @@ -109,10 +141,10 @@ const L2U(::Type{Float32}) = 0.693145751953125f0
const L2L(::Type{Float32}) = 1.428606765330187045f-06

const TRIG_MAX(::Type{Float64}) = 1e14
const TRIG_MAX(::Type{Float32}) = 1f7
const TRIG_MAX(::Type{Float32}) = 1.0f7

const SQRT_MAX(::Type{Float64}) = 1.3407807929942596355e154
const SQRT_MAX(::Type{Float32}) = 18446743523953729536f0
const SQRT_MAX(::Type{Float32}) = 18446743523953729536.0f0

include("estrin.jl")
include("utils.jl") # utility functions
Expand All @@ -139,16 +171,19 @@ include("rectifier.jl")
for n 0:N
push!(t.args, :(VectorizationBase.extractelement(v, $n)))
end
Expr(:block, Expr(:meta,:inline), :(VecUnroll($t)))
Expr(:block, Expr(:meta, :inline), :(VecUnroll($t)))
end
@generated function to_vecunrollscalar(v::VecUnroll{M,W,T,V}, ::StaticInt{N}) where {M,W,T,N,V<:VectorizationBase.AbstractSIMDVector{W,T}}
@generated function to_vecunrollscalar(
v::VecUnroll{M,W,T,V},
::StaticInt{N},
) where {M,W,T,N,V<:VectorizationBase.AbstractSIMDVector{W,T}}
t = Expr(:tuple)
n = 0
q = Expr(:block, Expr(:meta,:inline), :(d = VectorizationBase.data(v)))
q = Expr(:block, Expr(:meta, :inline), :(d = VectorizationBase.data(v)))
dobreak = false
for m 0:M
vm = Symbol(:v_,m)
push!(q.args, :($vm = getfield(d, $(m+1))))
vm = Symbol(:v_, m)
push!(q.args, :($vm = getfield(d, $(m + 1))))
for w 0:W-1
push!(t.args, :(VectorizationBase.extractelement($vm, $w)))
dobreak = n == N
Expand All @@ -161,38 +196,76 @@ end
q
end

for func in (:sin, :cos, :tan, :asin, :acos, :atan, :sinh, :cosh, :tanh,
:asinh, :acosh, :atanh, :log, :log2, :log10, :log1p, :expm1, :cbrt,
:sin_fast, :cos_fast, :tan_fast, :asin_fast, :acos_fast, :atan_fast,# :atan2_fast,
:log_fast, :log2_fast, :log10_fast, :cbrt_fast)#, :exp, :exp2, :exp10
for func in (
:sin,
:cos,
:tan,
:asin,
:acos,
:atan,
:sinh,
:cosh,
:tanh,
:asinh,
:acosh,
:atanh,
:log,
:log2,
:log10,
:log1p,
:expm1,
:cbrt,
:sin_fast,
:cos_fast,
:tan_fast,
:asin_fast,
:acos_fast,
:atan_fast,# :atan2_fast,
:log_fast,
:log2_fast,
:log10_fast,
:cbrt_fast,
)#, :exp, :exp2, :exp10
@eval begin
@inline $func(a::Float16) = Float16.($func(Float32(a)))
@inline $func(x::Real) = $func(float(x))
@inline $func(v::AbstractSIMD{W,I}) where {W,I<:Integer} = $func(float(v))
@inline $func(i::MM) = $func(Vec(i))
@inline $func(v::VecUnroll{N,1,T,T}) where {N,T} = to_vecunrollscalar($func(VectorizationBase.transpose_vecunroll(v)), StaticInt{N}())
@inline $func(v::VecUnroll{N,1,T,T}) where {N,T<:NativeTypes} =
to_vecunrollscalar($func(VectorizationBase.transpose_vecunroll(v)), StaticInt{N}())
end
end
@inline function sincos(v::VecUnroll{N,1,T,T}) where {N,T<:NativeTypes}
s, c = sincos(VectorizationBase.transpose_vecunroll(v))
to_vecunrollscalar(s, StaticInt{N}()), to_vecunrollscalar(c, StaticInt{N}())
end
@inline function sincos_fast(v::VecUnroll{N,1,T,T}) where {N,T<:NativeTypes}
s, c = sincos_fast(VectorizationBase.transpose_vecunroll(v))
to_vecunrollscalar(s, StaticInt{N}()), to_vecunrollscalar(c, StaticInt{N}())
end
# Tπ(::Type{T}) where {T} = promote_type(T, typeof(π))(π)
for func (:sin, :cos)
funcpi = Symbol(func, :pi)
funcfast = Symbol(func, :_fast)
funcpifast = Symbol(func, :pi_fast)
@eval @inline $funcpi(v::AbstractSIMD{W,T}) where {W,T} = $func(vbroadcast(Val{W}(), (T) * v))
@eval @inline Base.$funcpi(v::AbstractSIMD{W,T}) where {W,T} = $func(T(π) * v)
@eval @inline $funcpifast(v::AbstractSIMD{W,T}) where {W,T} = $funcfast(T(π) * v)
@eval @inline $funcpi(i::MM) = $funcpi(float(i))
funcpi = Symbol(func, :pi)
funcfast = Symbol(func, :_fast)
funcpifast = Symbol(func, :pi_fast)
@eval @inline $funcpi(v::AbstractSIMD{W,T}) where {W,T} =
$func(vbroadcast(Val{W}(), (T) * v))
@eval @inline Base.$funcpi(v::AbstractSIMD{W,T}) where {W,T} = $func(T(π) * v)
@eval @inline $funcpifast(v::AbstractSIMD{W,T}) where {W,T} = $funcfast(T(π) * v)
@eval @inline $funcpi(i::MM) = $funcpi(float(i))
end
if VERSION v"1.6"
@inline Base.sincospi(v::AbstractSIMD{W,T}) where {W,T} = sincos(T(π) * v)
@inline Base.sincospi(v::Vec{W,T}) where {W,T} = sincos(T(π) * v)
@inline Base.sincospi(v::AbstractSIMD{W,T}) where {W,T} = sincos(T(π) * v)
@inline Base.sincospi(v::Vec{W,T}) where {W,T} = sincos(T(π) * v)
end
@inline sincospi_fast(v::AbstractSIMD{W,T}) where {W,T} = sincos_fast(T(π) * v)
@inline sincospi_fast(v::Vec{W,T}) where {W,T} = sincos_fast(T(π) * v)

for func in (:sinh, :cosh, :tanh, :asinh, :acosh, :atanh, :log1p, :expm1)#, :exp, :exp2, :exp10
@eval begin
@inline Base.$func(x::AbstractSIMD{W,T}) where {W,T<:Union{Float32,Float64,Int32,UInt32,Int64,UInt64}} = $func(x)
@inline Base.$func(
x::AbstractSIMD{W,T},
) where {W,T<:Union{Float32,Float64,Int32,UInt32,Int64,UInt64}} = $func(x)
@inline Base.$func(x::MM) = $func(Vec(x))
end
end
Expand All @@ -203,10 +276,13 @@ for func ∈ (:sin, :cos, :tan, :asin, :acos, :atan, :log, :log2, :log10, :cbrt,
@inline Base.FastMath.$func_fast(x::AbstractSIMD) = $func_fast(float(x))
end
end
@inline Base.FastMath.atan_fast(a::T, b::Number) where {T<:AbstractSIMD} = atan_fast(a, T(b))
@inline Base.FastMath.atan_fast(a::Number, b::T) where {T<:AbstractSIMD} = atan_fast(T(a), b)
@inline Base.FastMath.atan_fast(a::T, b::Number) where {T<:AbstractSIMD} =
atan_fast(a, T(b))
@inline Base.FastMath.atan_fast(a::Number, b::T) where {T<:AbstractSIMD} =
atan_fast(T(a), b)
@inline Base.FastMath.atan_fast(a::T, b::T) where {T<:AbstractSIMD} = atan_fast(a, b)
@inline Base.FastMath.atan_fast(a::AbstractSIMD, b::AbstractSIMD) = ((c,d) = promote(a,b); atan_fast(c, d))
@inline Base.FastMath.atan_fast(a::AbstractSIMD, b::AbstractSIMD) =
((c, d) = promote(a, b); atan_fast(c, d))
for func in (:atan, :hypot, :pow)
func2 = func === :pow ? :^ : func
ptyp = func === :pow ? :FloatingTypes : :NativeTypes
Expand All @@ -215,12 +291,24 @@ for func in (:atan, :hypot, :pow)
@inline $func(a::Float16, b::Float16) = Float16($func(Float32(a), Float32(b)))
# @inline Base.$func2(x::AbstractSIMD{W,T}, y::Vec{W,T}) where {W,T<:Union{Float32,Float64}} = $func(x, Vec(y))
# @inline Base.$func2(x::Vec{W,T}, y::AbstractSIMD{W,T}) where {W,T<:Union{Float32,Float64}} = $func(Vec(x), y)
@inline Base.$func2(x::AbstractSIMD{W,T}, y::T) where {W,T<:Union{Float32,Float64}} = $func(x, convert(Vec{W,T}, y))
@inline Base.$func2(x::T, y::AbstractSIMD{W,T}) where {W,T<:Union{Float32,Float64}} = $func(convert(Vec{W,T}, x), y)
@inline Base.$func2(x::AbstractSIMD{W,T1}, y::T2) where {W,T1<:Union{Float32,Float64},T2<:$ptyp} = $func(x, convert(Vec{W,T1}, y))
@inline Base.$func2(x::T2, y::AbstractSIMD{W,T1}) where {W,T1<:Union{Float32,Float64},T2<:NativeTypes} = $func(convert(Vec{W,T1}, x), y)
@inline Base.$func2(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}) where {W,T<:Union{Float32,Float64}} = $func(x, y)
@inline $func(v1::AbstractSIMD{W,I}, v2::AbstractSIMD{W,I}) where {W,I<:Integer} = $func(float(v1), float(v2))
@inline Base.$func2(x::AbstractSIMD{W,T}, y::T) where {W,T<:Union{Float32,Float64}} =
$func(x, convert(Vec{W,T}, y))
@inline Base.$func2(x::T, y::AbstractSIMD{W,T}) where {W,T<:Union{Float32,Float64}} =
$func(convert(Vec{W,T}, x), y)
@inline Base.$func2(
x::AbstractSIMD{W,T1},
y::T2,
) where {W,T1<:Union{Float32,Float64},T2<:$ptyp} = $func(x, convert(Vec{W,T1}, y))
@inline Base.$func2(
x::T2,
y::AbstractSIMD{W,T1},
) where {W,T1<:Union{Float32,Float64},T2<:NativeTypes} = $func(convert(Vec{W,T1}, x), y)
@inline Base.$func2(
x::AbstractSIMD{W,T},
y::AbstractSIMD{W,T},
) where {W,T<:Union{Float32,Float64}} = $func(x, y)
@inline $func(v1::AbstractSIMD{W,I}, v2::AbstractSIMD{W,I}) where {W,I<:Integer} =
$func(float(v1), float(v2))
end
end
@inline ldexp(x::Float16, q::Int) = Float16(ldexpk(Float32(x), q))
Expand All @@ -231,8 +319,10 @@ end
# @inline ninvlogit(x) = Base.FastMath.inv_fast(Base.FastMath.add_fast(one(x), exp(x)))
# @inline log1m(x) = log1p(Base.FastMath.sub_fast(x))

max_tanh(::Type{Float64}) = 19.06154746539849599509609553228539867418786340504817671278462587964799037885145
max_tanh(::Type{Float32}) = 9.010913339828708369989037671244720498805572920317272822795576296065428827978905f0
max_tanh(::Type{Float64}) =
19.06154746539849599509609553228539867418786340504817671278462587964799037885145
max_tanh(::Type{Float32}) =
9.010913339828708369989037671244720498805572920317272822795576296065428827978905f0

@inline function tanh_fast(x::Union{Float32,AbstractSIMD{<:Any,Float32}})
# stolen from https://github.com/FluxML/NNlib.jl/pull/345
Expand All @@ -246,7 +336,7 @@ max_tanh(::Type{Float32}) = 9.01091333982870836998903767124472049880557292031727
d2 = muladd(d1, x2, 0.4679937f0)
n = muladd(n2, x2, 1.0f0)
d = muladd(d2, x2, 1.0f0)
ifelse(x2 < 66f0, @fastmath(x * (n / d)), sign(x))
ifelse(x2 < 66.0f0, @fastmath(x * (n / d)), sign(x))
end
@inline function tanh_fast(x::Union{Float64,AbstractSIMD{<:Any,Float64}})
exp2xm1 = expm1_fast(Base.FastMath.add_fast(x, x))
Expand All @@ -259,15 +349,16 @@ end
@inline Base.FastMath.tanh_fast(x::AbstractSIMD) = tanh_fast(x)
@inline function Base.:(^)(
x::AbstractSIMD{W,<:Base.BitInteger},
y::AbstractSIMD{W,<:Base.BitInteger}
y::AbstractSIMD{W,<:Base.BitInteger},
) where {W}
float(x) ^ y
float(x)^y
end
# sigmoid_max(::Type{Float64}) = 36.42994775023704665301938332748370611415146834112402863375388447785857586583462
# sigmoid_max(::Type{Float32}) = 17.3286794841963099036462718631317335849086302638474573162299687307067828965093f0

# @inline sigmoid_fast(x) = Base.FastMath.inv_fast(Base.FastMath.add_fast(one(x), exp(Base.FastMath.sub_fast(x))))
@inline sigmoid_fast(x) = inv(Base.FastMath.add_fast(one(x), Base.exp(Base.FastMath.sub_fast(x))))
@inline sigmoid_fast(x) =
inv(Base.FastMath.add_fast(one(x), Base.exp(Base.FastMath.sub_fast(x))))
# `inv_fast` was slower than `inv`
# @inline sigmoid_fast(x) = Base.FastMath.inv_fast(Base.FastMath.add_fast(one(x), exp(Base.FastMath.sub_fast(x))))

Expand Down

2 comments on commit a64124a

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/83955

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.39 -m "<description of version>" a64124aa8c1e3b4afde5183ea1937e1487243570
git push origin v0.6.39

Please sign in to comment.