From 963465242f56c208a6fce6e106835f2ad7bafe86 Mon Sep 17 00:00:00 2001 From: Zentrik Date: Wed, 29 May 2024 14:24:49 +0100 Subject: [PATCH] Fix fastmath for vararg +, *, min, max methods (#54513) Currently using the fastmath vararg +, *, min, max methods only actually sets fastmath if they are specifically overloaded even when the correct 2 argument methods have been defined. As such, `ComplexF32, ComplexF64` do not currently set fastmath when using the vararg methods. This will also fix any other types, such as those in SIMD.jl, which don't overload the vararg methods. E.g. ```julia x = ComplexF64(1) f(x) = @fastmath x + x + x ``` now works correctly. I see no reason why the vararg methods shouldn't default to using the fastmath 2 argument methods instead of the non fastmath ones, which is the current behaviour. I also switched the implementation to use `afoldl` as that's what the non fastmath vararg methods use. Fixes #54456 and https://github.com/eschnett/SIMD.jl/issues/108. --- base/fastmath.jl | 36 +++++++++++++++++++++++++++--------- test/fastmath.jl | 23 +++++++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/base/fastmath.jl b/base/fastmath.jl index 3f908f73d9742..824491acaf877 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -30,6 +30,7 @@ export @fastmath import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast, add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast +import Base: afoldl const fast_op = Dict(# basic arithmetic @@ -168,11 +169,6 @@ sub_fast(x::T, y::T) where {T<:FloatTypes} = sub_float_fast(x, y) mul_fast(x::T, y::T) where {T<:FloatTypes} = mul_float_fast(x, y) div_fast(x::T, y::T) where {T<:FloatTypes} = div_float_fast(x, y) -add_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} = - add_fast(add_fast(x, y), zs...) -mul_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} = - mul_fast(mul_fast(x, y), zs...) - @fastmath begin cmp_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(x==y, 0, ifelse(x x, y, x) min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y) minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x)) - - max_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = max_fast(max_fast(x, y), z...) - min_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = min_fast(min_fast(x, y), z...) end # fall-back implementations and type promotion @@ -260,7 +253,7 @@ for op in (:abs, :abs2, :conj, :inv, :sign) end end -for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax) +for op in (:-, :/, :(==), :!=, :<, :<=, :cmp, :rem, :minmax) op_fast = fast_op[op] @eval begin # fall-back implementation for non-numeric types @@ -273,6 +266,31 @@ for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax) end end +for op in (:+, :*, :min, :max) + op_fast = fast_op[op] + @eval begin + $op_fast(x) = $op(x) + # fall-back implementation for non-numeric types + $op_fast(x, y) = $op(x, y) + # type promotion + $op_fast(x::Number, y::Number) = + $op_fast(promote(x,y)...) + # fall-back implementation that applies after promotion + $op_fast(x::T,y::T) where {T<:Number} = $op(x,y) + # note: these definitions must not cause a dispatch loop when +(a,b) is + # not defined, and must only try to call 2-argument definitions, so + # that defining +(a,b) is sufficient for full functionality. + ($op_fast)(a, b, c, xs...) = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...)) + # a further concern is that it's easy for a type like (Int,Int...) + # to match many definitions, so we need to keep the number of + # definitions down to avoid losing type information. + # type promotion + $op_fast(a::Number, b::Number, c::Number, xs::Number...) = + $op_fast(promote(x,y,c,xs...)...) + # fall-back implementation that applies after promotion + $op_fast(a::T, b::T, c::T, xs::T...) where {T<:Number} = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...)) + end +end # Math functions exp2_fast(x::Union{Float32,Float64}) = Base.Math.exp2_fast(x) diff --git a/test/fastmath.jl b/test/fastmath.jl index abcecce96f031..d10b5a739483d 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -1,7 +1,30 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +using InteractiveUtils: code_llvm # fast math +@testset "check fast present in LLVM" begin + for T in (Float16, Float32, Float64, ComplexF32, ComplexF64) + f(x) = @fastmath x + x + x + llvm = sprint(code_llvm, f, (T,)) + @test occursin("fast", llvm) + + g(x) = @fastmath x * x * x + llvm = sprint(code_llvm, g, (T,)) + @test occursin("fast", llvm) + end + + for T in (Float16, Float32, Float64) + f(x, y, z) = @fastmath min(x, y, z) + llvm = sprint(code_llvm, f, (T,T,T)) + @test occursin("fast", llvm) + + g(x, y, z) = @fastmath max(x, y, z) + llvm = sprint(code_llvm, g, (T,T,T)) + @test occursin("fast", llvm) + end +end + @testset "check expansions" begin @test macroexpand(Main, :(@fastmath 1+2)) == :(Base.FastMath.add_fast(1,2)) @test macroexpand(Main, :(@fastmath +)) == :(Base.FastMath.add_fast)