diff --git a/docs/src/api/kernel.md b/docs/src/api/kernel.md index a995d03de..dd24000bb 100644 --- a/docs/src/api/kernel.md +++ b/docs/src/api/kernel.md @@ -53,4 +53,13 @@ MtlThreadGroupArray MemoryFlags threadgroup_barrier simdgroup_barrier +``` + +## Printing + +```@docs +@mtlprintf +@mtlprint +@mtlprintln +@mtlshow ``` \ No newline at end of file diff --git a/docs/src/usage/kernel.md b/docs/src/usage/kernel.md index e4e84f21f..8e910334c 100644 --- a/docs/src/usage/kernel.md +++ b/docs/src/usage/kernel.md @@ -84,6 +84,34 @@ Additional notes: - Kernels must always return nothing - Kernels are asynchronous. To synchronize, use the `Metal.@sync` macro. +## Printing + +When debugging, it's not uncommon to want to print some values. This is achieved with `@mtlprintf`: + +```julia +function gpu_add2_print!(y, x) + index = thread_position_in_grid_1d() + @mtlprintf("thread %d", index) + @inbounds y[i] += x[i] + return nothing +end + +A = Metal.ones(Float32, 8); +B = Metal.rand(Float32, 8); + +@metal threads=length(A) gpu_add2_print!(A, B) +``` + +`@mtlprintf` is supported on macOS 15 and later. `@mtlprintf` support most of the format specifiers that `printf` +supports in C with the following exceptions: + - `%n` and `%s` conversion specifiers are not supported + - Default argument promotion applies to arguments of half type which promote to the `double` type + - The format string must be a string literal + +Metal places output from `@mtlprintf` into a log buffer. The system only removes the messages from the log buffer when the command buffer completes. When the log buffer becomes full, the system drops all subsequent messages. + +See also: `@mtlprint`, `@mtlprintln` and `@mtlshow` + ## Other Helpful Links [Metal Shading Language Specification](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf) diff --git a/lib/mtl/MTL.jl b/lib/mtl/MTL.jl index 74a6f0921..bfeb2ecce 100644 --- a/lib/mtl/MTL.jl +++ b/lib/mtl/MTL.jl @@ -34,6 +34,7 @@ include("events.jl") include("fences.jl") include("heap.jl") include("buffer.jl") +include("log_state.jl") include("command_queue.jl") include("command_buf.jl") include("compute_pipeline.jl") diff --git a/lib/mtl/command_queue.jl b/lib/mtl/command_queue.jl index fe51b9aba..673260c68 100644 --- a/lib/mtl/command_queue.jl +++ b/lib/mtl/command_queue.jl @@ -1,3 +1,24 @@ + +export MTLCommandQueueDescriptor + +# @objcwrapper immutable=false MTLCommandQueueDescriptor <: NSObject + +function MTLCommandQueueDescriptor() + handle = @objc [MTLCommandQueueDescriptor alloc]::id{MTLCommandQueueDescriptor} + obj = MTLCommandQueueDescriptor(handle) + finalizer(release, obj) + @objc [obj::id{MTLCommandQueueDescriptor} init]::id{MTLCommandQueueDescriptor} + return obj +end + +function MTLCommandQueue(dev::MTLDevice, descriptor::MTLCommandQueueDescriptor) + handle = @objc [dev::id{MTLDevice} newCommandQueueWithDescriptor:descriptor::id{MTLCommandQueueDescriptor}]::id{MTLCommandQueue} + obj = MTLCommandQueue(handle) + finalizer(release, obj) + return obj +end + + export MTLCommandQueue # @objcwrapper immutable=false MTLCommandQueue <: NSObject @@ -8,3 +29,4 @@ function MTLCommandQueue(dev::MTLDevice) finalizer(release, obj) return obj end + diff --git a/lib/mtl/libmtl.jl b/lib/mtl/libmtl.jl index db5b98339..21aaea707 100644 --- a/lib/mtl/libmtl.jl +++ b/lib/mtl/libmtl.jl @@ -1395,7 +1395,7 @@ end @autoproperty dispatchType::MTLDispatchType end -@objcwrapper immutable = true availability = macos(v"15.0.0") MTLCommandQueueDescriptor <: NSObject +@objcwrapper immutable = false availability = macos(v"15.0.0") MTLCommandQueueDescriptor <: NSObject @objcproperties MTLCommandQueueDescriptor begin @autoproperty maxCommandBufferCount::UInt64 setter = setMaxCommandBufferCount @@ -2675,7 +2675,7 @@ end MTLLogLevelFault = 5 end -@objcwrapper immutable = true availability = macos(v"15.0.0") MTLLogStateDescriptor <: NSObject +@objcwrapper immutable = false availability = macos(v"15.0.0") MTLLogStateDescriptor <: NSObject @objcproperties MTLLogStateDescriptor begin @autoproperty level::MTLLogLevel setter = setLevel diff --git a/lib/mtl/log_state.jl b/lib/mtl/log_state.jl new file mode 100644 index 000000000..2e7459f3a --- /dev/null +++ b/lib/mtl/log_state.jl @@ -0,0 +1,26 @@ +export MTLLogLevel + +export MTLLogStateDescriptor + +# @objcwrapper immutable = false MTLLogStateDescriptor <: NSObject + +function MTLLogStateDescriptor() + handle = @objc [MTLLogStateDescriptor alloc]::id{MTLLogStateDescriptor} + obj = MTLLogStateDescriptor(handle) + finalizer(release, obj) + @objc [obj::id{MTLLogStateDescriptor} init]::id{MTLLogStateDescriptor} + return obj +end + + +export MTLLogState + +# @objcwrapper immutable = true MTLLogState <: NSObject + +function MTLLogState(dev::MTLDevice, descriptor::MTLLogStateDescriptor) + err = Ref{id{NSError}}(nil) + handle = @objc [dev::id{MTLDevice} newLogStateWithDescriptor:descriptor::id{MTLLogStateDescriptor} + error:err::Ptr{id{NSError}}]::id{MTLLogState} + err[] == nil || throw(NSError(err[])) + MTLLogState(handle) +end diff --git a/res/wrap/libmtl.toml b/res/wrap/libmtl.toml index 9f1fcc331..41c70f28c 100644 --- a/res/wrap/libmtl.toml +++ b/res/wrap/libmtl.toml @@ -50,6 +50,9 @@ immutable=false [api.MTLCommandQueue] immutable=false +[api.MTLCommandQueueDescriptor] +immutable=false + [api.MTLCompileOptions] immutable=false [api.MTLCompileOptions.proptype] @@ -85,6 +88,9 @@ immutable=false [api.MTLLibrary] immutable=false +[api.MTLLogStateDescriptor] +immutable=false + [api.MTLSharedEvent] immutable=false diff --git a/src/Metal.jl b/src/Metal.jl index 5e6f53754..a988bfbd3 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -35,6 +35,7 @@ include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") include("device/intrinsics/atomics.jl") +include("device/intrinsics/output.jl") include("device/quirks.jl") # array essentials diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index d1c3019e1..432573b54 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -104,9 +104,9 @@ function compile(@nospecialize(job::CompilerJob)) @signpost_interval log=log_compiler() "Generate LLVM IR" begin # TODO: on 1.9, this actually creates a context. cache those. - ir, entry = JuliaContext() do ctx + ir, entry, loggingEnabled = JuliaContext() do ctx mod, meta = GPUCompiler.compile(:llvm, job) - string(mod), LLVM.name(meta.entry) + string(mod), LLVM.name(meta.entry), haskey(functions(mod), "air.os_log") end end @@ -172,7 +172,7 @@ function compile(@nospecialize(job::CompilerJob)) end end - return (; ir, air, metallib, entry) + return (; ir, air, metallib, entry, loggingEnabled) end # link into an executable kernel @@ -210,5 +210,5 @@ end end end - pipeline_state + pipeline_state, compiled.loggingEnabled end diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 54a089a3a..0e911e550 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -161,6 +161,7 @@ mtlconvert(arg, cce=nothing) = adapt(Adaptor(cce), arg) struct HostKernel{F,TT} f::F pipeline::MTLComputePipelineState + loggingEnabled::Bool end const mtlfunction_lock = ReentrantLock() @@ -186,7 +187,7 @@ function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT} cache = compiler_cache(dev) source = methodinstance(F, tt) config = compiler_config(dev; name, kwargs...)::MetalCompilerConfig - pipeline = GPUCompiler.cached_compilation(cache, source, config, compile, link) + pipeline, loggingEnabled = GPUCompiler.cached_compilation(cache, source, config, compile, link) # create a callable object that captures the function instance. we don't need to think # about world age here, as GPUCompiler already does and will return a different object @@ -194,7 +195,7 @@ function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT} kernel = get(_kernel_instances, h, nothing) if kernel === nothing # create the kernel state object - kernel = HostKernel{F,tt}(f, pipeline) + kernel = HostKernel{F, tt}(f, pipeline, loggingEnabled) _kernel_instances[h] = kernel end return kernel::HostKernel{F,tt} @@ -275,7 +276,35 @@ end (threads.width * threads.height * threads.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup && throw(ArgumentError("Number of threads in group ($(threads.width * threads.height * threads.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)")) - cmdbuf = MTLCommandBuffer(queue) + cmdbuf = if kernel.loggingEnabled + # TODO: make this a dynamic error, i.e., from the kernel (JuliaGPU/Metal.jl#433) + @static if !is_macos(v"15.0.0") + error("Logging is only supported on macOS 15 or higher") + end + + if MTLCaptureManager().isCapturing + error("Logging is not supported while GPU frame capturing") + end + + log_state_descriptor = MTLLogStateDescriptor() + log_state_descriptor.level = MTL.MTLLogLevelDebug + log_state = MTLLogState(queue.device, log_state_descriptor) + + function log_handler(subSystem, category, logLevel, message) + Core.print(String(NSString(message))) + return nothing + end + + block = @objcblock(log_handler, Nothing, (id{NSString}, id{NSString}, NSInteger, id{NSString})) + @objc [log_state::id{MTLLogState} addLogHandler:block::id{NSBlock}]::Nothing + + cmdbuf_descriptor = MTLCommandBufferDescriptor() + cmdbuf_descriptor.logState = log_state + MTLCommandBuffer(queue, cmdbuf_descriptor) + else + MTLCommandBuffer(queue) + end + cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))" cce = MTLComputeCommandEncoder(cmdbuf) argument_buffers = try diff --git a/src/device/intrinsics/output.jl b/src/device/intrinsics/output.jl new file mode 100644 index 000000000..78bd0f6a5 --- /dev/null +++ b/src/device/intrinsics/output.jl @@ -0,0 +1,310 @@ +const MTLLOG_SUBSYSTEM = "com.juliagpu.metal.jl" +const MTLLOG_CATEGORY = "mtlprintf" + +const __METAL_OS_LOG_TYPE_DEBUG__ = Int32(2) +const __METAL_OS_LOG_TYPE_INFO__ = Int32(1) +const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0) +const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16) +const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17) + +export @mtlprintf + +@generated function promote_c_argument(arg) + # > When a function with a variable-length argument list is called, the variable + # > arguments are passed using C's old ``default argument promotions.'' These say that + # > types char and short int are automatically promoted to int, and type float is + # > automatically promoted to double. Therefore, varargs functions will never receive + # > arguments of type char, short int, or float. + + if arg == Cchar || arg == Cshort + return :(Cint(arg)) + elseif arg == Cfloat + return :(Cdouble(arg)) + else + return :(arg) + end +end + +function valist_size(dl, param_types) + size = 0 + for pty in param_types + ps = sizeof(dl, pty) + if size % ps == 0 + size += ps + else + size += (size % ps) + ps + end + end + + return size +end + +""" + @mtlprintf("%Fmt", args...) + +Print a formatted string in device context on the host standard output. +""" +macro mtlprintf(fmt::String, args...) + fmt_val = Val(Symbol(fmt)) + + return quote + _mtlprintf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)) + end +end + +@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt} + return @dispose ctx = Context() begin + arg_exprs = [:(argspec[$i]) for i in 1:length(argspec)] + arg_types = [argspec...] + + T_void = LLVM.VoidType() + T_int32 = LLVM.Int32Type() + T_int64 = LLVM.Int64Type() + T_pint8 = LLVM.PointerType(LLVM.Int8Type()) + T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2) + + # create functions + param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types] + wrapper_f, wrapper_ft = create_function(T_void, param_types) + mod = LLVM.parent(wrapper_f) + + llvm_ft = LLVM.FunctionType(T_void, LLVMType[]; vararg = true) + llvm_f = LLVM.Function(mod, "metal_os_log", llvm_ft) + push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0)) + + # generate IR + @dispose builder = IRBuilder() begin + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + + str = globalstring_ptr!(builder, String(fmt), addrspace = 2) + subsystem_str = globalstring_ptr!(builder, MTLLOG_SUBSYSTEM, addrspace = 2) + category_str = globalstring_ptr!(builder, MTLLOG_CATEGORY, addrspace = 2) + log_type = LLVM.ConstantInt(T_int32, __METAL_OS_LOG_TYPE_DEBUG__) + + # compute argsize + dl = datalayout(mod) + arg_size = LLVM.ConstantInt(T_int64, valist_size(dl, param_types)) + + alloc = alloca!(builder, T_pint8) + buffer = bitcast!(builder, alloc, T_pint8) + alloc_size = LLVM.ConstantInt(T_int64, sizeof(dl, T_pint8)) + + lifetime_start_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8]) + lifetime_start = LLVM.Function(mod, "llvm.lifetime.start.p0i8", lifetime_start_fty) + call!(builder, lifetime_start_fty, lifetime_start, [alloc_size, buffer]) + + va_start_fty = LLVM.FunctionType(T_void, [T_pint8]) + va_start = LLVM.Function(mod, "llvm.va_start", va_start_fty) + call!(builder, va_start_fty, va_start, [buffer]) + + arg_ptr = load!(builder, T_pint8, alloc) + + os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64]) + os_log = LLVM.Function(mod, "air.os_log", os_log_fty) + call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type, str, arg_ptr, arg_size]) + + va_end_fty = LLVM.FunctionType(T_void, [T_pint8]) + va_end = LLVM.Function(mod, "llvm.va_end", va_end_fty) + call!(builder, va_end_fty, va_end, [buffer]) + + lifetime_end_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8]) + lifetime_end = LLVM.Function(mod, "llvm.lifetime.end.p0i8", lifetime_end_fty) + call!(builder, lifetime_end_fty, lifetime_end, [alloc_size, buffer]) + + ret!(builder) + end + + @dispose builder = IRBuilder() begin + entry = BasicBlock(wrapper_f, "entry") + position!(builder, entry) + + call!(builder, llvm_ft, llvm_f, collect(parameters(wrapper_f))) + + ret!(builder) + end + + + call_function(wrapper_f, Nothing, Tuple{arg_types...}, arg_exprs...) + end +end + + +## print-like functionality + +export @mtlprint, @mtlprintln + +# simple conversions, defining an expression and the resulting argument type. nothing fancy, +# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`. +const mtlprint_conversions = [ + Float32 => (x -> :(Float64($x)), Float64), + Ptr{<:Any} => (x -> :(reinterpret(Int, $x)), Ptr{Cvoid}), + LLVMPtr{<:Any} => (x -> :(reinterpret(Int, $x)), Ptr{Cvoid}), + Bool => (x -> :(Int32($x)), Int32), +] + +# format specifiers +const mtlprint_specifiers = Dict( + # integers + Int16 => "%hd", + Int32 => "%d", + Int64 => "%ld", + UInt16 => "%hu", + UInt32 => "%u", + UInt64 => "%lu", + + # floating-point + Float32 => "%f", + + # other + Cchar => "%c", + Ptr{Cvoid} => "%p", + Cstring => "%s", +) + +@inline @generated function _mtlprint(parts...) + fmt = "" + args = Expr[] + + for i in 1:length(parts) + part = :(parts[$i]) + T = parts[i] + + # put literals directly in the format string + if T <: Val + fmt *= string(T.parameters[1]) + continue + end + + # try to convert arguments if they are not supported directly + if !haskey(mtlprint_specifiers, T) + for (Tmatch, rule) in mtlprint_conversions + if T <: Tmatch + part = rule[1](part) + T = rule[2] + break + end + end + end + + # render the argument + if haskey(mtlprint_specifiers, T) + fmt *= mtlprint_specifiers[T] + push!(args, part) + elseif T <: Tuple + fmt *= "(" + for (j, U) in enumerate(T.parameters) + if haskey(mtlprint_specifiers, U) + fmt *= mtlprint_specifiers[U] + push!(args, :($part[$j])) + if j < length(T.parameters) + fmt *= ", " + elseif length(T.parameters) == 1 + fmt *= "," + end + else + @error("@mtlprint does not support values of type $U") + end + end + fmt *= ")" + elseif T <: String + @error("@mtlprint does not support non-literal strings") + elseif T <: Type + fmt *= string(T.parameters[1]) + else + @warn("@mtlprint does not support values of type $T") + fmt *= "$(T)(...)" + end + end + + return quote + @mtlprintf($fmt, $(args...)) + end +end + +""" + @mtlprint(xs...) + @mtlprintln(xs...) + +Print a textual representation of values `xs` to standard output from the GPU. The +functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of +that API. However, that also means there's only limited support for argument types, handling +16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and +pointers. For more complex output, use `@mtlprintf` directly. + +Limited string interpolation is also possible: + +```julia + @mtlprint("Hello, World ", 42, "\\n") + @mtlprint "Hello, World \$(42)\\n" +``` +""" +macro mtlprint(parts...) + args = Union{Val, Expr, Symbol}[] + + parts = [parts...] + while true + isempty(parts) && break + + part = popfirst!(parts) + + # handle string interpolation + if isa(part, Expr) && part.head == :string + parts = vcat(part.args, parts) + continue + end + + # expose literals to the generator by using Val types + if isbits(part) # literal numbers, etc + push!(args, Val(part)) + elseif isa(part, QuoteNode) # literal symbols + push!(args, Val(part.value)) + elseif isa(part, String) # literal strings need to be interned + push!(args, Val(Symbol(part))) + else # actual values that will be passed to printf + push!(args, part) + end + end + + return quote + _mtlprint($(map(esc, args)...)) + end +end + +@doc (@doc @mtlprint) -> +macro mtlprintln(parts...) + return esc( + quote + Metal.@mtlprint($(parts...), "\n") + end + ) +end + +export @mtlshow + +""" + @mtlshow(ex) + +GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref). + +```julia +@mtlshow thread_position_in_grid_1d() +``` +""" +macro mtlshow(exs...) + blk = Expr(:block) + for ex in exs + push!( + blk.args, :( + Metal.@mtlprintln( + $(sprint(Base.show_unquoted, ex) * " = "), + begin + local value = $(esc(ex)) + end + ) + ) + ) + end + isempty(exs) || push!(blk.args, :value) + return blk +end diff --git a/src/device/quirks.jl b/src/device/quirks.jl index 7ccaebec1..5a09749f5 100644 --- a/src/device/quirks.jl +++ b/src/device/quirks.jl @@ -1,10 +1,3 @@ -macro print_and_throw(args...) - quote - #@println "ERROR: " $(args...) "." - throw(nothing) - end -end - # math.jl @device_override @noinline Base.Math.throw_complex_domainerror(f::Symbol, x) = @print_and_throw "This operation requires a complex input to return a complex result" diff --git a/src/device/utils.jl b/src/device/utils.jl index 5edc9c832..07ac82d3f 100644 --- a/src/device/utils.jl +++ b/src/device/utils.jl @@ -1,6 +1,13 @@ # local method table for device functions Base.Experimental.@MethodTable(method_table) +macro print_and_throw(args...) + return quote + #@println "ERROR: " $(args...) "." + throw(nothing) + end +end + macro device_override(ex) ex = macroexpand(__module__, ex) esc(quote diff --git a/src/state.jl b/src/state.jl index 3a0512e52..e4c926277 100644 --- a/src/state.jl +++ b/src/state.jl @@ -47,6 +47,7 @@ function global_queue(dev::MTLDevice) @autoreleasepool begin # NOTE: MTLCommandQueue itself is manually reference-counted, # the release pool is for resources used during its construction. + queue = MTLCommandQueue(dev) queue.label = "global_queue($(current_task()))" global_queues[queue] = nothing diff --git a/test/output.jl b/test/output.jl new file mode 100644 index 000000000..170475277 --- /dev/null +++ b/test/output.jl @@ -0,0 +1,150 @@ +@testset "output" begin + +@static if Metal.macos_version() < v"15" + +@warn "Skipping output tests in macOS 14 and below" + +function kernel() + @mtlprint("Hello, World\n") + return +end +@test_throws "Logging is only supported on macOS 15 or higher" @metal kernel() + +else + +@testset "formatted output" begin + _, out = @grab_output @on_device @mtlprintf("") + @test out == "" + + _, out = @grab_output @on_device @mtlprintf("Testing...\n") + @test out == "Testing...\n" + + # narrow integer + _, out = @grab_output @on_device @mtlprintf("Testing %d %d...\n", Int32(1), Int32(2)) + @test out == "Testing 1 2...\n" + + # wide integer + _, out = @grab_output @on_device @mtlprintf("Testing %ld %ld...\n", Int64(1), Int64(2)) + @test out == "Testing 1 2...\n" + + _, out = @grab_output @on_device begin + @mtlprintf("foo") + @mtlprintf("bar\n") + end + @test out == "foobar\n" + + # c argument promotions + function kernel(A) + @mtlprintf("%f %f\n", A[1], A[1]) + return + end + x = mtl(ones(2, 2)) + _, out = @grab_output begin + Metal.@sync @metal kernel(x) + end + @test out == "1.000000 1.000000\n" +end + +@testset "@mtlprint" begin + # basic @mtlprint/@mtlprintln + + _, out = @grab_output @on_device @mtlprint("Hello, World\n") + @test out == "Hello, World\n" + + _, out = @grab_output @on_device @mtlprintln("Hello, World") + @test out == "Hello, World\n" + + + # argument interpolation (by the macro, so can use literals) + + _, out = @grab_output @on_device @mtlprint("foobar") + @test out == "foobar" + + _, out = @grab_output @on_device @mtlprint(:foobar) + @test out == "foobar" + + _, out = @grab_output @on_device @mtlprint("foo", "bar") + @test out == "foobar" + + _, out = @grab_output @on_device @mtlprint("foobar ", 42) + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint("foobar $(42)") + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint("foobar $(4)", 2) + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint("foobar ", 4, "$(2)") + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint(42) + @test out == "42" + + _, out = @grab_output @on_device @mtlprint(4, 2) + @test out == "42" + + _, out = @grab_output @on_device @mtlprint(Any) + @test out == "Any" + + _, out = @grab_output @on_device @mtlprintln("foobar $(42)") + @test out == "foobar 42\n" + + + # argument types + + # we're testing the generated functions now, so can't use literals + function test_output(val, str) + canary = rand(Int32) # if we mess up the main arg, this one will print wrong + _, out = @grab_output @on_device @mtlprint(val, " (", canary, ")") + @test out == "$(str) ($(Int(canary)))" + end + + for typ in (Int16, Int32, Int64, UInt16, UInt32, UInt64) + test_output(typ(42), "42") + end + + for typ in (Float32,) + test_output(typ(42), "42.000000") + end + + test_output(Cchar('c'), "c") + + for typ in (Ptr{Cvoid}, Ptr{Int}) + ptr = convert(typ, Int(0x12345)) + test_output(ptr, "0x12345") + end + + test_output(true, "1") + test_output(false, "0") + + test_output((1,), "(1,)") + test_output((1,2), "(1, 2)") + test_output((1,2,3.0f0), "(1, 2, 3.000000)") + + # escaping + + kernel1(val) = (@mtlprint(val); nothing) + _, out = @grab_output @on_device kernel1(42) + @test out == "42" + + kernel2(val) = (@mtlprintln(val); nothing) + _, out = @grab_output @on_device kernel2(42) + @test out == "42\n" +end + +@testset "@mtlshow" begin + function kernel() + seven_i32 = Int32(7) + three_f32 = Float32(3) + @mtlshow seven_i32 + @mtlshow three_f32 + @mtlshow 1f0 + 4f0 + return + end + + _, out = @grab_output @on_device kernel() + @test out == "seven_i32 = 7\nthree_f32 = 3.000000\n1.0f0 + 4.0f0 = 5.000000\n" +end +end +end