Skip to content

Commit

Permalink
Add nextafter intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Jan 29, 2025
1 parent 1b811cb commit 2569fe1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
7 changes: 6 additions & 1 deletion src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ end
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x)

@static if Metal.is_macos(v"14")
@device_function nextafter(x::Float32, y::Float32) = ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_function nextafter(x::Float16, y::Float16) = ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y)
end

# hypot without use of double
#
# taken from Cosmopolitan Libc
Expand Down Expand Up @@ -418,7 +423,7 @@ end
j = fma(1.442695f0, a, 12582912.0f0)
j = j - 12582912.0f0
i = unsafe_trunc(Int32, j)
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -1.42860677f-6, f) # log_2_lo

# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
Expand Down
25 changes: 21 additions & 4 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,27 @@ end
vecF = Array(SpecialFunctions.erfinv.(bufferF))
@test vecF SpecialFunctions.erfinv.(f)

f = collect(LinRange(nextfloat(-88f0), 88f0, 100))
bufferF = MtlArray(f)
vecF = Array(expm1.(bufferF))
@test vecF expm1.(f)
g = collect(LinRange(nextfloat(-88f0), 88f0, 100))
bufferG = MtlArray(g)
vecG = Array(expm1.(bufferG))
@test vecG expm1.(g)

if Metal.is_macos(v"14")
function nextafter_test(X, y)
idx = thread_position_in_grid_1d()
X[idx] = Metal.nextafter(X[idx], y)
return nothing
end
h = rand(Float32,1)
bufferH = MtlArray(h)
@metal nextafter_test(bufferH,typemax(Float32))
synchronize()
@test Array(bufferH) nextfloat.(h)

@metal nextafter_test(bufferH,typemin(Float32))
synchronize()
@test Array(bufferH) h
end
end

############################################################################################
Expand Down

0 comments on commit 2569fe1

Please sign in to comment.