Skip to content

Commit

Permalink
tag doubles
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Sep 17, 2024
1 parent 6ba1672 commit 72b20b1
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions src/device/intrinsics/output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0)
const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16)
const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17)

const ALLOW_DOUBLE_META = "allowdouble"

export @mtlprintf

@generated function promote_c_argument(arg)
Expand All @@ -18,13 +20,39 @@ export @mtlprintf

if arg == Cchar || arg == Cshort
return :(Cint(arg))
elseif arg == Cfloat
return :(Cdouble(arg))
else
return :(arg)
end
end

@generated function tag_doubles(arg)
@dispose ctx=Context() begin
ret = arg == Cfloat ? Cdouble : arg
T_arg = convert(LLVMType, arg)
T_ret = convert(LLVMType, ret)

f, ft = create_function(T_ret, [T_arg])

@dispose builder=IRBuilder() begin
entry = BasicBlock(f, "entry")
position!(builder, entry)

p1 = parameters(f)[1]

if arg == Cfloat
res = fpext!(builder, p1, LLVM.DoubleType())
metadata(res)["ir_check_ignore"] = MDNode([])
ret!(builder, res)
else
ret!(builder, p1)
end
end

call_function(f, ret, Tuple{arg}, :arg)
end
end


"""
@mtlprintf("%Fmt", args...)
Expand All @@ -33,7 +61,7 @@ Print a formatted string in device context on the host standard output.
macro mtlprintf(fmt::String, args...)
fmt_val = Val(Symbol(fmt))

return :(_mtlprintf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...)))
end

@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt}
Expand Down

0 comments on commit 72b20b1

Please sign in to comment.