Skip to content

Commit

Permalink
Add faster exp2, and a naive log1m(x) = log(1-x).
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Feb 24, 2020
1 parent a568d4b commit 2f73ead
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 60 deletions.
12 changes: 4 additions & 8 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[SIMDPirates]]
deps = ["VectorizationBase"]
git-tree-sha1 = "4a33e07340324e2d3200c72231b5db829b7bad1f"
repo-rev = "master"
repo-url = "https://github.com/chriselrod/SIMDPirates.jl"
git-tree-sha1 = "ecacd3f808e559d9e363f2620041c6286c8efaca"
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
version = "0.3.16"
version = "0.4.0"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -55,8 +53,6 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[VectorizationBase]]
deps = ["CpuId", "LinearAlgebra"]
git-tree-sha1 = "794a8d4ad8c817f1c7b7598b3d858891ab100722"
repo-rev = "master"
repo-url = "https://github.com/chriselrod/VectorizationBase.jl"
git-tree-sha1 = "b9b5c8fa55e9b859989e759f405624d16b0b0ca2"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.4.1"
version = "0.4.2"
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name = "SLEEFPirates"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
authors = ["chriselrod <[email protected]>"]
version = "0.3.7"
version = "0.3.8"

[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
SIMDPirates = "21efa798-c60a-11e8-04d3-e1a92915a26a"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

[compat]
SIMDPirates = "~0.3.15"
SIMDPirates = "~0.3.15, 0.4"
VectorizationBase = "~0.4"
julia = "1"

Expand Down
6 changes: 4 additions & 2 deletions src/SLEEFPirates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Base.Math: uinttype, exponent_bias, exponent_mask, significand_bits, IEEEF
using SIMDPirates
using SIMDPirates: vifelse, vzero, AbstractStructVec

export SVec, loggamma, logit, invlogit, nlogit, ninvlogit
export SVec, loggamma, logit, invlogit, nlogit, ninvlogit, log1m

const FloatType64 = Union{Float64,SVec{<:Any,Float64}}
const FloatType32 = Union{Float32,SVec{<:Any,Float32}}
Expand Down Expand Up @@ -193,7 +193,9 @@ end
@inline ninvlogit(x::SIMDPirates.Vec{W,T}) where {W,T} = SIMDPirates.vfdiv( vbroadcast(Vec{W,T},one(T)), vadd(vbroadcast(Vec{W,T},one(T)), exp(x)))
@inline SIMDPirates.vexp(v::AbstractStructVec{W,Float32}) where {W} = exp(v)
@inline SIMDPirates.vlog(v::AbstractStructVec{W,Float32}) where {W} = log(v)

@inline log1m(x) = Base.log(Base.FastMath.sub_fast(one(x), x))
@inline log1m(v::Vec{W,T}) where {W,T} = log(vsub(vone(Vec{W,T}), v))
@inline log1m(v::AbstractStructVec{W,T}) where {W,T} = log(vsub(vone(Vec{W,T}), v))

include("precompile.jl")
_precompile_()
Expand Down
85 changes: 37 additions & 48 deletions src/sleef.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,59 +113,48 @@ end
@inline log2(v::SVec{8,Float64}) = SVec(log2(extract_data(v)))
@inline Base.log2(v::SVec{8,Float64}) = SVec(log2(extract_data(v)))



@inline function exp2(v::Vec{8,Float64})
Base.llvmcall(("""
declare <4 x double> @llvm.x86.avx.round.pd.256(<4 x double>, i32)
declare <8 x i32> @llvm.x86.avx512.mask.cvtpd2dq.512(<8 x double>, <8 x i32>, i8, i32)
declare <8 x i64> @llvm.x86.avx512.mask.cvtpd2qq.512(<8 x double>, <8 x i64>, i8, i32)
declare <8 x double> @llvm.fma.v8f64(<8 x double>, <8 x double>, <8 x double>)
""","""
%2 = shufflevector <8 x double> %0, <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
%3 = shufflevector <8 x double> %0, <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
%4 = tail call <4 x double> @llvm.x86.avx.round.pd.256(<4 x double> %2, i32 8) #13
%5 = tail call <4 x double> @llvm.x86.avx.round.pd.256(<4 x double> %3, i32 8) #13
%6 = shufflevector <4 x double> %5, <4 x double> %4, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%7 = tail call <8 x i32> @llvm.x86.avx512.mask.cvtpd2dq.512(<8 x double> %0, <8 x i32> zeroinitializer, i8 -1, i32 8) #13
%8 = fsub <8 x double> %0, %6
%9 = fmul <8 x double> %8, %8
%10 = fmul <8 x double> %9, %9
%11 = fmul <8 x double> %10, %10
%12 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %8, <8 x double> <double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150>, <8 x double> <double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17>) #13
%13 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %8, <8 x double> <double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979>, <8 x double> <double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81>) #13
%14 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %8, <8 x double> <double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80>, <8 x double> <double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC>) #13
%15 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %9, <8 x double> %13, <8 x double> %14) #13
%16 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %8, <8 x double> <double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960>, <8 x double> <double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0>) #13
%17 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %8, <8 x double> <double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F>, <8 x double> <double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1>) #13
%18 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %9, <8 x double> %16, <8 x double> %17) #13
%19 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %10, <8 x double> %15, <8 x double> %18) #13
%20 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %11, <8 x double> %12, <8 x double> %19) #13
%21 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %20, <8 x double> %8, <8 x double> <double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF>) #13
%22 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %21, <8 x double> %8, <8 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>) #13
%23 = ashr <8 x i32> %7, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
%24 = add nsw <8 x i32> %23, <i32 1023, i32 1023, i32 1023, i32 1023, i32 1023, i32 1023, i32 1023, i32 1023>
%25 = bitcast <8 x i32> %24 to <4 x i64>
%26 = shufflevector <4 x i64> %25, <4 x i64> undef, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef>
%27 = bitcast <8 x i64> %26 to <16 x i32>
%28 = shufflevector <16 x i32> %27, <16 x i32> undef, <16 x i32> <i32 undef, i32 0, i32 undef, i32 1, i32 undef, i32 2, i32 undef, i32 3, i32 undef, i32 4, i32 undef, i32 5, i32 undef, i32 6, i32 undef, i32 7>
%29 = shufflevector <16 x i32> <i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef>, <16 x i32> %28, <16 x i32> <i32 0, i32 17, i32 2, i32 19, i32 4, i32 21, i32 6, i32 23, i32 8, i32 25, i32 10, i32 27, i32 12, i32 29, i32 14, i32 31>
%30 = shl <16 x i32> %29, <i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20>
%31 = bitcast <16 x i32> %30 to <8 x double>
%32 = fmul <8 x double> %22, %31
%33 = add <8 x i32> %7, <i32 1023, i32 1023, i32 1023, i32 1023, i32 1023, i32 1023, i32 1023, i32 1023>
%34 = sub <8 x i32> %33, %23
%35 = bitcast <8 x i32> %34 to <4 x i64>
%36 = shufflevector <4 x i64> %35, <4 x i64> undef, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef>
%37 = bitcast <8 x i64> %36 to <16 x i32>
%38 = shufflevector <16 x i32> %37, <16 x i32> undef, <16 x i32> <i32 undef, i32 0, i32 undef, i32 1, i32 undef, i32 2, i32 undef, i32 3, i32 undef, i32 4, i32 undef, i32 5, i32 undef, i32 6, i32 undef, i32 7>
%39 = shufflevector <16 x i32> <i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef, i32 0, i32 undef>, <16 x i32> %38, <16 x i32> <i32 0, i32 17, i32 2, i32 19, i32 4, i32 21, i32 6, i32 23, i32 8, i32 25, i32 10, i32 27, i32 12, i32 29, i32 14, i32 31>
%40 = shl <16 x i32> %39, <i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20>
%41 = bitcast <16 x i32> %40 to <8 x double>
%42 = fmul <8 x double> %32, %41
%43 = fcmp oge <8 x double> %0, <double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03>
%44 = select <8 x i1> %43, <8 x double> <double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000>, <8 x double> %42
%45 = fcmp olt <8 x double> %0, <double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03>
%46 = select <8 x i1> %45, <8 x double> zeroinitializer, <8 x double> %44
ret <8 x double> %46
%2 = tail call <8 x i64> @llvm.x86.avx512.mask.cvtpd2qq.512(<8 x double> %0, <8 x i64> zeroinitializer, i8 -1, i32 8) #13
%3 = sitofp <8 x i64> %2 to <8 x double>
%4 = fsub <8 x double> %0, %3
%5 = fmul <8 x double> %4, %4
%6 = fmul <8 x double> %5, %5
%7 = fmul <8 x double> %6, %6
%8 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %4, <8 x double> <double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150, double 0x3DFE7901CA95E150>, <8 x double> <double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17, double 0x3E3E6106D72C1C17>) #13
%9 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %4, <8 x double> <double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979, double 0x3E7B5266946BF979>, <8 x double> <double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81, double 0x3EB62BFCDABCBB81>) #13
%10 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %4, <8 x double> <double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80, double 0x3EEFFCBFBC12CC80>, <8 x double> <double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC, double 0x3F24309130CB34EC>) #13
%11 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %5, <8 x double> %9, <8 x double> %10) #13
%12 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %4, <8 x double> <double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960, double 0x3F55D87FE78C5960>, <8 x double> <double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0, double 0x3F83B2AB6FBA08F0>) #13
%13 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %4, <8 x double> <double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F, double 0x3FAC6B08D704A01F>, <8 x double> <double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1, double 0x3FCEBFBDFF82C5A1>) #13
%14 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %5, <8 x double> %12, <8 x double> %13) #13
%15 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %6, <8 x double> %11, <8 x double> %14) #13
%16 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %7, <8 x double> %8, <8 x double> %15) #13
%17 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %16, <8 x double> %4, <8 x double> <double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF, double 0x3FE62E42FEFA39EF>) #13
%18 = tail call <8 x double> @llvm.fma.v8f64(<8 x double> %17, <8 x double> %4, <8 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>) #13
%19 = ashr <8 x i64> %2, <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1>
%20 = add nsw <8 x i64> %19, <i64 1023, i64 1023, i64 1023, i64 1023, i64 1023, i64 1023, i64 1023, i64 1023>
%21 = shl <8 x i64> %20, <i64 52, i64 52, i64 52, i64 52, i64 52, i64 52, i64 52, i64 52>
%22 = bitcast <8 x i64> %21 to <8 x double>
%23 = fmul <8 x double> %18, %22
%24 = add <8 x i64> %2, <i64 1023, i64 1023, i64 1023, i64 1023, i64 1023, i64 1023, i64 1023, i64 1023>
%25 = sub <8 x i64> %24, %19
%26 = shl <8 x i64> %25, <i64 52, i64 52, i64 52, i64 52, i64 52, i64 52, i64 52, i64 52>
%27 = bitcast <8 x i64> %26 to <8 x double>
%28 = fmul <8 x double> %23, %27
%29 = fcmp oge <8 x double> %0, <double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03, double 1.024000e+03>
%30 = select <8 x i1> %29, <8 x double> <double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000, double 0x7FF0000000000000>, <8 x double> %28
%31 = fcmp olt <8 x double> %0, <double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03, double -2.000000e+03>
%32 = select <8 x i1> %31, <8 x double> zeroinitializer, <8 x double> %30
ret <8 x double> %32
"""), Vec{8,Float64}, Tuple{Vec{8,Float64}}, v)
end

@inline exp2(v::SVec{8,Float64}) = SVec(exp2(extract_data(v)))
@inline Base.exp2(v::SVec{8,Float64}) = SVec(exp2(extract_data(v)))

2 comments on commit 2f73ead

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/9966

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.8 -m "<description of version>" 2f73ead34ef70965bc1654a10ef5664a583812d1
git push origin v0.3.8

Please sign in to comment.