Skip to content

Commit

Permalink
Fix fastmath for vararg +, *, min, max methods (#54513)
Browse files Browse the repository at this point in the history
Currently using the fastmath vararg +, *, min, max methods only actually
sets fastmath if they are specifically overloaded even when the correct
2 argument methods have been defined.
As such, `ComplexF32, ComplexF64` do not currently set fastmath when
using the vararg methods. This will also fix any other types, such as
those in SIMD.jl, which don't overload the vararg methods.

E.g. 
```julia
x = ComplexF64(1)
f(x) = @fastmath x + x + x
```
now works correctly.

I see no reason why the vararg methods shouldn't default to using the
fastmath 2 argument methods instead of the non fastmath ones, which is
the current behaviour.

I also switched the implementation to use `afoldl` as that's what the
non fastmath vararg methods use.

Fixes #54456 and eschnett/SIMD.jl#108.
  • Loading branch information
Zentrik authored May 29, 2024
1 parent fe35189 commit 9634652
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
36 changes: 27 additions & 9 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export @fastmath
import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast,
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
import Base: afoldl

const fast_op =
Dict(# basic arithmetic
Expand Down Expand Up @@ -168,11 +169,6 @@ sub_fast(x::T, y::T) where {T<:FloatTypes} = sub_float_fast(x, y)
mul_fast(x::T, y::T) where {T<:FloatTypes} = mul_float_fast(x, y)
div_fast(x::T, y::T) where {T<:FloatTypes} = div_float_fast(x, y)

add_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} =
add_fast(add_fast(x, y), zs...)
mul_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} =
mul_fast(mul_fast(x, y), zs...)

@fastmath begin
cmp_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(x==y, 0, ifelse(x<y, -1, +1))
log_fast(b::T, x::T) where {T<:FloatTypes} = log_fast(x)/log_fast(b)
Expand Down Expand Up @@ -245,9 +241,6 @@ ComplexTypes = Union{ComplexF32, ComplexF64}
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))

max_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = max_fast(max_fast(x, y), z...)
min_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = min_fast(min_fast(x, y), z...)
end

# fall-back implementations and type promotion
Expand All @@ -260,7 +253,7 @@ for op in (:abs, :abs2, :conj, :inv, :sign)
end
end

for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
for op in (:-, :/, :(==), :!=, :<, :<=, :cmp, :rem, :minmax)
op_fast = fast_op[op]
@eval begin
# fall-back implementation for non-numeric types
Expand All @@ -273,6 +266,31 @@ for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
end
end

for op in (:+, :*, :min, :max)
op_fast = fast_op[op]
@eval begin
$op_fast(x) = $op(x)
# fall-back implementation for non-numeric types
$op_fast(x, y) = $op(x, y)
# type promotion
$op_fast(x::Number, y::Number) =
$op_fast(promote(x,y)...)
# fall-back implementation that applies after promotion
$op_fast(x::T,y::T) where {T<:Number} = $op(x,y)
# note: these definitions must not cause a dispatch loop when +(a,b) is
# not defined, and must only try to call 2-argument definitions, so
# that defining +(a,b) is sufficient for full functionality.
($op_fast)(a, b, c, xs...) = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...))
# a further concern is that it's easy for a type like (Int,Int...)
# to match many definitions, so we need to keep the number of
# definitions down to avoid losing type information.
# type promotion
$op_fast(a::Number, b::Number, c::Number, xs::Number...) =
$op_fast(promote(x,y,c,xs...)...)
# fall-back implementation that applies after promotion
$op_fast(a::T, b::T, c::T, xs::T...) where {T<:Number} = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...))
end
end

# Math functions
exp2_fast(x::Union{Float32,Float64}) = Base.Math.exp2_fast(x)
Expand Down
23 changes: 23 additions & 0 deletions test/fastmath.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using InteractiveUtils: code_llvm
# fast math

@testset "check fast present in LLVM" begin
for T in (Float16, Float32, Float64, ComplexF32, ComplexF64)
f(x) = @fastmath x + x + x
llvm = sprint(code_llvm, f, (T,))
@test occursin("fast", llvm)

g(x) = @fastmath x * x * x
llvm = sprint(code_llvm, g, (T,))
@test occursin("fast", llvm)
end

for T in (Float16, Float32, Float64)
f(x, y, z) = @fastmath min(x, y, z)
llvm = sprint(code_llvm, f, (T,T,T))
@test occursin("fast", llvm)

g(x, y, z) = @fastmath max(x, y, z)
llvm = sprint(code_llvm, g, (T,T,T))
@test occursin("fast", llvm)
end
end

@testset "check expansions" begin
@test macroexpand(Main, :(@fastmath 1+2)) == :(Base.FastMath.add_fast(1,2))
@test macroexpand(Main, :(@fastmath +)) == :(Base.FastMath.add_fast)
Expand Down

0 comments on commit 9634652

Please sign in to comment.