Skip to content

Commit

Permalink
3-arg max and min
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Feb 3, 2025
1 parent 50b6292 commit 5c22476
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ using Base.Math: throw_complex_domainerror
@device_override Base.min(x::Float32, y::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.min(x::Float16, y::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.min_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.min(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.min(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmin3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)

@device_override FastMath.max_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.max(x::Float32, y::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.max(x::Float16, y::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.max_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.max(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.max(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmax3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)

@device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.acos(x::Float16) = ccall("extern air.acos.f16", llvmcall, Float16, (Float16,), x)
Expand Down
4 changes: 2 additions & 2 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ MATH_INTR_FUNCS_3_ARG = [
# mix, # T mix(T x, T y, T a) # x+(y-x)*a
# smoothstep, # T smoothstep(T edge0, T edge1, T x)
fma, # T fma(T a, T b, T c)
# max3, # T max3(T x, T y, T z)
max, # T max3(T x, T y, T z)
# median3, # T median3(T x, T y, T z)
# min3, # T min3(T x, T y, T z)
min, # T min3(T x, T y, T z)
]

@testset "math" begin
Expand Down

0 comments on commit 5c22476

Please sign in to comment.