Skip to content

Commit

Permalink
Add more ifelse methods
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Oct 2, 2021
1 parent eb2aa6c commit 469c4bb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
32 changes: 30 additions & 2 deletions src/SIMDDualNumbers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module SIMDDualNumbers

using VectorizationBase, SLEEFPirates, ForwardDiff
using VectorizationBase: AbstractSIMD
using IfElse: ifelse
using VectorizationBase: AbstractSIMD, AbstractMask
import IfElse: ifelse

@generated function Base.abs(x::ForwardDiff.Dual{TAG,S,N}) where {TAG,S<:AbstractSIMD,N}
quote
Expand Down Expand Up @@ -78,5 +78,33 @@ end
end
end

@generated function ifelse(m::AbstractMask, x::ForwardDiff.Dual{TAG,V,P}, y::ForwardDiff.Dual{TAG,V,P}) where {TAG,V,P}
quote
$(Expr(:meta,:inline))
z = $ifelse(m, ForwardDiff.value(x), ForwardDiff.value(y))
px = ForwardDiff.partials(x)
py = ForwardDiff.partials(y)
p = Base.Cartesian.@ntuple $P p -> $ifelse(m, px[p], py[p])
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
end
end
@generated function ifelse(m::AbstractMask, x::Number, y::ForwardDiff.Dual{TAG,V,P}) where {TAG,V,P}
quote
$(Expr(:meta,:inline))
z = $ifelse(m, x, ForwardDiff.value(y))
py = ForwardDiff.partials(y)
p = Base.Cartesian.@ntuple $P p -> $ifelse(m, zero($V), py[p])
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
end
end
@generated function ifelse(m::AbstractMask, x::ForwardDiff.Dual{TAG,V,P}, y::Number) where {TAG,V,P}
quote
$(Expr(:meta,:inline))
z = $ifelse(m, ForwardDiff.value(x), y)
px = ForwardDiff.partials(x)
p = Base.Cartesian.@ntuple $P p -> $ifelse(m, px[p], zero($V))
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
end
end

end
26 changes: 16 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ function toaos(d::ForwardDiff.Dual{TAG,Vec{W,T},P}) where {TAG,W,T,P}
end
end

function test(ref, vanswer)
answer = toaos(vanswer)
for i eachindex(ref)
@test ref[i] answer[i]
end
end

@testset "SIMDDualNumbers.jl" begin

dx = ForwardDiff.Dual(
Expand All @@ -26,20 +33,19 @@ end
dxaos = toaos(dx)
dyaos = toaos(dy)
for uf [SIMDDualNumbers.tanh_fast, SIMDDualNumbers.sigmoid_fast, abs, VectorizationBase.relu]
ref = map(uf, dxaos)
answer = toaos(uf(dx))
for i eachindex(ref)
@test ref[i] answer[i]
end
test(map(uf, dxaos), uf(dx))
end
for bf [max, min]
ref = map(bf, dxaos, dyaos)
answer = toaos(bf(dx, dy))
for i eachindex(ref)
@test ref[i] answer[i]
end
test(map(bf, dxaos, dyaos), bf(dx, dy))
end

vz = Vec(ntuple(_ -> rand(), VectorizationBase.pick_vector_width(Float64))...)
tz = Tuple(vz)
cmp = dx > dy
cmpaos = dxaos .> dyaos
test(map(ifelse, cmpaos, dxaos, dyaos), SIMDDualNumbers.ifelse(cmp, dx, dy))
test(map(ifelse, cmpaos, tz, dyaos), SIMDDualNumbers.ifelse(cmp, vz, dy))
test(map(ifelse, cmpaos, dxaos, tz), SIMDDualNumbers.ifelse(cmp, dx, vz))

Aqua.test_all(SIMDDualNumbers, ambiguities=false) #TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped
end

2 comments on commit 469c4bb

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/45977

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" 469c4bb462e581b38c5639e1baac6a071bc855ac
git push origin v0.1.0

Please sign in to comment.