diff --git a/lib/mtl/MTL.jl b/lib/mtl/MTL.jl index 12f62d5e8..b9fa742ee 100644 --- a/lib/mtl/MTL.jl +++ b/lib/mtl/MTL.jl @@ -6,7 +6,7 @@ using ObjectiveC, .Foundation, .Dispatch ## version information -export darwin_version, macos_version +export darwin_version, macos_version, metal_version @noinline function _syscall_version(name) size = Ref{Csize_t}() @@ -39,6 +39,25 @@ function macos_version() _macos_version[] end +function metal_version() + macos = macos_version() + if macos >= v"13" + v"3.0" + elseif macos >= v"12" + v"2.4" + elseif macos v> v"11" + v"2.3" + elseif macos >= v"10.15" + v"2.2" + elseif macos >= v"10.14" + v"2.1" + elseif macos >= v"10.13" + v"2.0" + else + error("Metal is not supported on macOS < 10.13") + end +end + ## source code includes diff --git a/src/Metal.jl b/src/Metal.jl index 285f4a81c..fb9551cb3 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -31,6 +31,8 @@ include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") +include("device/intrinsics/version.jl") +include("device/intrinsics/atomics.jl") include("device/quirks.jl") # array essentials diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 35fceac8e..ca2af5005 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -40,10 +40,12 @@ end @noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...) # TODO: configure the compiler target based on the device - macos=macos_version() + macos = macos_version() + metal = metal_version() + air = metal # XXX: do these ever differ? # create GPUCompiler objects - target = MetalCompilerTarget(macos; kwargs...) + target = MetalCompilerTarget(; macos, air, metal, kwargs...) params = MetalCompilerParams() CompilerConfig(target, params; kernel, name, always_inline) end diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index add2a3100..bb058b084 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -211,7 +211,7 @@ function (kernel::HostKernel)(args...; groups=1, threads=1, queue=global_queue(c else # everything else is passed by reference, and requires an argument buffer arg = mtlconvert(arg, cce) - argtyp = Core.typeof(arg) + argtyp = Core.Typeof(arg) if isghosttype(argtyp) || Core.Compiler.isconstType(argtyp) continue elseif !isbitstype(argtyp) diff --git a/src/device/intrinsics/atomics.jl b/src/device/intrinsics/atomics.jl new file mode 100644 index 000000000..d2f1f87db --- /dev/null +++ b/src/device/intrinsics/atomics.jl @@ -0,0 +1,238 @@ +# Atomic Functions + +@enum memory_order::Int32 begin + memory_order_relaxed = 0 +end + +# XXX: the integers should come from some enum +const atomic_memory_names = Dict( + AS.Device => ("global", Int32(2)), + AS.ThreadGroup => ("local", Int32(1)) +) + +const atomic_type_names = Dict( + :Int32 => "i32", + :UInt32 => "i32", + :Int64 => "i64", + :UInt64 => "i64", + :Float32 => "f32" +) + + +## low-level functions + +# NOTE: Float32 atomics are only available on Metal 3.0, but we can't check that at runtime + +for typ in (:Int32, :UInt32, :Float32), as in (AS.Device, AS.ThreadGroup) + typnam = atomic_type_names[typ] + memnam, memid = atomic_memory_names[as] + @eval begin + function atomic_store_explicit(ptr::LLVMPtr{$typ,$as}, desired::$typ) + @typed_ccall($"air.atomic.$memnam.store.$typnam", llvmcall, Nothing, + (LLVMPtr{$typ,$as}, $typ, Int32, Int32, Bool), + ptr, desired, Val(memory_order_relaxed), Val($memid), Val(true)) + end + + function atomic_load_explicit(ptr::LLVMPtr{$typ,$as}) + @typed_ccall($"air.atomic.$memnam.load.$typnam", llvmcall, $typ, + (LLVMPtr{$typ,$as}, Int32, Int32, Bool), + ptr, Val(memory_order_relaxed), Val($memid), Val(true)) + end + + function atomic_exchange_explicit(ptr::LLVMPtr{$typ,$as}, desired::$typ) + @typed_ccall($"air.atomic.$memnam.xchg.$typnam", llvmcall, $typ, + (LLVMPtr{$typ,$as}, $typ, Int32, Int32, Bool), + ptr, desired, Val(memory_order_relaxed), Val($memid), Val(true)) + end + + function atomic_compare_exchange_weak_explicit(ptr::LLVMPtr{$typ,$as}, + expected::$typ, desired::$typ) + # NOTE: we deviate slightly from the Metal/C++ API here, not returning the + # status boolean, but the contents of the expected value box, which will + # have been changed to the current value if the exchange failed. + expected_box = Ref(expected) + @typed_ccall($"air.atomic.$memnam.cmpxchg.weak.$typnam", llvmcall, Bool, + (LLVMPtr{$typ,$as}, Ptr{$typ}, $typ, Int32, Int32, Int32, Bool), + ptr, expected_box, desired, Val(memory_order_relaxed), + Val(memory_order_relaxed), Val($memid), Val(true)) + expected_box[] + end + end +end + +const atomic_fetch_and_modify = [ + :add => [:Int32, :UInt32, :Float32], + :sub => [:Int32, :UInt32, :Float32], + :min => [:Int32, :UInt32], + :max => [:Int32, :UInt32], + :and => [:Int32, :UInt32], + :or => [:Int32, :UInt32], + :xor => [:Int32, :UInt32] +] + +for (op, types) in atomic_fetch_and_modify, typ in types, as in (AS.Device, AS.ThreadGroup) + typnam = atomic_type_names[typ] + if typ in [:Int32, :Int64] + typnam = "s.$typnam" + elseif typ in [:UInt32, :UInt64] + typnam = "u.$typnam" + end + memnam, memid = atomic_memory_names[as] + f = Symbol("atomic_fetch_$(op)_explicit") + @eval begin + function $f(ptr::LLVMPtr{$typ,$as}, desired::$typ) + @typed_ccall($"air.atomic.$memnam.$op.$typnam", llvmcall, $typ, + (LLVMPtr{$typ,$as}, $typ, Int32, Int32, Bool), + ptr, desired, Val(memory_order_relaxed), Val($memid), Val(true)) + end + end +end + +# TODO: non-fetch 64-bit min/max atomics (hardware support?) + +# generic atomic support using compare-and-swap +@inline function atomic_fetch_op_explicit(ptr::LLVMPtr{T}, op::Function, val) where {T} + old = Base.unsafe_load(ptr) + while true + cmp = old + new = convert(T, op(old, val)) + old = atomic_compare_exchange_weak_explicit(ptr, cmp, new) + isequal(old, cmp) && return new + end +end + + +## high-level interface + +# copied from CUDA.jl -- should be generalized or integrated with Base + +const inplace_ops = Dict( + :(+=) => :(+), + :(-=) => :(-), + :(*=) => :(*), + :(/=) => :(/), + :(\=) => :(\), + :(%=) => :(%), + :(^=) => :(^), + :(&=) => :(&), + :(|=) => :(|), + :(⊻=) => :(⊻), + :(>>>=) => :(>>>), + :(>>=) => :(>>), + :(<<=) => :(<<), +) + +struct AtomicError <: Exception + msg::AbstractString +end + +Base.showerror(io::IO, err::AtomicError) = + print(io, "AtomicError: ", err.msg) + +""" + @atomic a[I] = op(a[I], val) + @atomic a[I] ...= val + +Atomically perform a sequence of operations that loads an array element `a[I]`, performs the +operation `op` on that value and a second value `val`, and writes the result back to the +array. This sequence can be written out as a regular assignment, in which case the same +array element should be used in the left and right hand side of the assignment, or as an +in-place application of a known operator. In both cases, the array reference should be pure +and not induce any side-effects. + +!!! warn + This interface is experimental, and might change without warning. Use the lower-level + `atomic_...!` functions for a stable API, albeit one limited to natively-supported ops. +""" +macro atomic(ex) + # decode assignment and call + if ex.head == :(ref) + # @atomic b[i] + ref = ex + op = nothing + val = nothing + elseif ex.head == :(=) + # @atomic b[i] = ... + ref = ex.args[1] + rhs = ex.args[2] + if !isa(rhs, Expr) + # @atomic b[i] = val + op = nothing + val = rhs + elseif Meta.isexpr(rhs, :call) + # @atomic b[i] = b[i] + val + # TODO: matching on a call is ambiguous (`@atomicm b[i] = Int32(0)` is a call) + # so we should probably only support in-place assignment? + op = rhs.args[1] + if rhs.args[2] != ref + throw(AtomicError("right-hand side of a non-inplace @atomic assignment should reference the left-hand side")) + end + val = rhs.args[3] + else + throw(AtomicError("right-hand side of an @atomic assignment should be a value or a call")) + end + elseif haskey(inplace_ops, ex.head) + # @atomic b[i] += val + op = inplace_ops[ex.head] + ref = ex.args[1] + val = ex.args[2] + else + throw(AtomicError("unknown @atomic expression")) + end + + # decode array expression + Meta.isexpr(ref, :ref) || throw(AtomicError("@atomic should be applied to an array reference expression")) + array = ref.args[1] + indices = Expr(:tuple, ref.args[2:end]...) + + if val === nothing + esc(quote + $atomic_arrayref($array, $indices) + end) + else + esc(quote + $atomic_arrayset($array, $indices, $op, $val) + end) + end +end + +# FIXME: make this respect the indexing style +@inline atomic_arrayref(A::AbstractArray{T}, Is::Tuple) where {T} = + atomic_arrayref(A, Base._to_linear_index(A, Is...)) +@inline atomic_arrayset(A::AbstractArray{T}, Is::Tuple, op, val) where {T} = + atomic_arrayset(A, Base._to_linear_index(A, Is...), op, convert(T, val)) + +# native atomics +@inline atomic_arrayref(A::AbstractArray, I::Integer) = atomic_load_explicit(pointer(A, I)) +@inline atomic_arrayset(A::AbstractArray{T}, I::Integer, ::Nothing, val) where T = + atomic_store_explicit(pointer(A, I), convert(T, val)) +for (op,impl,typ) in [(:(+), :(atomic_fetch_add_explicit), [:UInt32,:Int32,:Float32]), + (:(-), :(atomic_fetch_sub_explicit), [:UInt32,:Int32,:Float32]), + (:(&), :(atomic_fetch_and_explicit), [:UInt32,:Int32]), + (:(|), :(atomic_fetch_or_explicit), [:UInt32,:Int32]), + (:(⊻), :(atomic_fetch_xor_explicit), [:UInt32,:Int32]), + (:max, :(atomic_fetch_max_explicit), [:UInt32,:Int32]), + (:min, :(atomic_fetch_min_explicit), [:UInt32,:Int32])] + @eval @inline atomic_arrayset(A::AbstractArray{T}, I::Integer, ::typeof($op), + val::T) where {T<:Union{$(typ...)}} = + $impl(pointer(A, I), val) +end + +# native atomics that are not supported on all devices +@inline function atomic_arrayset(A::AbstractArray{T}, I::Integer, op::typeof(+), + val::T) where {T <: AbstractFloat} + ptr = pointer(A, I) + # XXX: consider falling back to fetch_op here to support Metal < 3.0 (this also requires + # cmpxchg support for Float32, but we should be able to do that using bitcast) + atomic_fetch_add_explicit(ptr, val) +end +@inline function atomic_arrayset(A::AbstractArray{T}, I::Integer, op::typeof(-), + val::T) where {T <: AbstractFloat} + ptr = pointer(A, I) + # XXX: see above + atomic_fetch_sub_explicit(ptr, val) +end + +# fallback using compare-and-swap +@inline atomic_arrayset(A::AbstractArray{T}, I::Integer, op::Function, val) where {T} = + atomic_fetch_op_explicit(pointer(A, I), op, val) diff --git a/src/device/intrinsics/version.jl b/src/device/intrinsics/version.jl new file mode 100644 index 000000000..3f86c376b --- /dev/null +++ b/src/device/intrinsics/version.jl @@ -0,0 +1,67 @@ +# device intrinsics for querying the compute SimpleVersion and PTX ISA version + + +## a GPU-compatible version number + +# XXX: this is duplicated with CUDA.jl; move it to a common place + +export SimpleVersion, @sv_str + +struct SimpleVersion + major::UInt32 + minor::UInt32 + + SimpleVersion(major, minor=0) = new(major, minor) +end + +function Base.tryparse(::Type{SimpleVersion}, v::AbstractString) + parts = split(v, ".") + 1 <= length(parts) <= 2 || return nothing + + int_parts = map(parts) do part + tryparse(Int, part) + end + any(isnothing, int_parts) && return nothing + + SimpleVersion(int_parts...) +end + +function Base.parse(::Type{SimpleVersion}, v::AbstractString) + ver = tryparse(SimpleVersion, v) + ver === nothing && throw(ArgumentError("invalid SimpleVersion string: '$v'")) + return ver +end + +SimpleVersion(v::AbstractString) = parse(SimpleVersion, v) + +@inline function Base.isless(a::SimpleVersion, b::SimpleVersion) + (a.major < b.major) && return true + (a.major > b.major) && return false + (a.minor < b.minor) && return true + (a.minor > b.minor) && return false + return false +end + +macro sv_str(str) + SimpleVersion(str) +end + + +## accessors for the Metal and AIR version + +export metal_version, air_version + +for var in ["metal_major", "metal_minor", "air_major", "air_minor"] + @eval @inline $(Symbol(var))() = + Base.llvmcall( + $("""@$var = external global i32 + define i32 @entry() #0 { + %val = load i32, i32* @$var + ret i32 %val + } + attributes #0 = { alwaysinline } + """, "entry"), UInt32, Tuple{}) +end + +@device_override @inline metal_version() = SimpleVersion(metal_major(), metal_minor()) +@device_function @inline air_version() = SimpleVersion(air_major(), air_minor()) diff --git a/src/pool.jl b/src/pool.jl index 365fa78d1..d12d05f25 100644 --- a/src/pool.jl +++ b/src/pool.jl @@ -5,13 +5,13 @@ using Printf # allocation statistics mutable struct AllocStats - @atomic alloc_count::Int - @atomic alloc_bytes::Int + Base.@atomic alloc_count::Int + Base.@atomic alloc_bytes::Int - @atomic free_count::Int - @atomic free_bytes::Int + Base.@atomic free_count::Int + Base.@atomic free_bytes::Int - @atomic total_time::Float64 + Base.@atomic total_time::Float64 end AllocStats() = AllocStats(0, 0, 0, 0, 0.0) @@ -61,9 +61,9 @@ function alloc(dev::Union{MTLDevice,MTLHeap}, buf = MTLBuffer(dev, bytesize, args...; storage, kwargs...) end - @atomic alloc_stats.alloc_count + 1 - @atomic alloc_stats.alloc_bytes + bytesize - @atomic alloc_stats.total_time + time + Base.@atomic alloc_stats.alloc_count + 1 + Base.@atomic alloc_stats.alloc_bytes + bytesize + Base.@atomic alloc_stats.total_time + time return buf end @@ -81,9 +81,9 @@ function free(buf::MTLBuffer) release(buf) end - @atomic alloc_stats.free_count + 1 - @atomic alloc_stats.free_bytes + sz - @atomic alloc_stats.total_time + time + Base.@atomic alloc_stats.free_count + 1 + Base.@atomic alloc_stats.free_bytes + sz + Base.@atomic alloc_stats.total_time + time return end diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index fc66ffabf..5eb4c4296 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -369,3 +369,327 @@ end end # End Matrix Functions end # End SIMD Intrinsics + + +############################################################################################ + +@testset "atomics" begin + +n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors + +# JuliaGPU/Metal.jl#217: threadgroup atomics seem to requires all-atomic operations + +@testset "low-level" begin + # TODO: make these tests actually write to the overlapping memory locations + + # XXX: according to the docs, Float32 atomics should also work on threadgroup memory + + @testset "store_explicit" begin + function global_kernel(a, val) + i = thread_position_in_grid_1d() + Metal.atomic_store_explicit(pointer(a, i), val) + return + end + + types = [Int32] + metal_version() >= v"3.0" && push!(types, Float32) + @testset for T in types + a = Metal.zeros(T, n) + @metal threads=n global_kernel(a, T(42)) + @test all(isequal(42), Array(a)) + end + + function local_kernel(a, val::T) where T + i = thread_position_in_grid_1d() + b = MtlThreadGroupArray(T, 128) + Metal.atomic_store_explicit(pointer(b, i), val) + a[i] = b[i] + return + end + + @testset for T in [Int32,] + a = Metal.zeros(T, n) + @metal threads=n local_kernel(a, T(42)) + @test all(isequal(42), Array(a)) + end + end + + @testset "load_explicit" begin + function global_kernel(a, b) + i = thread_position_in_grid_1d() + val = Metal.atomic_load_explicit(pointer(a, i)) + b[i] = val + return + end + + types = [Int32] + metal_version() >= v"3.0" && push!(types, Float32) + @testset for T in types + a = MtlArray(rand(T, n)) + b = Metal.zeros(T, n) + @metal threads=n global_kernel(a, b) + @test Array(a) == Array(b) + end + + function local_kernel(a::AbstractArray{T}, b::AbstractArray{T}) where T + i = thread_position_in_grid_1d() + c = MtlThreadGroupArray(T, 128) + #c[i] = a[i] + val = Metal.atomic_load_explicit(pointer(a, i)) + Metal.atomic_store_explicit(pointer(c, i), val) + val = Metal.atomic_load_explicit(pointer(c, i)) + #b[i] = val + Metal.atomic_store_explicit(pointer(b, i), val) + return + end + + @testset for T in [Int32,] + a = MtlArray(rand(T, n)) + b = Metal.zeros(T, n) + @metal threads=n local_kernel(a, b) + @test Array(a) == Array(b) + end + end + + @testset "exchange_explicit" begin + function global_kernel(a, val) + i = thread_position_in_grid_1d() + Metal.atomic_exchange_explicit(pointer(a, i), val) + return + end + + types = [Int32] + metal_version() >= v"3.0" && push!(types, Float32) + @testset for T in types + a = MtlArray(rand(T, n)) + @metal threads=n global_kernel(a, T(42)) + @test all(isequal(42), Array(a)) + end + + function local_kernel(a, val::T) where T + i = thread_position_in_grid_1d() + b = MtlThreadGroupArray(T, 128) + Metal.atomic_exchange_explicit(pointer(b, i), val) + a[i] = b[i] + return + end + + @testset for T in [Int32,] + a = Metal.zeros(T, n) + @metal threads=n local_kernel(a, T(42)) + @test all(isequal(42), Array(a)) + end + end + + @testset "compare_exchange_weak_explicit" begin + function global_kernel(a, expected, desired) + i = thread_position_in_grid_1d() + while Metal.atomic_compare_exchange_weak_explicit(pointer(a, i), expected[i], desired) != expected[i] + # keep on trying + end + return + end + + types = [Int32] + metal_version() >= v"3.0" && push!(types, Float32) + @testset for T in types + a = MtlArray(rand(T, n)) + expected = copy(a) + desired = T(42) + @metal threads=length(a) global_kernel(a, expected, desired) + @test all(isequal(42), Array(a)) + end + + function local_kernel(a, expected::AbstractArray{T}, desired::T) where T + i = thread_position_in_grid_1d() + b = MtlThreadGroupArray(T, 128) + #b[i] = a[i] + val = Metal.atomic_load_explicit(pointer(a, i)) + Metal.atomic_store_explicit(pointer(b, i), val) + while Metal.atomic_compare_exchange_weak_explicit(pointer(b, i), expected[i], desired) != expected[i] + # keep on trying + end + #a[i] = b[i] + val = Metal.atomic_load_explicit(pointer(b, i)) + Metal.atomic_store_explicit(pointer(a, i), val) + return + end + + @testset for T in [Int32,] + a = Metal.zeros(T, n) + expected = copy(a) + desired = T(42) + @metal threads=n local_kernel(a, expected, desired) + @test all(isequal(42), Array(a)) + end + end + + @testset "fetch and modify" begin + add_sub_types = [Int32, UInt32] + metal_version() >= v"3.0" && push!(add_sub_types, Float32) + other_types = [Int32, UInt32] + for (jlfun, mtlfun, types) in [(min, Metal.atomic_fetch_min_explicit, other_types), + (max, Metal.atomic_fetch_max_explicit, other_types), + (&, Metal.atomic_fetch_and_explicit, other_types), + (|, Metal.atomic_fetch_or_explicit, other_types), + (⊻, Metal.atomic_fetch_xor_explicit, other_types), + (+, Metal.atomic_fetch_add_explicit, add_sub_types), + (-, Metal.atomic_fetch_sub_explicit, add_sub_types) + ] + function global_kernel(f, a, arg) + i = thread_position_in_grid_1d() + f(pointer(a, i), arg) + return + end + + function local_kernel(f, a, arg::T) where T + i = thread_position_in_grid_1d() + b = MtlThreadGroupArray(T, 128) + #b[i] = a[i] + val = Metal.atomic_load_explicit(pointer(a, i)) + Metal.atomic_store_explicit(pointer(b, i), val) + f(pointer(b, i), arg) + #a[i] = b[i] + val = Metal.atomic_load_explicit(pointer(b, i)) + Metal.atomic_store_explicit(pointer(a, i), val) + return + end + + @testset "fetch_$(jlfun)_explicit" begin + @testset "device $T" for T in types + a = rand(T, n) + b = MtlArray(a) + val = rand(T) + @metal threads=n global_kernel(mtlfun, b, val) + @test jlfun.(a, val) ≈ Array(b) + end + + @testset "threadgroup $T" for T in setdiff(types, [Float32]) + a = rand(T, n) + b = MtlArray(a) + val = rand(T) + @metal threads=n local_kernel(mtlfun, b, val) + @test jlfun.(a, val) ≈ Array(b) + end + end + end + end + + @testset "generic fetch and modify" begin + # custom operator that doesn't map onto an atomic intrinsic + f(a::T, b::T) where {T} = a + b + one(T) + + function global_kernel(a, op, arg) + i = thread_position_in_grid_1d() + Metal.atomic_fetch_op_explicit(pointer(a, i), op, arg) + return + end + + @testset for T in (Int32, UInt32) + a = rand(T, n) + b = MtlArray(a) + val = rand(T) + @metal threads=n global_kernel(b, f, val) + @test f.(a, val) ≈ Array(b) + end + + function local_kernel(a, op, arg::T) where T + i = thread_position_in_grid_1d() + b = MtlThreadGroupArray(T, 128) + #b[i] = a[i] + val = Metal.atomic_load_explicit(pointer(a, i)) + Metal.atomic_store_explicit(pointer(b, i), val) + Metal.atomic_fetch_op_explicit(pointer(b, i), op, arg) + #a[i] = b[i] + val = Metal.atomic_load_explicit(pointer(b, i)) + Metal.atomic_store_explicit(pointer(a, i), val) + return + end + + @testset for T in (Int32, UInt32) + a = rand(T, n) + b = MtlArray(a) + val = rand(T) + @metal threads=n local_kernel(b, f, val) + @test f.(a, val) ≈ Array(b) + end + end +end + +@testset "high-level" begin + # NOTE: this doesn't test threadgroup atomics, as those are assumed to have been + # covered by the low-level tests above, but only the atomic macro functionality. + + @testset "load" begin + types = [Int32, UInt32] + metal_version() >= v"3.0" && append!(types, [Float32]) + + function kernel(a, b) + i = thread_position_in_grid_1d() + a[i] = Metal.@atomic b[i] + return + end + + @testset for T in types + a = Metal.zeros(T, n) + b = MtlArray(rand(T, n)) + @metal threads=n kernel(a, b) + @test Array(a) == Array(b) + end + end + + @testset "store" begin + types = [Int32, UInt32] + metal_version() >= v"3.0" && append!(types, [Float32]) + + function kernel(a, b) + i = thread_position_in_grid_1d() + val = b[i] + Metal.@atomic a[i] = val + return + end + + @testset for T in types + a = Metal.zeros(T, n) + b = MtlArray(rand(T, n)) + @metal threads=n kernel(a, b) + @test Array(a) == Array(b) + end + end + + @testset "add" begin + types = [Int32, UInt32] + metal_version() >= v"3.0" && append!(types, [Float32]) + + function kernel(a) + Metal.@atomic a[1] = a[1] + 1 + Metal.@atomic a[1] += 1 + return + end + + @testset for T in types + a = Metal.zeros(T) + @metal threads=n kernel(a) + @test Array(a)[1] == 2*n + end + end + + @testset "sub" begin + types = [Int32, UInt32] + metal_version() >= v"3.0" && append!(types, [Float32]) + + function kernel(a) + Metal.@atomic a[1] = a[1] - 1 + Metal.@atomic a[1] -= 1 + return + end + + @testset for T in types + a = MtlArray(T[2n]) + @metal threads=n kernel(a) + @test Array(a)[1] == 0 + end + end +end + +end diff --git a/test/execution.jl b/test/execution.jl index 771d60b95..50ad3bd29 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -235,13 +235,13 @@ end end @testset "unused mutable types" begin - function kernel(ptr, T) + function kernel(T, ptr) unsafe_store!(ptr, one(T)) return end a = MtlArray([0]) - @metal kernel(pointer(a), Int) + @metal kernel(Int, pointer(a)) @test Array(a)[] == 1 end end