diff --git a/test/forwarddiffext.jl b/test/forwarddiffext.jl index 90b167e3..b4b905c7 100644 --- a/test/forwarddiffext.jl +++ b/test/forwarddiffext.jl @@ -1,3 +1,4 @@ +using Base: Forward using NNlib, LoopVectorization, VectorizationBase, ForwardDiff, Test randnvec() = Vec(ntuple(_ -> randn(), pick_vector_width(Float64))...) @@ -15,6 +16,20 @@ function tovec(x::ForwardDiff.Dual{T,V,N}) where {T,V,N} return ret end +if LoopVectorization.ifelse !== Base.ifelse + @inline function NNlib.leakyrelu( + x::LoopVectorization.AbstractSIMD, + a = NNlib.oftf(x, NNlib.leakyrelu_a), + ) + LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower + end + @inline function NNlib.leakyrelu( + x::ForwardDiff.Dual{<:Any,<:LoopVectorization.AbstractSIMD}, + a = NNlib.oftf(x, NNlib.leakyrelu_a), + ) + LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower + end +end vx0 = randnvec() vx1 = randnvec() diff --git a/test/gemm.jl b/test/gemm.jl index c0330b2e..73cb7955 100644 --- a/test/gemm.jl +++ b/test/gemm.jl @@ -8,7 +8,10 @@ Unum, Tnum = LoopVectorization.register_count() == 16 ? (2, 6) : (4, 6) end Unumt, Tnumt = LoopVectorization.register_count() == 16 ? (2, 6) : (5, 5) - if LoopVectorization.register_count() != 8 + if (LoopVectorization.register_count() != 8) && ( + (LoopVectorization.pick_vector_width(Float64) != 2) || + (LoopVectorization.register_count() != 16) + ) @test @inferred(LoopVectorization.matmul_params()) == (Unum, Tnum) end @@ -30,7 +33,10 @@ end ) lsAmulBt1 = LoopVectorization.loopset(AmulBtq1) - if LoopVectorization.register_count() != 8 + if (LoopVectorization.register_count() != 8) && ( + (LoopVectorization.pick_vector_width(Float64) != 2) || + (LoopVectorization.register_count() != 16) + ) @test LoopVectorization.choose_order(lsAmulBt1) == (Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum) end @@ -43,7 +49,10 @@ end ) lsAmulB1 = LoopVectorization.loopset(AmulBq1) - if LoopVectorization.register_count() != 8 + if (LoopVectorization.register_count() != 8) && ( + (LoopVectorization.pick_vector_width(Float64) != 2) || + (LoopVectorization.register_count() != 16) + ) @test LoopVectorization.choose_order(lsAmulB1) == (Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum) end @@ -56,7 +65,10 @@ end ) lsAmulB2 = LoopVectorization.loopset(AmulBq2) - if LoopVectorization.register_count() != 8 + if (LoopVectorization.register_count() != 8) && ( + (LoopVectorization.pick_vector_width(Float64) != 2) || + (LoopVectorization.register_count() != 16) + ) @test LoopVectorization.choose_order(lsAmulB2) == (Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum) end @@ -70,11 +82,12 @@ end ) lsAmulB3 = LoopVectorization.loopset(AmulBq3) - if LoopVectorization.register_count() != 8 + if (LoopVectorization.register_count() != 8) && ( + (LoopVectorization.pick_vector_width(Float64) != 2) || + (LoopVectorization.register_count() != 16) + ) @test LoopVectorization.choose_order(lsAmulB3) == (Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum) - end - if LoopVectorization.register_count() != 8 for (fA, fB, v, Un, Tn) ∈ [ (identity, identity, :m, Unum, Tnum), (adjoint, identity, :k, Unumt, Tnumt), @@ -177,7 +190,8 @@ end ) lsAmuladd = LoopVectorization.loopset(Amuladdq) - if LoopVectorization.register_count() != 8 + if LoopVectorization.register_count() != 8 && + LoopVectorization.pick_vector_width(Float64) != 2 @test LoopVectorization.choose_order(lsAmuladd) == (Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum) end @@ -410,9 +424,13 @@ @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 3, 7) end elseif LoopVectorization.register_count() == 16 - # @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 1, 6) - # @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 2, 4) - @test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :n, :m, :m, 3, 3) + if LoopVectorization.pick_vector_width(Float64) == 4 + # @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 1, 6) + # @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 2, 4) + @test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :n, :m, :m, 3, 3) + elseif LoopVectorization.pick_vector_width(Float64) == 2 + @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :n, :m, :m, 3, 3) + end end function rank2AmulBavx!(C, Aₘ, Aₖ, B) @turbo for m ∈ axes(C, 1), n ∈ axes(C, 2)