diff --git a/Package.swift b/Package.swift index 688c0877..edfd3347 100644 --- a/Package.swift +++ b/Package.swift @@ -22,24 +22,12 @@ let package = Package( .library(name: "MLXFFT", targets: ["MLXFFT"]), .library(name: "MLXLinalg", targets: ["MLXLinalg"]), .library(name: "MLXFast", targets: ["MLXFast"]), - - // build support & back end - .plugin( - name: "PrepareMetalShaders", - targets: ["PrepareMetalShaders"] - ), ], dependencies: [ // for Complex type .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0") ], targets: [ - // plugin to help build the metal shaders - .plugin( - name: "PrepareMetalShaders", - capability: .buildTool(), - path: "Plugins/PrepareMetalShaders" - ), .target( name: "Cmlx", exclude: [ @@ -133,10 +121,7 @@ let package = Package( .linkedFramework("Foundation"), .linkedFramework("Metal"), .linkedFramework("Accelerate"), - ], - - // run the plugin to build the metal shaders - plugins: [.plugin(name: "PrepareMetalShaders")] + ] ), .testTarget( name: "CmlxTests", diff --git a/Plugins/PrepareMetalShaders/main.swift b/Plugins/PrepareMetalShaders/main.swift deleted file mode 100644 index 107cbaf3..00000000 --- a/Plugins/PrepareMetalShaders/main.swift +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import PackagePlugin - -let debug = false - -private func log(_ message: @autoclosure () -> String) { - if debug { - print(message()) - } -} - -/// Prepare the metal shaders source for compilation. -/// -/// The metal shaders (`mlx/backend/metal/kernels`) include headers as: -/// -/// ``` -/// #include "mlx/backend/metal/kernels/...h" -/// ``` -/// -/// but this doesn't work with the swiftpm build -- there is no way to set a header search path and this -/// absolute path won't work. -/// -/// This plugin makes a copy of the shaders and modifies them to include the files using a relative path. -/// This code is specialized to the mlx/metal code but could be adapted elsewhere with changes. -@main -struct PrepareMetalShaders: BuildToolPlugin { - - /// pattern to rewrite - private let include = try! Regex("#include \"mlx/backend/metal/kernels/([^\"]*)\"") - - // see Source/Cmlx/mlx/mlx/backend/metal/kernels/CMakeLists.txt - let kernels: Set = [ - "arg_reduce.metal", - "conv.metal", - "gemv.metal", - "random.metal", - "rms_norm.metal", - "layer_norm.metal", - "rope.metal", - "scaled_dot_product_attention.metal", - ] - - func transformIncludes(url: URL) throws { - let contents = try String(contentsOf: url, encoding: .utf8) - - let new: String - - // need to transform - // #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" - // - // into - // #include "../../steel/gemm/transforms.h" - - let pathUnderKernels = url.pathComponents.drop { $0 != "output" }.dropLast() - - let rootPath = - Array(repeating: "..", count: pathUnderKernels.count - 1).joined(separator: "/") - + ((pathUnderKernels.count - 1 == 0) ? "" : "/") - - new = - contents - .replacing(include, with: { "#include \"\(rootPath)\($0[1].substring ?? "")\"" }) - - try new.write(to: url, atomically: true, encoding: .utf8) - } - - func collectFiles(from directory: URL) throws -> [String: Date] { - var result = [String: Date]() - - let prefixCount = directory.pathComponents.count - - if let enumerator = FileManager.default.enumerator( - at: directory, - includingPropertiesForKeys: [.isRegularFileKey, .contentModificationDateKey], - options: [.skipsHiddenFiles, .skipsPackageDescendants]) - { - - for case let url as URL in enumerator { - let resourceValues = try url.resourceValues(forKeys: [ - .isRegularFileKey, .contentModificationDateKey, - ]) - let isRegularFile = resourceValues.isRegularFile ?? false - - // ignore directories and CMakeLists.txt - guard isRegularFile else { - continue - } - guard url.lastPathComponent != "CMakeLists.txt" else { - continue - } - - if url.pathExtension == "h" || kernels.contains(url.lastPathComponent) { - // ok - } else { - continue - } - - let modDate = resourceValues.contentModificationDate ?? Date() - - // these will be moved to the top level (see below in building) - if url.pathExtension == "metal" { - result[url.lastPathComponent] = modDate - } else { - let path = url.pathComponents.dropFirst(prefixCount).joined(separator: "/") - result[path] = modDate - } - } - } - - return result - } - - func shouldCopy(from source: URL, to destination: URL) throws -> Bool { - // directory does not exist - if !FileManager.default.fileExists(atPath: destination.path(percentEncoded: false)) { - log("\(destination) does not exist -- copy source metal files") - return true - } - - let sourceFiles = try collectFiles(from: source) - if let destinationFiles = try? collectFiles(from: destination) { - - log("source: \(source)") - log("destination: \(destination)") - for (path, date) in sourceFiles { - if let destinationDate = destinationFiles[path] { - log("\(path): \(date) vs \(destinationDate)") - } else { - log("\(path): \(date) vs MISSING") - } - } - - // if there are missing files in the destination - if Set(sourceFiles.keys) != Set(destinationFiles.keys) { - print( - "files in \(source) are different than in \(destination) -- copy source metal files" - ) - return true - } - - // or if there are newer files in the source - for (path, date) in sourceFiles { - if let destinationDate = destinationFiles[path] { - if destinationDate < date { - print("\(path) in \(destination) is out of date") - return true - } - } - } - - print("metal files in \(destination) are up to date") - return false - - } else { - // no destination - return true - } - } - - func createBuildCommands(context: PluginContext, target: Target) throws -> [Command] { - var commands = [Command]() - - let sourcePath = target.directory.appending(["mlx", "mlx", "backend", "metal", "kernels"]) - let source = URL(fileURLWithPath: sourcePath.string) - - let destinationPath = context.pluginWorkDirectory.appending(["output"]) - let destination = URL(fileURLWithPath: destinationPath.string) - - // only do the work if the directory doesn't exist - if try shouldCopy(from: source, to: destination) { - // remove the destination directory first in case files have been removed - try? FileManager.default.removeItem(at: destination) - - // copy the files from the source area - try FileManager.default.createDirectory( - at: destination.deletingLastPathComponent(), withIntermediateDirectories: true) - try FileManager.default.copyItem(at: source, to: destination) - - // the builder won't find metal kernels in subdirectories, so move them to the top - if let enumerator = FileManager.default.enumerator( - at: destination, includingPropertiesForKeys: [.isRegularFileKey], - options: [.skipsHiddenFiles, .skipsPackageDescendants]) - { - for case let url as URL in enumerator { - let isRegularFile = - try url.resourceValues(forKeys: [.isRegularFileKey]).isRegularFile ?? false - guard isRegularFile else { - continue - } - - if url.deletingLastPathComponent().lastPathComponent == "output" { - // still in the top directory - continue - } - - if url.pathExtension == "metal" { - try FileManager.default.moveItem( - at: url, to: destination.appending(component: url.lastPathComponent)) - } - } - } - - // remove any kernels that are not in the list - if let enumerator = FileManager.default.enumerator( - at: destination, includingPropertiesForKeys: [.isRegularFileKey], - options: [.skipsHiddenFiles, .skipsPackageDescendants]) - { - for case let url as URL in enumerator { - let isRegularFile = - try url.resourceValues(forKeys: [.isRegularFileKey]).isRegularFile ?? false - guard isRegularFile else { - continue - } - - if url.pathExtension == "h" || kernels.contains(url.lastPathComponent) { - // keep it - print("keeping \(url.lastPathComponent)") - } else { - print("removing \(url.lastPathComponent)") - try FileManager.default.removeItem(at: url) - } - } - } - - // foreach file, transform the #includes - if let enumerator = FileManager.default.enumerator( - at: destination, includingPropertiesForKeys: [.isRegularFileKey], - options: [.skipsHiddenFiles, .skipsPackageDescendants]) - { - for case let url as URL in enumerator { - let isRegularFile = - try url.resourceValues(forKeys: [.isRegularFileKey]).isRegularFile ?? false - guard isRegularFile else { - continue - } - - if url.lastPathComponent == "CMakeLists.txt" { - try FileManager.default.removeItem(at: url) - continue - } - - try transformIncludes(url: url) - } - } - } - - // a prebuild command to inject the output directory so swiftpm knows to pick it up - commands.append( - .prebuildCommand( - displayName: "Install Headers", - executable: .init("/bin/echo"), - arguments: [], - outputFilesDirectory: Path(destination.path(percentEncoded: false)))) - - return commands - } -} diff --git a/Source/Cmlx/mlx-generated/metal/arange.h b/Source/Cmlx/mlx-generated/metal/arange.h new file mode 100644 index 00000000..5448fe9a --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/arange.h @@ -0,0 +1,9 @@ +// Copyright © 2023-2024 Apple Inc. +template +[[kernel]] void arange( + constant const T& start, + constant const T& step, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = start + index * step; +} diff --git a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal new file mode 100644 index 00000000..98924c32 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal @@ -0,0 +1,180 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include "utils.h" + +using namespace metal; + +template +struct IndexValPair { + uint32_t index; + U val; +}; + +template +struct ArgMin { + static constexpr constant U init = Limits::max; + + IndexValPair reduce(IndexValPair best, IndexValPair current) { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + IndexValPair + reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] < best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +struct ArgMax { + static constexpr constant U init = Limits::min; + + IndexValPair reduce(IndexValPair best, IndexValPair current) { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + IndexValPair + reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] > best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { + return IndexValPair{ + simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; +} + +template +[[kernel]] void arg_reduce_general( + const device T* in [[buffer(0)]], + device uint32_t* out [[buffer(1)]], + const constant int* shape [[buffer(2)]], + const constant size_t* in_strides [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& ndim [[buffer(5)]], + const constant size_t& axis_stride [[buffer(6)]], + const constant size_t& axis_size [[buffer(7)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Shapes and strides *do not* contain the reduction axis. The reduction size + // and stride are provided in axis_stride and axis_size. + // + // Note: in shape == out shape with this convention. + // + // The sketch of the kernel is as follows. + // 1. Launch prod(shape) * thread_group_size threads. + // 2. Loop ceildiv(axis_size / lsize) times + // 3. Read input values + // 4. Reduce among them and go to 3 + // 4. Reduce in each simd_group + // 6. Write in the thread local memory + // 6. Reduce them across thread group + // 7. Write the output without need for atomic + Op op; + + // Compute the input/output index. There is one beginning and one output for + // the whole threadgroup. + auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); + auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + + IndexValPair best{0, Op::init}; + + threadgroup IndexValPair local_data[32]; + + // Loop over the reduction axis in lsize*N_READS buckets + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + // Read the current value + uint32_t current_index = r * lsize * N_READS + lid * N_READS; + uint32_t offset = current_index; + const device T* current_in = in + in_idx + current_index * axis_stride; + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = (current_index < axis_size) ? *current_in : T(Op::init); + current_index++; + current_in += axis_stride; + } + best = op.template reduce_many(best, vals, offset); + } + // At this point we have reduced the axis into thread group best values so we + // need to reduce across the thread group. + + // First per simd reduction. + for (uint offset = simd_size / 2; offset > 0; offset /= 2) { + IndexValPair neighbor = simd_shuffle_down(best, offset); + best = op.reduce(best, neighbor); + } + + // Write to the threadgroup memory + if (simd_lane_id == 0) { + local_data[simd_group_id] = best; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id != 0) { + return; + } + + // Read the appropriate value from local data and perform one simd reduction + uint simd_groups = ceildiv(lsize, simd_size); + if (simd_lane_id < simd_groups) { + best = local_data[simd_lane_id]; + } + for (uint offset = simd_size / 2; offset > 0; offset /= 2) { + IndexValPair neighbor = simd_shuffle_down(best, offset); + best = op.reduce(best, neighbor); + } + + // Finally write the output + if (lid == 0) { + out[out_idx] = best.index; + } +} + +// clang-format off +#define instantiate_arg_reduce(name, itype) \ + instantiate_kernel( \ + "argmin_" #name, arg_reduce_general, itype, ArgMin) \ + instantiate_kernel( \ + "argmax_" #name, arg_reduce_general, itype, ArgMax) + +instantiate_arg_reduce(bool_, bool) +instantiate_arg_reduce(uint8, uint8_t) +instantiate_arg_reduce(uint16, uint16_t) +instantiate_arg_reduce(uint32, uint32_t) +instantiate_arg_reduce(uint64, uint64_t) +instantiate_arg_reduce(int8, int8_t) +instantiate_arg_reduce(int16, int16_t) +instantiate_arg_reduce(int32, int32_t) +instantiate_arg_reduce(int64, int64_t) +instantiate_arg_reduce(float16, half) +instantiate_arg_reduce(float32, float) +instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/atomic.h b/Source/Cmlx/mlx-generated/metal/atomic.h new file mode 100644 index 00000000..93952c2c --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/atomic.h @@ -0,0 +1,345 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Atomic utils +/////////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable +template +constexpr constant bool is_metal_atomic = _disjunction< + is_same, + is_same, + is_same, + is_same>::value; + +#pragma METAL internals : disable + +template +struct mlx_atomic { + atomic val; +}; + +template +struct mlx_atomic>> { + atomic val; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Native metal atomics +/////////////////////////////////////////////////////////////////////////////// + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { + atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + T expected = mlx_atomic_load_explicit(object, offset); + while (!mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val * expected, offset)) { + } +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread T* expected, + T val, + size_t offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} + +// Specialization for float since it does not atomic_fetch_min_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + float val, + size_t offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val < expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +// Specialization for float since it does not atomic_fetch_max_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + float val, + size_t offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val > expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Custom atomics +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +template +constexpr constant uint packing_size = sizeof(uint) / sizeof(T); + +template +union uint_or_packed { + T val[packing_size]; + uint bits; +}; + +template +struct mlx_atomic_update_helper { + uint operator()(uint_or_packed init, T update, size_t elem_offset) { + Op op; + init.val[elem_offset] = op(update, init.val[elem_offset]); + return init.bits; + } +}; + +template +METAL_FUNC void mlx_atomic_update_and_store( + device mlx_atomic* object, + T update, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + + mlx_atomic_update_helper helper; + uint_or_packed expected; + expected.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + + while (Op::condition(update, expected.val[elem_offset]) && + !mlx_atomic_compare_exchange_weak_explicit( + object, + &(expected.bits), + helper(expected, update, elem_offset), + pack_offset)) { + } +} + +template +struct __None { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { +#pragma unused(b) + return a; + } +}; + +template +struct __Add { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { + return a + b; + } +}; + +template +struct __Mul { + static bool condition(T a, T b) { +#pragma unused(a) + return b != 0; + } + + T operator()(T a, T b) { + return a * b; + } +}; + +template +struct __Max { + static bool condition(T a, T b) { + return a > b; + } + + T operator()(T a, T b) { + return max(a, b); + } +}; + +template +struct __Min { + static bool condition(T a, T b) { + return a < b; + } + + T operator()(T a, T b) { + return min(a, b); + } +}; + +} // namespace + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + size_t pack_offset = offset / sizeof(T); + size_t elem_offset = offset % sizeof(T); + uint_or_packed packed_val; + packed_val.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + return packed_val.val[elem_offset]; +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = __UINT32_MAX__; + identity.val[elem_offset] = val; + + atomic_fetch_and_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = 0; + identity.val[elem_offset] = val; + + atomic_fetch_or_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread uint* expected, + uint val, + size_t offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} diff --git a/Source/Cmlx/mlx-generated/metal/bf16.h b/Source/Cmlx/mlx-generated/metal/bf16.h new file mode 100644 index 00000000..e8c22afb --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/bf16.h @@ -0,0 +1,317 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat numeric limits +///////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable + +namespace metal { + +template <> +struct _numeric_limits_impl : _fp_numeric_limits_impl_base { + static constexpr constant int digits = 8; + static constexpr constant int digits10 = 2; + static constexpr constant int max_digits10 = 4; + static constexpr constant int radix = 2; + static constexpr constant int min_exponent = -125; + static constexpr constant int min_exponent10 = -37; + static constexpr constant int max_exponent = 128; + static constexpr constant int max_exponent10 = 38; + + static constexpr bfloat16_t min() { + return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t lowest() { + return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t max() { + return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t epsilon() { + return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t round_error() { + return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t infinity() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t quiet_NaN() { + return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t signaling_NaN() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t denorm_min() { + return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat()); + } +}; + +METAL_FUNC bool isnan(_MLX_BFloat16 x) { + return x != x; +} + +} // namespace metal + +#pragma METAL internals : disable + +#endif + +#include "bf16_math.h" diff --git a/Source/Cmlx/mlx-generated/metal/bf16_math.h b/Source/Cmlx/mlx-generated/metal/bf16_math.h new file mode 100644 index 00000000..79e1ef15 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/bf16_math.h @@ -0,0 +1,394 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "bf16.h" + +/////////////////////////////////////////////////////////////////////////////// +// Metal math for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +/* + +Following the Metal Shading Language Specification (Metal 3.1) + +"bfloat is an extended itypeing point type that only allows implicit conversion + to a type of greater itypeing point rank. While bfloat can be implicitly + converted to itype, it cannot be implicitly converted to half, and neither + itype nor half can be implicitly converted to bfloat." + +Further, as far as I can tell, the stdlib math/simd functions are not defined +for bfloat and calling with an argument of type bfloat will result in that +argument getting implicitly converted to itype which then returns an output +that is (likely) a itype which cannot be implicitly converted into a bfloat + +This leads to situations where +bfloat a = 5.0bf; +bfloat b = metal::abs(a); // this will throw an error since abs return itype +bfloat c = static_cast(metal::abs(a)); // this is fine + +For the moment, I will be adding overloaded instantiations of the math +functions to accordingly automatically handle the casting + +*/ + +#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ + \ + METAL_FUNC otype abs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acos(itype x) { \ + return static_cast(__metal_acos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acosh(itype x) { \ + return static_cast(__metal_acosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asin(itype x) { \ + return static_cast(__metal_asin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asinh(itype x) { \ + return static_cast(__metal_asinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atan(itype y_over_x) { \ + return static_cast( \ + __metal_atan(static_cast(y_over_x), mfast)); \ + } \ + METAL_FUNC otype atan2(itype y, itype x) { \ + return static_cast( \ + __metal_atan2(static_cast(y), static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atanh(itype x) { \ + return static_cast(__metal_atanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype ceil(itype x) { \ + return static_cast(__metal_ceil(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cos(itype x) { \ + return static_cast(__metal_cos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cosh(itype x) { \ + return static_cast(__metal_cosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cospi(itype x) { \ + return static_cast(__metal_cospi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype divide(itype x, itype y) { \ + return static_cast( \ + __metal_divide(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype exp(itype x) { \ + return static_cast(__metal_exp(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp10(itype x) { \ + return static_cast(__metal_exp10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp2(itype x) { \ + return static_cast(__metal_exp2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fabs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fdim(itype x, itype y) { \ + ctype t = static_cast(x - y); \ + return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ + } \ + METAL_FUNC otype floor(itype x) { \ + return static_cast(__metal_floor(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fma(itype x, itype y, itype z) { \ + return static_cast(__metal_fma( \ + static_cast(x), static_cast(y), static_cast(z))); \ + } \ + METAL_FUNC otype fmax(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmin(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmod(itype x, itype y) { \ + return static_cast( \ + __metal_fmod(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fract(itype x) { \ + return static_cast(__metal_fract(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype frexp(itype x, thread int& exp) { \ + return static_cast(__metal_frexp(static_cast(x), &exp)); \ + } \ + METAL_FUNC otype ldexp(itype x, int k) { \ + return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ + } \ + METAL_FUNC otype log(itype x) { \ + return static_cast(__metal_log(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log10(itype x) { \ + return static_cast(__metal_log10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log2(itype x) { \ + return static_cast(__metal_log2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype max(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype max3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype median3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype min(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype min3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype nextafter(itype x, itype y) { \ + return static_cast( \ + __metal_nextafter(static_cast(x), static_cast(y))); \ + } \ + METAL_FUNC otype pow(itype x, itype y) { \ + return static_cast( \ + __metal_pow(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype powr(itype x, itype y) { \ + return static_cast( \ + __metal_powr(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype rint(itype x) { \ + return static_cast(__metal_rint(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype round(itype x) { \ + return static_cast(__metal_round(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype rsqrt(itype x) { \ + return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sin(itype x) { \ + return static_cast(__metal_sin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinh(itype x) { \ + return static_cast(__metal_sinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinpi(itype x) { \ + return static_cast(__metal_sinpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sqrt(itype x) { \ + return static_cast(__metal_sqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tan(itype x) { \ + return static_cast(__metal_tan(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanh(itype x) { \ + return static_cast(__metal_tanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanpi(itype x) { \ + return static_cast(__metal_tanpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype trunc(itype x) { \ + return static_cast(__metal_trunc(static_cast(x), mfast)); \ + } + +namespace metal { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_MAYBE_FAST_MATH__); + +namespace fast { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_FAST_MATH__); + +} // namespace fast + +namespace precise { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_PRECISE_MATH__); + +} // namespace precise + +} // namespace metal + +/////////////////////////////////////////////////////////////////////////////// +// Metal simd for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_metal_simd_comm_funcs( \ + itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ + \ + METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ + } + +#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ + \ + METAL_FUNC otype simd_max(itype data) { \ + return static_cast(__metal_simd_max(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_min(itype data) { \ + return static_cast(__metal_simd_min(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_product(itype data) { \ + return static_cast(__metal_simd_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_sum(itype data) { \ + return static_cast(__metal_simd_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_xor(itype data) { \ + return static_cast(__metal_simd_xor(static_cast(data))); \ + } + +#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) + +#define bfloat16_to_uint16(x) as_type(x) +#define uint16_to_bfloat16(x) as_type(x) + +#else + +#define bfloat16_to_uint16(x) x.bits_ +#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()) + +#endif + +namespace metal { + +instantiate_metal_simd_comm_funcs( + bfloat16_t, + bfloat16_t, + uint16_t, + bfloat16_to_uint16, + uint16_to_bfloat16); +instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); + +} // namespace metal diff --git a/Source/Cmlx/mlx-generated/metal/binary.h b/Source/Cmlx/mlx-generated/metal/binary.h new file mode 100644 index 00000000..d64488e9 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/binary.h @@ -0,0 +1,139 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[0]); +} + +template +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[index]); +} + +template +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[index], b[0]); +} + +template +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[index], b[index]); +} + +template +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[0], b[offset]); +} + +template +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[offset], b[0]); +} + +template +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[offset], b[offset]); +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + c[index] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + size_t out_idx = + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + c[out_idx++] = Op()(a[idx.x], b[idx.y]); + idx.x += a_xstride; + idx.y += b_xstride; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/binary_ops.h b/Source/Cmlx/mlx-generated/metal/binary_ops.h new file mode 100644 index 00000000..8f961c2c --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/binary_ops.h @@ -0,0 +1,296 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + T operator()(T x, T y) { + return x / y; + } + template <> + float operator()(float x, float y) { + return trunc(x / y); + } + template <> + half operator()(half x, half y) { + return trunc(x / y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return trunc(x / y); + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + auto x_theta = metal::atan2(x.imag, x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + T operator()(T y, T x) { + return metal::precise::atan2(y, x); + } +}; + +struct DivMod { + template + metal::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; diff --git a/Source/Cmlx/mlx-generated/metal/binary_two.h b/Source/Cmlx/mlx-generated/metal/binary_two.h new file mode 100644 index 00000000..a4a3130b --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/binary_two.h @@ -0,0 +1,172 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[0], b[0]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[0], b[index]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[index], b[0]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[index], b[index]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[0], b[offset]); + c[offset] = out[0]; + d[offset] = out[1]; +} + +template +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[offset], b[0]); + c[offset] = out[0]; + d[offset] = out[1]; +} + +template +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[offset], b[offset]); + c[offset] = out[0]; + d[offset] = out[1]; +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + auto out = Op()(a[a_idx], b[b_idx]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + size_t out_idx = + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx++] = out[1]; + idx.x += a_xstride; + idx.y += b_xstride; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/complex.h b/Source/Cmlx/mlx-generated/metal/complex.h new file mode 100644 index 00000000..fe8ec5c0 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/complex.h @@ -0,0 +1,133 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +struct complex64_t; + +template +static constexpr constant bool can_convert_to_complex64 = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_complex64 = + !is_same_v && + (is_convertible_v || is_convertible_v); + +struct complex64_t { + float real; + float imag; + + // Constructors + constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; + constexpr complex64_t() : real(0), imag(0) {}; + constexpr complex64_t() threadgroup : real(0), imag(0) {}; + + // Conversions to complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) thread : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) device : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) constant : real(x), imag(0) {} + + // Conversions from complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const thread { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const threadgroup { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const device { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const constant { + return static_cast(real); + } +}; + +constexpr complex64_t operator-(complex64_t x) { + return {-x.real, -x.imag}; +} + +constexpr bool operator>=(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); +} + +constexpr bool operator>(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); +} + +constexpr bool operator<=(complex64_t a, complex64_t b) { + return operator>=(b, a); +} + +constexpr bool operator<(complex64_t a, complex64_t b) { + return operator>(b, a); +} + +constexpr bool operator==(complex64_t a, complex64_t b) { + return a.real == b.real && a.imag == b.imag; +} + +constexpr complex64_t operator+(complex64_t a, complex64_t b) { + return {a.real + b.real, a.imag + b.imag}; +} + +constexpr complex64_t operator-(complex64_t a, complex64_t b) { + return {a.real - b.real, a.imag - b.imag}; +} + +constexpr complex64_t operator*(complex64_t a, complex64_t b) { + return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; +} + +constexpr complex64_t operator/(complex64_t a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a.real * b.real + a.imag * b.imag; + auto y = a.imag * b.real - a.real * b.imag; + return {x / denom, y / denom}; +} + +constexpr complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real - (b.real * static_cast(a.real / b.real)); + auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } + return {real, imag}; +} diff --git a/Source/Cmlx/mlx-generated/metal/conv.metal b/Source/Cmlx/mlx-generated/metal/conv.metal new file mode 100644 index 00000000..c03d09c3 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/conv.metal @@ -0,0 +1,654 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include +#include + +#include "bf16.h" +#include "steel/conv/params.h" + +#define MLX_MTL_CONST static constant constexpr const + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +/// Naive unfold with dilation +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void naive_unfold_Nd( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + const constant MLXConvParams* params [[buffer(2)]], + uint3 gid [[thread_position_in_grid]]) { + int filter_size = params->C; + for (short i = 0; i < N; i++) + filter_size *= params->wS[i]; + + int out_pixels = 1; + for (short i = 0; i < N; i++) + out_pixels *= params->oS[i]; + + // Set out + out += gid.z * filter_size + gid.y * (params->C); + + // Coordinates in input + int is[N] = {0}; + + // gid.z: N oS (Batch and row in unfolded output) + // gid.y: wS (Filter location to unfold input) + // gid.x: C (channel) + + int n = (gid.z) / out_pixels; + int oS = (gid.z) % out_pixels; + int wS = gid.y; + + bool valid = n < params->N; + + // Unroll dimensions + for (int i = N - 1; i >= 0; --i) { + int os_ = (oS % params->oS[i]); + int ws_ = (wS % params->wS[i]); + + ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; + + int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; + int is_max = 1 + params->idil[i] * (params->iS[i] - 1); + + valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); + + is[i] = is_ / params->idil[i]; + + oS /= params->oS[i]; + wS /= params->wS[i]; + } + + if (valid) { + size_t in_offset = n * params->in_strides[0]; + + for (int i = 0; i < N; ++i) { + in_offset += is[i] * params->in_strides[i + 1]; + } + + out[gid.x] = in[in_offset + gid.x]; + } else { + out[gid.x] = T(0); + } +} + +// This kernel unfolds the input array of size (N, *spatial_dims, C) +// into an array of size (N x *spatial_dims, C x *kernel_dims). +template +[[kernel]] void naive_unfold_transpose_Nd( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + const constant MLXConvParams* params [[buffer(2)]], + uint3 gid [[thread_position_in_grid]]) { + int filter_size = params->C; + for (short i = 0; i < N; i++) + filter_size *= params->wS[i]; + + int out_pixels = 1; + for (short i = 0; i < N; i++) + out_pixels *= params->oS[i]; + + // Set out + out += gid.z * filter_size + gid.x * (filter_size / params->C); + + // Coordinates in input + int is[N] = {0}; + + // gid.z: N oS (Batch and row in unfolded output) + // gid.y: wS (Filter location to unfold input) + // gid.x: C (channel) + + int n = (gid.z) / out_pixels; + int oS = (gid.z) % out_pixels; + int wS = gid.y; + + bool valid = n < params->N; + + // Unroll dimensions + int kernel_stride = 1; + for (int i = N - 1; i >= 0; --i) { + int os_ = (oS % params->oS[i]); + int ws_ = (wS % params->wS[i]); + out += ws_ * kernel_stride; + + ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; + + int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; + int is_max = 1 + params->idil[i] * (params->iS[i] - 1); + + valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); + + is[i] = is_ / params->idil[i]; + + oS /= params->oS[i]; + wS /= params->wS[i]; + + kernel_stride *= params->wS[i]; + } + + if (valid) { + size_t in_offset = n * params->in_strides[0]; + + for (int i = 0; i < N; ++i) { + in_offset += is[i] * params->in_strides[i + 1]; + } + + out[0] = in[in_offset + gid.x]; + } else { + out[0] = T(0); + } +} + +#define instantiate_naive_unfold_nd(name, itype, n) \ + template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \ + naive_unfold_Nd( \ + const device itype* in [[buffer(0)]], \ + device itype* out [[buffer(1)]], \ + const constant MLXConvParams* params [[buffer(2)]], \ + uint3 gid [[thread_position_in_grid]]); \ + template \ + [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \ + naive_unfold_transpose_Nd( \ + const device itype* in [[buffer(0)]], \ + device itype* out [[buffer(1)]], \ + const constant MLXConvParams* params [[buffer(2)]], \ + uint3 gid [[thread_position_in_grid]]); + +#define instantiate_naive_unfold_nd_dims(name, itype) \ + instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \ + name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3) + +instantiate_naive_unfold_nd_dims(float32, float); +instantiate_naive_unfold_nd_dims(float16, half); +instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Slow and naive conv2d kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const int BC = 16> +[[kernel]] void naive_conv_2d( + const device T* in [[buffer(0)]], + const device T* wt [[buffer(1)]], + device T* out [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)simd_gid; + (void)simd_lid; + + out += tid.z * params.out_strides[0]; + in += tid.z * params.in_strides[0]; + + int out_o = tid.y * BN * TN + lid.y * TN; + int out_hw = tid.x * BM * TM + lid.x * TM; + + int out_h[TM]; + int out_w[TN]; + + for (int m = 0; m < TM; ++m) { + int mm = (out_hw + m); + out_h[m] = mm / params.oS[1]; + out_w[m] = mm % params.oS[1]; + } + + T in_local[TM]; + T wt_local[TN]; + T out_local[TM * TN] = {T(0)}; + + for (int h = 0; h < params.wS[0]; ++h) { + for (int w = 0; w < params.wS[1]; ++w) { + for (int c = 0; c < params.C; ++c) { + // Local in + for (int m = 0; m < TM; m++) { + int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0]; + int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1]; + + bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; + in_local[m] = valid + ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] + : T(0); + } + + // Load weight + for (int n = 0; n < TN; ++n) { + int o = out_o + n; + wt_local[n] = o < params.O + ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] + + w * params.wt_strides[2] + c] + : T(0); + } + + // Accumulate + for (int m = 0; m < TM; ++m) { + for (int n = 0; n < TN; ++n) { + out_local[m * TN + n] += in_local[m] * wt_local[n]; + } + } + } + } + } + + for (int m = 0; m < TM; ++m) { + for (int n = 0; n < TN; ++n) { + if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && + (out_o + n) < params.O) + out[out_h[m] * params.out_strides[1] + + out_w[m] * params.out_strides[2] + out_o + n] = + out_local[m * TN + n]; + } + } +} + +// Instantiations + +#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ + template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \ + "_tn" #tn)]] [[kernel]] void \ + naive_conv_2d( \ + const device itype* in [[buffer(0)]], \ + const device itype* wt [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant MLXConvParams<2>& params [[buffer(3)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_naive_conv_2d_blocks(name, itype) \ + instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ + instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) + +instantiate_naive_conv_2d_blocks(float32, float); +instantiate_naive_conv_2d_blocks(float16, half); +instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Winograd kernels +/////////////////////////////////////////////////////////////////////////////// + +template +struct WinogradTransforms {}; + +template <> +struct WinogradTransforms<6, 3, 8> { + MLX_MTL_CONST int OUT_TILE_SIZE = 6; + MLX_MTL_CONST int FILTER_SIZE = 3; + MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1; + MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8; + MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { + {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, + {0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, + {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, + {0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, + {5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, + {0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, + {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, + {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, + }; + + MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { + {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, + {1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, + {1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, + {1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, + {1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, + {1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, + {1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, + {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, + }; + + MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { + {1.00, 0.00, 0.00}, + {-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00}, + {-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00}, + {1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0}, + {1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0}, + {32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0}, + {32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0}, + {0.00, 0.00, 1.00}, + }; +}; + +constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; +constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; +constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; + +template +[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void +winograd_conv_2d_weight_transform( + const device T* wt_in [[buffer(0)]], + device T* wt_out [[buffer(1)]], + const constant int& C [[buffer(2)]], + const constant int& O [[buffer(3)]], + uint tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + using WGT = WinogradTransforms; + + // Get lane position in simdgroup + const short qid = simd_lane_id / 4; + const short sm = (qid & 4) + (simd_lane_id / 2) % 4; + const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Initialize G matrix + simdgroup_matrix G; + G.thread_elements()[0] = WGT::wt_transform[sm][sn]; + G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; + + // Initialize Gt matrix + simdgroup_matrix Gt; + Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; + Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; + + // Move to the correct output filter + size_t ko = BO * tid + simd_group_id; + wt_in += ko * R * R * C; + + // wt_out is stored transposed (A x A x C x O) + short ohw_0 = sm * 8 + sn; + short ohw_1 = sm * 8 + sn + 1; + device T* wt_out_0 = wt_out + ohw_0 * C * O + ko; + device T* wt_out_1 = wt_out + ohw_1 * C * O + ko; + + // Prepare shared memory + threadgroup T Ws[BO][R][R][BC]; + + // Loop over C + for (int bc = 0; bc < C; bc += BC) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read into shared memory + for (int kh = 0; kh < R; ++kh) { + for (int kw = 0; kw < R; ++kw) { + for (int kc = simd_lane_id; kc < BC; kc += 32) { + Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Do transform and store the result + for (int c = 0; c < BC; ++c) { + simdgroup_matrix g; + g.thread_elements()[0] = + sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); + g.thread_elements()[1] = + sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); + + simdgroup_matrix g_out = (G * g) * Gt; + wt_out_0[c * O] = static_cast(g_out.thread_elements()[0]); + wt_out_1[c * O] = static_cast(g_out.thread_elements()[1]); + } + + wt_in += BC; + wt_out_0 += BC * O; + wt_out_1 += BC * O; + } +} + +#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ + template [[host_name("winograd_conv_2d_weight_transform_" #name \ + "_bc" #bc)]] [[kernel]] void \ + winograd_conv_2d_weight_transform( \ + const device itype* wt_in [[buffer(0)]], \ + device itype* wt_out [[buffer(1)]], \ + const constant int& C [[buffer(2)]], \ + const constant int& O [[buffer(3)]], \ + uint tid [[threadgroup_position_in_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]]); + +template +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +winograd_conv_2d_input_transform( + const device T* inp_in [[buffer(0)]], + device T* inp_out [[buffer(1)]], + const constant MLXConvParams<2>& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_per_grid [[threadgroups_per_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + (void)lid; + + using WGT = WinogradTransforms; + constexpr int A = WGT::IN_TILE_SIZE; + constexpr int N_SIMD_GROUPS = WM * WN; + + // Get lane position in simdgroup + const short qid = simd_lane_id / 4; + const short sm = (qid & 4) + (simd_lane_id / 2) % 4; + const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Initialize B matrix + simdgroup_matrix B; + B.thread_elements()[0] = WGT::in_transform[sm][sn]; + B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; + + // Initialize Bt matrix + simdgroup_matrix Bt; + Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; + Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; + + // Resolve input tile + constexpr int TH = (A / WM); + constexpr int TW = (A / WN); + int kh = TH * (simd_group_id / WN); + int kw = TW * (simd_group_id % WN); + int bh = M * tid.y + kh; + int bw = M * tid.x + kw; + + // Move to the correct input tile + inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] + + bw * params.in_strides[2]; + + // Pre compute strides + int jump_in[TH][TW]; + + for (int h = 0; h < TH; h++) { + for (int w = 0; w < TW; w++) { + jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2]; + } + } + + // inp_out is stored interleaved (A x A x tiles x C) + size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; + size_t tile_id = + tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; + size_t ohw_0 = sm * 8 + sn; + size_t ohw_1 = sm * 8 + sn + 1; + device T* inp_out_0 = + inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C; + device T* inp_out_1 = + inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C; + + // Prepare shared memory + threadgroup T Is[A][A][BC]; + + // Loop over C + for (int bc = 0; bc < params.C; bc += BC) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read into shared memory + for (int h = 0; h < TH; h++) { + for (int w = 0; w < TW; w++) { + const device T* in_ptr = inp_in + jump_in[h][w]; + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Do transform and store the result + for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { + simdgroup_matrix I; + I.thread_elements()[0] = Is[sm][sn][c]; + I.thread_elements()[1] = Is[sm][sn + 1][c]; + + simdgroup_matrix I_out = (Bt * I) * B; + inp_out_0[c] = static_cast(I_out.thread_elements()[0]); + inp_out_1[c] = static_cast(I_out.thread_elements()[1]); + } + + inp_in += BC; + inp_out_0 += BC; + inp_out_1 += BC; + } +} + +#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \ + template [[host_name("winograd_conv_2d_input_transform_" #name \ + "_bc" #bc)]] [[kernel]] void \ + winograd_conv_2d_input_transform( \ + const device itype* inp_in [[buffer(0)]], \ + device itype* inp_out [[buffer(1)]], \ + const constant MLXConvParams<2>& params [[buffer(2)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_per_grid [[threadgroups_per_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]]); + +template +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +winograd_conv_2d_output_transform( + const device T* out_in [[buffer(0)]], + device T* out_out [[buffer(1)]], + const constant MLXConvParams<2>& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_per_grid [[threadgroups_per_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + (void)lid; + + using WGT = WinogradTransforms; + constexpr int N_SIMD_GROUPS = WM * WN; + + // Get lane position in simdgroup + const short qid = simd_lane_id / 4; + const short sm = (qid & 4) + (simd_lane_id / 2) % 4; + const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Initialize A matrix + simdgroup_matrix B; + B.thread_elements()[0] = WGT::out_transform[sm][sn]; + B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; + + // Initialize At matrix + simdgroup_matrix Bt; + Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; + Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; + + // Out_in comes in shape (A x A x tiles x O) + // We do transform and then write out to out_out in shape (N, H, W, O) + + // Resolve output tile + constexpr int TH = (M / WM); + constexpr int TW = (M / WN); + int kh = TH * (simd_group_id / WN); + int kw = TW * (simd_group_id % WN); + int bh = M * tid.y + kh; + int bw = M * tid.x + kw; + + // Move to the correct input tile + out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] + + bw * params.out_strides[2]; + + // Pre compute strides + int jump_in[TH][TW]; + + for (int h = 0; h < TH; h++) { + for (int w = 0; w < TW; w++) { + bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]); + jump_in[h][w] = + valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1; + } + } + + // out_in is stored interleaved (A x A x tiles x O) + size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; + size_t tile_id = + tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; + size_t ohw_0 = sm * 8 + sn; + size_t ohw_1 = sm * 8 + sn + 1; + const device T* out_in_0 = + out_in + ohw_0 * N_TILES * params.O + tile_id * params.O; + const device T* out_in_1 = + out_in + ohw_1 * N_TILES * params.O + tile_id * params.O; + + // Prepare shared memory + threadgroup T Os[M][M][BO]; + + // Loop over O + for (int bo = 0; bo < params.O; bo += BO) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Do transform and store the result + for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { + simdgroup_matrix O_mat; + O_mat.thread_elements()[0] = out_in_0[c]; + O_mat.thread_elements()[1] = out_in_1[c]; + + simdgroup_matrix O_out = (Bt * (O_mat * B)); + if ((sm < M) && (sn < M)) { + Os[sm][sn][c] = static_cast(O_out.thread_elements()[0]); + } + if ((sm < M) && ((sn + 1) < M)) { + Os[sm][sn + 1][c] = static_cast(O_out.thread_elements()[1]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read out from shared memory + for (int h = 0; h < TH; h++) { + for (int w = 0; w < TW; w++) { + if (jump_in[h][w] >= 0) { + device T* out_ptr = out_out + jump_in[h][w]; + for (int c = simd_lane_id; c < BO; c += 32) { + out_ptr[c] = Os[kh + h][kw + w][c]; + } + } + } + } + + out_out += BO; + out_in_0 += BO; + out_in_1 += BO; + } +} + +#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \ + template [[host_name("winograd_conv_2d_output_transform_" #name \ + "_bo" #bo)]] [[kernel]] void \ + winograd_conv_2d_output_transform( \ + const device itype* out_in [[buffer(0)]], \ + device itype* out_out [[buffer(1)]], \ + const constant MLXConvParams<2>& params [[buffer(2)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_per_grid [[threadgroups_per_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_winograd_conv_2d(name, itype) \ + instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ + instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ + instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on + +// clang-format off +instantiate_winograd_conv_2d(float32, float); +instantiate_winograd_conv_2d(bfloat16, bfloat16_t); +instantiate_winograd_conv_2d(float16, half); // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/copy.h b/Source/Cmlx/mlx-generated/metal/copy.h new file mode 100644 index 00000000..914aebfd --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/copy.h @@ -0,0 +1,164 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void copy_s( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + dst[index] = static_cast(src[0]); +} + +template +[[kernel]] void copy_v( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + dst[index] = static_cast(src[index]); +} + +template +[[kernel]] void copy_s2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + dst[offset] = static_cast(src[0]); +} + +template +[[kernel]] void copy_v2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + dst[offset] = static_cast(src[offset]); +} + +template +[[kernel]] void copy_g_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + dst[index] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc( + {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); + if (N == 1) { + int64_t dst_idx = + index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); + return; + } + auto xshape = src_shape[ndim - 1]; + int64_t dst_idx = + N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z); + auto src_xstride = src_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[dst_idx + i] = static_cast(src[src_idx]); + src_idx += src_xstride; + } +} + +template +[[kernel]] void copy_gg_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = static_cast(src[idx.x]); + return; + } + auto src_xstride = src_strides[ndim - 1]; + auto dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = static_cast(src[idx.x]); + idx.x += src_xstride; + idx.y += dst_xstride; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/defines.h b/Source/Cmlx/mlx-generated/metal/defines.h new file mode 100644 index 00000000..c369adb7 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/defines.h @@ -0,0 +1,24 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#if defined __METAL__ || defined MLX_METAL_JIT +#define MTL_CONST constant +#else +#define MTL_CONST +#endif + +static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; +static MTL_CONST constexpr int REDUCE_N_READS = 4; +static MTL_CONST constexpr int REDUCE_N_WRITES = 4; +static MTL_CONST constexpr int SOFTMAX_N_READS = 4; +static MTL_CONST constexpr int RMS_N_READS = 4; +static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; diff --git a/Source/Cmlx/mlx-generated/metal/erf.h b/Source/Cmlx/mlx-generated/metal/erf.h new file mode 100644 index 00000000..da6c2eac --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/erf.h @@ -0,0 +1,69 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include + +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + */ +float erf(float a) { + float r, s, t, u; + t = metal::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = metal::fma(r, s, u); + r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = metal::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - metal::exp(r); + r = metal::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = metal::fma(r, a, a); + } + return r; +} + +float erfinv(float a) { + auto t = metal::fma(a, 0.0f - a, 1.0f); + t = metal::log(t); + float p; + if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} diff --git a/Source/Cmlx/mlx-generated/metal/expm1f.h b/Source/Cmlx/mlx-generated/metal/expm1f.h new file mode 100644 index 00000000..68224e17 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/expm1f.h @@ -0,0 +1,90 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +// Original license copied below: +// Copyright (c) 2015-2023 Norbert Juffa +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 + + i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. + Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). + With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, + when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. + + NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) +*/ +float expm1f_scaled_unchecked(float a, float b) { + float f, j, r, s, t, u, v, x, y; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 + j = j - 12582912.0f; // 0x1.8p23 + i = (int)j; + f = fma(j, -6.93145752e-1f, a); + + // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] + s = f * f; + if (a == 0.0f) + s = a; // ensure -0 is passed through + // err = 0.997458 ulp1 = 11081805 + r = 1.97350979e-4f; // 0x1.9de000p-13 + r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 + r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 + r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 + r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 + r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 + u = (j == 1) ? (f + 0.5f) : f; + v = fma(r, s, u); + s = 0.5f * b; + t = ldexp(s, i); + y = t - s; + x = (t - y) - s; // double-float canonicalization of difference + r = fma(v, t, x) + y; + r = r + r; + if (j == 0) + r = v; + if (j == 1) + r = v + v; + return r; +} + +/* Compute exponential base e minus 1. max ulp err = 0.99746 */ +float expm1f(float a) { + float r; + + r = expm1f_scaled_unchecked(a, 1.0f); + /* handle severe overflow and underflow */ + if (abs(a - 1.0f) > 88.0f) { + r = pow(2, a); + r = fma(r, r, -1.0f); + } + return r; +} diff --git a/Source/Cmlx/mlx-generated/metal/fft.h b/Source/Cmlx/mlx-generated/metal/fft.h new file mode 100644 index 00000000..911d8d97 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/fft.h @@ -0,0 +1,486 @@ +// Copyright © 2024 Apple Inc. + +// Metal FFT using Stockham's algorithm +// +// References: +// - VkFFT (https://github.com/DTolm/VkFFT) +// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) + +#include + +#include "fft/radix.h" +#include "fft/readwrite.h" +#include "steel/defines.h" + +using namespace metal; + +#define MAX_RADIX 13 +// Reached when elems_per_thread_ = 6, max_radix = 13 +// and some threads have to do 3 radix 6s requiring 18 float2s. +#define MAX_OUTPUT_SIZE 18 + +// Specialize for a particular value of N at runtime +STEEL_CONST bool inv_ [[function_constant(0)]]; +STEEL_CONST bool is_power_of_2_ [[function_constant(1)]]; +STEEL_CONST int elems_per_thread_ [[function_constant(2)]]; +// rader_m = n / rader_n +STEEL_CONST int rader_m_ [[function_constant(3)]]; +// Stockham steps +STEEL_CONST int radix_13_steps_ [[function_constant(4)]]; +STEEL_CONST int radix_11_steps_ [[function_constant(5)]]; +STEEL_CONST int radix_8_steps_ [[function_constant(6)]]; +STEEL_CONST int radix_7_steps_ [[function_constant(7)]]; +STEEL_CONST int radix_6_steps_ [[function_constant(8)]]; +STEEL_CONST int radix_5_steps_ [[function_constant(9)]]; +STEEL_CONST int radix_4_steps_ [[function_constant(10)]]; +STEEL_CONST int radix_3_steps_ [[function_constant(11)]]; +STEEL_CONST int radix_2_steps_ [[function_constant(12)]]; +// Rader steps +STEEL_CONST int rader_13_steps_ [[function_constant(13)]]; +STEEL_CONST int rader_11_steps_ [[function_constant(14)]]; +STEEL_CONST int rader_8_steps_ [[function_constant(15)]]; +STEEL_CONST int rader_7_steps_ [[function_constant(16)]]; +STEEL_CONST int rader_6_steps_ [[function_constant(17)]]; +STEEL_CONST int rader_5_steps_ [[function_constant(18)]]; +STEEL_CONST int rader_4_steps_ [[function_constant(19)]]; +STEEL_CONST int rader_3_steps_ [[function_constant(20)]]; +STEEL_CONST int rader_2_steps_ [[function_constant(21)]]; + +// See "radix.h" for radix codelets +typedef void (*RadixFunc)(thread float2*, thread float2*); + +// Perform a single radix n butterfly with appropriate twiddles +template +METAL_FUNC void radix_butterfly( + int i, + int p, + thread float2* x, + thread short* indices, + thread float2* y) { + // i: the index in the overall DFT that we're processing. + // p: the size of the DFTs we're merging at this step. + // m: how many threads are working on this DFT. + int k, j; + + // Use faster bitwise operations when working with powers of two + constexpr bool radix_p_2 = (radix & (radix - 1)) == 0; + if (radix_p_2 && is_power_of_2_) { + constexpr short power = __builtin_ctz(radix); + k = i & (p - 1); + j = ((i - k) << power) + k; + } else { + k = i % p; + j = (i / p) * radix * p + k; + } + + // Apply twiddles + if (p > 1) { + float2 twiddle_1 = get_twiddle(k, radix * p); + float2 twiddle = twiddle_1; + x[1] = complex_mul(x[1], twiddle); + + STEEL_PRAGMA_UNROLL + for (int t = 2; t < radix; t++) { + twiddle = complex_mul(twiddle, twiddle_1); + x[t] = complex_mul(x[t], twiddle); + } + } + + radix_func(x, y); + + STEEL_PRAGMA_UNROLL + for (int t = 0; t < radix; t++) { + indices[t] = j + t * p; + } +} + +// Perform all the radix steps required for a +// particular radix size n. +template +METAL_FUNC void radix_n_steps( + int i, + thread int* p, + int m, + int n, + int num_steps, + thread float2* inputs, + thread short* indices, + thread float2* values, + threadgroup float2* buf) { + int m_r = n / radix; + // When combining different sized radices, we have to do + // multiple butterflies in a single thread. + // E.g. n = 28 = 4 * 7 + // 4 threads, 7 elems_per_thread + // All threads do 1 radix7 butterfly. + // 3 threads do 2 radix4 butterflies. + // 1 thread does 1 radix4 butterfly. + int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix; + + int index = 0; + int r_index = 0; + for (int s = 0; s < num_steps; s++) { + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + inputs[r] = buf[index + r * m_r]; + } + radix_butterfly( + index, *p, inputs, indices + t * radix, values + t * radix); + } + } + + // Wait until all threads have read their inputs into thread local mem + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + r_index = t * radix + r; + buf[indices[r_index]] = values[r_index]; + } + } + } + + // Wait until all threads have written back to threadgroup mem + threadgroup_barrier(mem_flags::mem_threadgroup); + *p *= radix; + } +} + +#define RADIX_STEP(radix, radix_func, num_steps) \ + radix_n_steps( \ + fft_idx, p, m, n, num_steps, inputs, indices, values, buf); + +template +METAL_FUNC void +perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) { + float2 inputs[MAX_RADIX]; + short indices[MAX_OUTPUT_SIZE]; + float2 values[MAX_OUTPUT_SIZE]; + + RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_); + RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_); + RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_); + RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_); + RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_); + RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_); + RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_); + RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_); + RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_); +} + +// Each FFT is computed entirely in shared GPU memory. +// +// N is decomposed into radix-n DFTs: +// e.g. 128 = 2 * 4 * 4 * 4 +template +[[kernel]] void fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void rader_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* raders_b_q [[buffer(2)]], + const device short* raders_g_q [[buffer(3)]], + const device short* raders_g_minus_q [[buffer(4)]], + constant const int& n, + constant const int& batch_size, + constant const int& rader_n, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Use Rader's algorithm to compute fast FFTs + // when a prime factor `p` of `n` is greater than 13 but + // has `p - 1` Stockham decomposable into to prime factors <= 13. + // + // E.g. n = 102 + // = 2 * 3 * 17 + // . = 2 * 3 * RADER(16) + // . = 2 * 3 * RADER(4 * 4) + // + // In numpy: + // x_perm = x[g_q] + // y = np.fft.fft(x_perm) * b_q + // z = np.fft.ifft(y) + x[0] + // out = z[g_minus_q] + // out[0] = x[1:].sum() + // + // Where the g_q and g_minus_q are permutations formed + // by the group under multiplicative modulo N using the + // primitive root of N and b_q is a constant. + // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm + // + // Rader's uses fewer operations than Bluestein's and so + // is more accurate. It's also faster in most cases. + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The number of the threads we're using for each DFT + int m = grid.z; + + int fft_idx = elem.z; + int tg_idx = elem.y * n; + threadgroup float2* buf = &shared_in[tg_idx]; + + // rader_m = n / rader_n; + int rader_m = rader_m_; + + // We have to load two x_0s for each thread since sometimes + // elems_per_thread_ crosses a boundary. + // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4 + // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + short x_0_index = + metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1); + float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]}; + + // Do the Rader permutation in shared memory + float2 temp[MAX_RADIX]; + int max_index = n - rader_m - 1; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q = raders_g_q[index / rader_m]; + temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[index + rader_m] = temp[e]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader FFT on x[rader_m:] + int p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + // x_1 + ... + x_n is computed for us in the first FFT step so + // we save it in the first rader_m indices of the array for later. + int x_sum_index = metal::min(fft_idx, rader_m - 1); + buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)]; + + float2 inv = {1.0f, -1.0f}; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short interleaved_index = + index / rader_m + (index % rader_m) * (rader_n - 1); + temp[e] = complex_mul( + buf[rader_m + interleaved_index], + raders_b_q[interleaved_index % (rader_n - 1)]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[rader_m + index] = temp[e] * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader IFFT on x[rader_m:] + p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)}; + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1); + short diff_index = index / (rader_n - 1) - x_0_index; + temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index]; + } + + // Use the sum of elements that was computed in the first FFT + float2 x_sum = buf[x_0_index] + x_0[0]; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q_index = index % (rader_n - 1); + short g_q = raders_g_minus_q[g_q_index]; + short out_index = index - g_q_index + g_q + (index / (rader_n - 1)); + buf[out_index] = temp[e]; + } + + buf[x_0_index * rader_n] = x_sum; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + p = rader_n; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void bluestein_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* w_q [[buffer(2)]], + const device float2* w_k [[buffer(3)]], + constant const int& length, + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Computes arbitrary length FFTs with Bluestein's algorithm + // + // In numpy: + // bluestein_n = next_power_of_2(2*n - 1) + // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q) + // + // Where w_k and w_q are precomputed on CPU in high precision as: + // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2)) + // w_q = np.fft.fft(1/w_k[-n:]) + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_padded(length, w_k); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + // fft + perform_fft(fft_idx, &p, m, n, buf); + + float2 inv = float2(1.0f, -1.0f); + for (int t = 0; t < elems_per_thread_; t++) { + int index = fft_idx + t * m; + buf[index] = complex_mul(buf[index], w_q[index]) * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ifft + p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_padded(length, w_k); +} + +template < + int tg_mem_size, + typename in_T, + typename out_T, + int step, + bool real = false> +[[kernel]] void four_step_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n1, + constant const int& n2, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Fast four step FFT implementation for powers of 2. + int overall_n = n1 * n2; + int n = step == 0 ? n1 : n2; + int stride = step == 0 ? n2 : n1; + + // The number of the threads we're using for each DFT + int m = grid.z; + int fft_idx = elem.z; + + threadgroup float2 shared_in[tg_mem_size]; + threadgroup float2* buf = &shared_in[elem.y * n]; + + using read_writer_t = ReadWriter; + read_writer_t read_writer = read_writer_t( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_strided(stride, overall_n); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_strided(stride, overall_n); +} \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/fft/radix.h b/Source/Cmlx/mlx-generated/metal/fft/radix.h new file mode 100644 index 00000000..bd61eef6 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/fft/radix.h @@ -0,0 +1,328 @@ +// Copyright © 2024 Apple Inc. + +/* Radix kernels + +We provide optimized, single threaded Radix codelets +for n=2,3,4,5,6,7,8,10,11,12,13. + +For n=2,3,4,5,6 we hand write the codelets. +For n=8,10,12 we combine smaller codelets. +For n=7,11,13 we use Rader's algorithm which decomposes +them into (n-1)=6,10,12 codelets. */ + +#pragma once + +#include +#include +#include + +METAL_FUNC float2 complex_mul(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +// Complex mul followed by conjugate +METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); +} + +// Compute an FFT twiddle factor +METAL_FUNC float2 get_twiddle(int k, int p) { + float theta = -2.0f * k * M_PI_F / p; + + float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; + return twiddle; +} + +METAL_FUNC void radix2(thread float2* x, thread float2* y) { + y[0] = x[0] + x[1]; + y[1] = x[0] - x[1]; +} + +METAL_FUNC void radix3(thread float2* x, thread float2* y) { + float pi_2_3 = -0.8660254037844387; + + float2 a_1 = x[1] + x[2]; + float2 a_2 = x[1] - x[2]; + + y[0] = x[0] + a_1; + float2 b_1 = x[0] - 0.5 * a_1; + float2 b_2 = pi_2_3 * a_2; + + float2 b_2_j = {-b_2.y, b_2.x}; + y[1] = b_1 + b_2_j; + y[2] = b_1 - b_2_j; +} + +METAL_FUNC void radix4(thread float2* x, thread float2* y) { + float2 z_0 = x[0] + x[2]; + float2 z_1 = x[0] - x[2]; + float2 z_2 = x[1] + x[3]; + float2 z_3 = x[1] - x[3]; + float2 z_3_i = {z_3.y, -z_3.x}; + + y[0] = z_0 + z_2; + y[1] = z_1 + z_3_i; + y[2] = z_0 - z_2; + y[3] = z_1 - z_3_i; +} + +METAL_FUNC void radix5(thread float2* x, thread float2* y) { + float2 root_5_4 = 0.5590169943749475; + float2 sin_2pi_5 = 0.9510565162951535; + float2 sin_1pi_5 = 0.5877852522924731; + + float2 a_1 = x[1] + x[4]; + float2 a_2 = x[2] + x[3]; + float2 a_3 = x[1] - x[4]; + float2 a_4 = x[2] - x[3]; + + float2 a_5 = a_1 + a_2; + float2 a_6 = root_5_4 * (a_1 - a_2); + float2 a_7 = x[0] - a_5 / 4; + float2 a_8 = a_7 + a_6; + float2 a_9 = a_7 - a_6; + float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4; + float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4; + float2 a_10_j = {a_10.y, -a_10.x}; + float2 a_11_j = {a_11.y, -a_11.x}; + + y[0] = x[0] + a_5; + y[1] = a_8 + a_10_j; + y[2] = a_9 + a_11_j; + y[3] = a_9 - a_11_j; + y[4] = a_8 - a_10_j; +} + +METAL_FUNC void radix6(thread float2* x, thread float2* y) { + float sin_pi_3 = 0.8660254037844387; + float2 a_1 = x[2] + x[4]; + float2 a_2 = x[0] - a_1 / 2; + float2 a_3 = sin_pi_3 * (x[2] - x[4]); + float2 a_4 = x[5] + x[1]; + float2 a_5 = x[3] - a_4 / 2; + float2 a_6 = sin_pi_3 * (x[5] - x[1]); + float2 a_7 = x[0] + a_1; + + float2 a_3_i = {a_3.y, -a_3.x}; + float2 a_6_i = {a_6.y, -a_6.x}; + float2 a_8 = a_2 + a_3_i; + float2 a_9 = a_2 - a_3_i; + float2 a_10 = x[3] + a_4; + float2 a_11 = a_5 + a_6_i; + float2 a_12 = a_5 - a_6_i; + + y[0] = a_7 + a_10; + y[1] = a_8 - a_11; + y[2] = a_9 + a_12; + y[3] = a_7 - a_10; + y[4] = a_8 + a_11; + y[5] = a_9 - a_12; +} + +METAL_FUNC void radix7(thread float2* x, thread float2* y) { + // Rader's algorithm + float2 inv = {1 / 6.0, -1 / 6.0}; + + // fft + float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]}; + radix6(in1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879)); + y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629)); + y[4] = complex_mul_conj(y[4], float2(0, -2.64575131)); + y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629)); + y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879)); + + // ifft + radix6(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[5] = x[2] * inv + x[0]; + y[4] = x[3] * inv + x[0]; + y[6] = x[4] * inv + x[0]; + y[2] = x[5] * inv + x[0]; + y[3] = x[6] * inv + x[0]; +} + +METAL_FUNC void radix8(thread float2* x, thread float2* y) { + float cos_pi_4 = 0.7071067811865476; + float2 w_0 = {cos_pi_4, -cos_pi_4}; + float2 w_1 = {-cos_pi_4, -cos_pi_4}; + float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]}; + radix4(temp, x); + radix4(temp + 4, x + 4); + + y[0] = x[0] + x[4]; + y[4] = x[0] - x[4]; + float2 x_5 = complex_mul(x[5], w_0); + y[1] = x[1] + x_5; + y[5] = x[1] - x_5; + float2 x_6 = {x[6].y, -x[6].x}; + y[2] = x[2] + x_6; + y[6] = x[2] - x_6; + float2 x_7 = complex_mul(x[7], w_1); + y[3] = x[3] + x_7; + y[7] = x[3] - x_7; +} + +template +METAL_FUNC void radix10(thread float2* x, thread float2* y) { + float2 w[4]; + w[0] = {0.8090169943749475, -0.5877852522924731}; + w[1] = {0.30901699437494745, -0.9510565162951535}; + w[2] = {-w[1].x, w[1].y}; + w[3] = {-w[0].x, w[0].y}; + + if (raders_perm) { + float2 temp[10] = { + x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } else { + float2 temp[10] = { + x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } + + y[0] = x[0] + x[5]; + y[5] = x[0] - x[5]; + for (int t = 1; t < 5; t++) { + float2 a = complex_mul(x[t + 5], w[t - 1]); + y[t] = x[t] + a; + y[t + 5] = x[t] - a; + } +} + +METAL_FUNC void radix11(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 10.0, -1 / 10.0}; + + // fft + radix10(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649)); + y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656)); + y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479)); + y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150)); + y[6] = complex_mul_conj(y[6], float2(0, -3.31662479)); + y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150)); + y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479)); + y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656)); + y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649)); + + // ifft + radix10(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[6] = x[2] * inv + x[0]; + y[3] = x[3] * inv + x[0]; + y[7] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[10] = x[6] * inv + x[0]; + y[5] = x[7] * inv + x[0]; + y[8] = x[8] * inv + x[0]; + y[4] = x[9] * inv + x[0]; + y[2] = x[10] * inv + x[0]; +} + +template +METAL_FUNC void radix12(thread float2* x, thread float2* y) { + float2 w[6]; + float sin_pi_3 = 0.8660254037844387; + w[0] = {sin_pi_3, -0.5}; + w[1] = {0.5, -sin_pi_3}; + w[2] = {0, -1}; + w[3] = {-0.5, -sin_pi_3}; + w[4] = {-sin_pi_3, -0.5}; + + if (raders_perm) { + float2 temp[12] = { + x[0], + x[3], + x[2], + x[11], + x[8], + x[9], + x[1], + x[7], + x[5], + x[10], + x[4], + x[6]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } else { + float2 temp[12] = { + x[0], + x[2], + x[4], + x[6], + x[8], + x[10], + x[1], + x[3], + x[5], + x[7], + x[9], + x[11]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } + + y[0] = x[0] + x[6]; + y[6] = x[0] - x[6]; + for (int t = 1; t < 6; t++) { + float2 a = complex_mul(x[t + 6], w[t - 1]); + y[t] = x[t] + a; + y[t + 6] = x[t] - a; + } +} + +METAL_FUNC void radix13(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 12.0, -1 / 12.0}; + + // fft + radix12(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669)); + y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823)); + y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161)); + y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690)); + y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267)); + y[7] = complex_mul_conj(y[7], float2(3.60555128, 0)); + y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267)); + y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690)); + y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161)); + y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823)); + y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669)); + + // ifft + radix12(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[7] = x[2] * inv + x[0]; + y[10] = x[3] * inv + x[0]; + y[5] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[11] = x[6] * inv + x[0]; + y[12] = x[7] * inv + x[0]; + y[6] = x[8] * inv + x[0]; + y[3] = x[9] * inv + x[0]; + y[8] = x[10] * inv + x[0]; + y[4] = x[11] * inv + x[0]; + y[2] = x[12] * inv + x[0]; +} \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h new file mode 100644 index 00000000..23231946 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h @@ -0,0 +1,622 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "../fft/radix.h" + +/* FFT helpers for reading and writing from/to device memory. + +For many sizes, GPU FFTs are memory bandwidth bound so +read/write performance is important. + +Where possible, we read 128 bits sequentially in each thread, +coalesced with accesses from adajcent threads for optimal performance. + +We implement specialized reading/writing for: + - FFT + - RFFT + - IRFFT + +Each with support for: + - Contiguous reads + - Padded reads + - Strided reads +*/ + +#define MAX_RADIX 13 + +using namespace metal; + +template < + typename in_T, + typename out_T, + int step = 0, + bool four_step_real = false> +struct ReadWriter { + const device in_T* in; + threadgroup float2* buf; + device out_T* out; + int n; + int batch_size; + int elems_per_thread; + uint3 elem; + uint3 grid; + int threads_per_tg; + bool inv; + + // Used for strided access + int strided_device_idx = 0; + int strided_shared_idx = 0; + + METAL_FUNC ReadWriter( + const device in_T* in_, + threadgroup float2* buf_, + device out_T* out_, + const short n_, + const int batch_size_, + const short elems_per_thread_, + const uint3 elem_, + const uint3 grid_, + const bool inv_) + : in(in_), + buf(buf_), + out(out_), + n(n_), + batch_size(batch_size_), + elems_per_thread(elems_per_thread_), + elem(elem_), + grid(grid_), + inv(inv_) { + // Account for padding on last threadgroup + threads_per_tg = elem.x == grid.x - 1 + ? (batch_size - (grid.x - 1) * grid.y) * grid.z + : grid.y * grid.z; + } + + // ifft(x) = 1/n * conj(fft(conj(x))) + METAL_FUNC float2 post_in(float2 elem) const { + return inv ? float2(elem.x, -elem.y) : elem; + } + + // Handle float case for generic RFFT alg + METAL_FUNC float2 post_in(float elem) const { + return float2(elem, 0); + } + + METAL_FUNC float2 pre_out(float2 elem) const { + return inv ? float2(elem.x / n, -elem.y / n) : elem; + } + + METAL_FUNC float2 pre_out(float2 elem, int length) const { + return inv ? float2(elem.x / length, -elem.y / length) : elem; + } + + METAL_FUNC bool out_of_bounds() const { + // Account for possible extra threadgroups + int grid_index = elem.x * grid.y + elem.y; + return grid_index >= batch_size; + } + + METAL_FUNC void load() const { + int batch_idx = elem.x * grid.y * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + // 2 complex64s = 128 bits + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + buf[index] = post_in(in[batch_idx + index]); + buf[index + 1] = post_in(in[batch_idx + index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + buf[index] = post_in(in[batch_idx + index]); + } + } + + METAL_FUNC void write() const { + int batch_idx = elem.x * grid.y * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + out[batch_idx + index] = pre_out(buf[index]); + out[batch_idx + index + 1] = pre_out(buf[index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + out[batch_idx + index] = pre_out(buf[index]); + } + } + + // Padded IO for Bluestein's algorithm + METAL_FUNC void load_padded(int length, const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = post_in(in[batch_idx + index]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0.0; + } + } + } + + METAL_FUNC void write_padded(int length, const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + float2 inv_factor = {1.0f / n, -1.0f / n}; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = seq_buf[index + length - 1] * inv_factor; + out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length); + } + } + } + + // Strided IO for four step FFT + METAL_FUNC void compute_strided_indices(int stride, int overall_n) { + // Use the batch threadgroup dimension to coalesce memory accesses: + // e.g. stride = 12 + // device | shared mem + // 0 1 2 3 | 0 12 - - + // - - - - | 1 13 - - + // - - - - | 2 14 - - + // 12 13 14 15 | 3 15 - - + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread; + } + + // Four Step FFT First Step + METAL_FUNC void load_strided(int stride, int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + buf[strided_shared_idx + e] = + post_in(in[strided_device_idx + e * stride]); + } + } + + METAL_FUNC void write_strided(int stride, int overall_n) { + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + int combined_idx = (strided_device_idx + e * stride) % overall_n; + int ij = (combined_idx / stride) * (combined_idx % stride); + // Apply four step twiddles at end of first step + float2 twiddle = get_twiddle(ij, overall_n); + out[strided_device_idx + e * stride] = complex_mul(output, twiddle); + } + } +}; + +// Four Step FFT Second Step +template <> +METAL_FUNC void ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = pre_out(output, overall_n); + } +} + +// For RFFT, we interleave batches of two real sequences into one complex one: +// +// z_k = x_k + j.y_k +// X_k = (Z_k + Z_(N-k)*) / 2 +// Y_k = -j * ((Z_k - Z_(N-k)*) / 2) +// +// This roughly doubles the throughput over the regular FFT. +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for RFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + seq_buf[index].x = in[batch_idx + index]; + seq_buf[index].y = in[batch_idx + index + next_in]; + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + short n_over_2 = (n / 2) + 1; + + int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + float2 conj = {1, -1}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, n_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + out[batch_idx + index] = {seq_buf[index].x, 0}; + out[batch_idx + index + next_out] = {seq_buf[index].y, 0}; + } else { + float2 x_k = seq_buf[index]; + float2 x_n_minus_k = seq_buf[n - index] * conj; + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = + float2(in[batch_idx + index], in[batch_idx + index + next_in]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + int length_over_2 = (length / 2) + 1; + int batch_idx = + elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + float2 conj = {1, -1}; + float2 inv_factor = {1.0f / n, -1.0f / n}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, length_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor); + out[batch_idx + index] = float2(elem.x, 0); + out[batch_idx + index + next_out] = float2(elem.y, 0); + } else { + float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor); + float2 x_n_minus_k = complex_mul( + w_k[length - index], seq_buf[length - index] * inv_factor); + x_n_minus_k *= conj; + // w_k should happen before this extraction + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +// For IRFFT, we do the opposite +// +// Z_k = X_k + j.Y_k +// x_k = Re(Z_k) +// Y_k = Imag(Z_k) +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for IRFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + short n_over_2 = (n / 2) + 1; + int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + // NumPy forces first input to be real + bool first_val = index == 0; + // NumPy forces last input on even irffts to be real + bool last_val = n % 2 == 0 && index == n_over_2 - 1; + if (first_val || last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + seq_buf[index] = x + complex_mul(y, plus_j); + seq_buf[index].y = -seq_buf[index].y; + if (index > 0 && !last_val) { + seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[n - index].y = -seq_buf[n - index].y; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + out[batch_idx + index] = seq_buf[index].x / n; + out[batch_idx + index + next_out] = seq_buf[index].y / -n; + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + int n_over_2 = (n / 2) + 1; + int length_over_2 = (length / 2) + 1; + + int batch_idx = + elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + if (index < length_over_2) { + bool last_val = length % 2 == 0 && index == length_over_2 - 1; + if (last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + float2 elem1 = x + complex_mul(y, plus_j); + seq_buf[index] = complex_mul(elem1 * conj, w_k[index]); + if (index > 0 && !last_val) { + float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[length - index] = + complex_mul(elem2 * conj, w_k[length - index]); + } + } else { + short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2); + seq_buf[pad_index] = 0; + seq_buf[pad_index + 1] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + float2 inv_factor = {1.0f / n, -1.0f / n}; + for (int e = 0; e < elems_per_thread; e++) { + int index = fft_idx + e * m; + if (index < length) { + float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]); + out[batch_idx + index] = output.x / length; + out[batch_idx + index + next_out] = output.y / -length; + } + } +} + +// Four Step RFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n_over_2 * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread / 2 * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread / 2; + for (int e = 0; e < elems_per_thread / 2; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = output; + } + + // Add on n/2 + 1 element + if (tg_idx == 0 && elem.x % outer_batch_size == 0) { + out[strided_batch_idx + overall_n / 2] = buf[n / 2]; + } +} + +// Four Step IRFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + auto conj = float2(1, -1); + + compute_strided_indices(stride, overall_n); + // Translate indices in terms of N - k + for (int e = 0; e < elems_per_thread; e++) { + int device_idx = strided_device_idx + e * stride; + int overall_batch = device_idx / overall_n; + int overall_index = device_idx % overall_n; + if (overall_index < overall_n_over_2) { + device_idx -= overall_batch * (overall_n - overall_n_over_2); + buf[strided_shared_idx + e] = in[device_idx] * conj; + } else { + int conj_idx = overall_n - overall_index; + device_idx = overall_batch * overall_n_over_2 + conj_idx; + buf[strided_shared_idx + e] = in[device_idx]; + } + } +} + +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + + for (int e = 0; e < elems_per_thread; e++) { + out[strided_device_idx + e * stride] = + pre_out(buf[strided_shared_idx + e], overall_n).x; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/gather.h b/Source/Cmlx/mlx-generated/metal/gather.h new file mode 100644 index 00000000..8063c6f6 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/gather.h @@ -0,0 +1,49 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "indexing.h" + +template +METAL_FUNC void gather_impl( + const device T* src [[buffer(0)]], + device T* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant size_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const thread Indices& indices, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + size_t src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + size_t idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = index.x * indices.strides[indices.ndim * i]; + } else { + idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc += elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += idx_val * src_strides[ax]; + } + + auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + + size_t out_idx = index.z; + if (IDX_NDIM == 1) { + out_idx += static_cast(grid_dim.z) * index.x; + } else if (IDX_NDIM >= 2) { + out_idx += + grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + } + out[out_idx] = src[src_offset + src_idx]; +} diff --git a/Source/Cmlx/mlx-generated/metal/gemv.metal b/Source/Cmlx/mlx-generated/metal/gemv.metal new file mode 100644 index 00000000..00e0704e --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/gemv.metal @@ -0,0 +1,915 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "bf16.h" +#include "defines.h" +#include "utils.h" + +#include "steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_MTL_CONST static constant constexpr const + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +struct GEMVKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + static METAL_FUNC void + load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } + + static METAL_FUNC void load_safe( + const device T* src, + thread T dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Advance matrix + mat += out_row * matrix_ld; + + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + load_unsafe(in_vec, v_coeff, bn); + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + + bn += blockN; + } + + if (leftover > 0) { + load_safe(in_vec, v_coeff, bn, in_size); + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + if (kDoAxpby) { + out_vec[out_row + tm] = static_cast(alpha) * result[tm] + + static_cast(beta) * bias[(out_row + tm) * bias_stride]; + } else { + out_vec[out_row + tm] = result[tm]; + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + + bm += blockM; + } + + if (leftover > 0) { + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + if (kDoAxpby) { + out_vec[out_col + j] = static_cast(alpha) * result[j] + + static_cast(beta) * bias[(out_col + j) * bias_stride]; + } else { + out_vec[out_col + j] = result[j]; + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch, /* Batch ndim > 1 */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant size_t* bias_batch_stride [[buffer(13)]], + const constant int& bias_stride [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (kDoAxpby) { + bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (kDoAxpby) { + bias += tid.z * bias_batch_stride[0]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_helper( \ + name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ + template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ + "_tm" #tm "_tn" #tn "_nc" #nc \ + "_axpby" #axpby)]] [[kernel]] void \ + gemv( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const constant size_t* bias_batch_stride [[buffer(13)]], \ + const constant int& bias_stride [[buffer(14)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on + +// clang-format off +#define instantiate_gemv_blocks(name, itype) \ + instantiate_gemv(name, itype, 4, 32, 1, 4) \ + instantiate_gemv(name, itype, 4, 32, 4, 4) \ + instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on + +instantiate_gemv_blocks(float32, float); +instantiate_gemv_blocks(float16, half); +instantiate_gemv_blocks(bfloat16, bfloat16_t); + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_gather( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* index_batch_strides [[buffer(11)]], + const constant int& vector_batch_ndim [[buffer(12)]], + const constant int* vector_batch_shape [[buffer(13)]], + const constant size_t* vector_batch_stride [[buffer(14)]], + const constant int& matrix_batch_ndim [[buffer(15)]], + const constant int* matrix_batch_shape [[buffer(16)]], + const constant size_t* matrix_batch_stride [[buffer(17)]], + const constant uint32_t* vec_indices [[buffer(18)]], + const constant uint32_t* mat_indices [[buffer(19)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + uint32_t indx_vec; + uint32_t indx_mat; + + // Update batch offsets + if (batch_ndim > 1) { + const constant size_t* veci_bstrides = index_batch_strides; + const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); + + indx_vec = vec_indices[batch_offsets.x]; + indx_mat = mat_indices[batch_offsets.y]; + + } else { + indx_vec = vec_indices[index_batch_strides[0] * tid.z]; + indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; + } + + if (vector_batch_ndim > 1) { + in_vec += elem_to_loc( + indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); + } else { + in_vec += indx_vec * vector_batch_stride[0]; + } + + if (matrix_batch_ndim > 1) { + mat += elem_to_loc( + indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); + } else { + mat += indx_mat * matrix_batch_stride[0]; + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + batch_ndim, // Not used + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ + template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ + "_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \ + gemv_gather( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* index_batch_strides [[buffer(11)]], \ + const constant int& vector_batch_ndim [[buffer(12)]], \ + const constant int* vector_batch_shape [[buffer(13)]], \ + const constant size_t* vector_batch_stride [[buffer(14)]], \ + const constant int& matrix_batch_ndim [[buffer(15)]], \ + const constant int* matrix_batch_shape [[buffer(16)]], \ + const constant size_t* matrix_batch_stride [[buffer(17)]], \ + const constant uint32_t* vec_indices [[buffer(18)]], \ + const constant uint32_t* mat_indices [[buffer(19)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_bs_blocks(name, itype) \ + instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \ + instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \ + instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on + +instantiate_gemv_bs_blocks(float32, float); +instantiate_gemv_bs_blocks(float16, half); +instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch, /* Batch ndim > 1 */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant size_t* bias_batch_stride [[buffer(13)]], + const constant int& bias_stride [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (kDoAxpby) { + bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (kDoAxpby) { + bias += tid.z * bias_batch_stride[0]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_t_helper( \ + name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ + template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ + "_tm" #tm "_tn" #tn "_nc" #nc \ + "_axpby" #axpby)]] [[kernel]] void \ + gemv_t( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const constant size_t* bias_batch_stride [[buffer(13)]], \ + const constant int& bias_stride [[buffer(14)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on + +// clang-format off +#define instantiate_gemv_t_blocks(name, itype) \ + instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \ + instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on + +// clang-format off +instantiate_gemv_t_blocks(float32, float); +instantiate_gemv_t_blocks(float16, half); +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_gather( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* index_batch_strides [[buffer(11)]], + const constant int& vector_batch_ndim [[buffer(12)]], + const constant int* vector_batch_shape [[buffer(13)]], + const constant size_t* vector_batch_stride [[buffer(14)]], + const constant int& matrix_batch_ndim [[buffer(15)]], + const constant int* matrix_batch_shape [[buffer(16)]], + const constant size_t* matrix_batch_stride [[buffer(17)]], + const constant uint32_t* vec_indices [[buffer(18)]], + const constant uint32_t* mat_indices [[buffer(19)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + uint32_t indx_vec; + uint32_t indx_mat; + + // Update batch offsets + if (batch_ndim > 1) { + const constant size_t* veci_bstrides = index_batch_strides; + const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); + + indx_vec = vec_indices[batch_offsets.x]; + indx_mat = mat_indices[batch_offsets.y]; + + } else { + indx_vec = vec_indices[index_batch_strides[0] * tid.z]; + indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; + } + + if (vector_batch_ndim > 1) { + in_vec += elem_to_loc( + indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); + } else { + in_vec += indx_vec * vector_batch_stride[0]; + } + + if (matrix_batch_ndim > 1) { + mat += elem_to_loc( + indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); + } else { + mat += indx_mat * matrix_batch_stride[0]; + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + batch_ndim, // Not used, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ + template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ + "_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \ + gemv_t_gather( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* index_batch_strides [[buffer(11)]], \ + const constant int& vector_batch_ndim [[buffer(12)]], \ + const constant int* vector_batch_shape [[buffer(13)]], \ + const constant size_t* vector_batch_stride [[buffer(14)]], \ + const constant int& matrix_batch_ndim [[buffer(15)]], \ + const constant int* matrix_batch_shape [[buffer(16)]], \ + const constant size_t* matrix_batch_stride [[buffer(17)]], \ + const constant uint32_t* vec_indices [[buffer(18)]], \ + const constant uint32_t* mat_indices [[buffer(19)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_t_bs_blocks(name, itype) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on + +// clang-format off +instantiate_gemv_t_bs_blocks(float32, float); +instantiate_gemv_t_bs_blocks(float16, half); +instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/gemv_masked.h b/Source/Cmlx/mlx-generated/metal/gemv_masked.h new file mode 100644 index 00000000..25491658 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/gemv_masked.h @@ -0,0 +1,819 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "steel/utils.h" + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +typedef struct _NoMask nomask_t; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +struct GEMVKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + static METAL_FUNC void + load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } + + static METAL_FUNC void load_safe( + const device T* src, + thread T dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; + + int mat_mask_offset = + !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = T(0.); + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Advance matrix + mat += out_row * matrix_ld; + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_unsafe(in_vec, v_coeff, bn); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + } + + bn += blockN; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_safe(in_vec, v_coeff, bn, in_size); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = result[tm]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + out_mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; + + int mat_mask_offset = + !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (cm == 0 && out_col < out_vec_size) { + if (out_col + TN <= out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + out_vec[out_col + tn] = T(0.); + } + } else { + for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { + out_vec[out_col + tn] = T(0.); + } + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] *= block_scale; + } + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + + bm += blockM; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + out_vec[out_col + j] = result[j]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} diff --git a/Source/Cmlx/mlx-generated/metal/hadamard.h b/Source/Cmlx/mlx-generated/metal/hadamard.h new file mode 100644 index 00000000..8f2d8cc1 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/hadamard.h @@ -0,0 +1,167 @@ +// Copyright © 2024 Apple Inc. +#include +#include + +#include "steel/defines.h" + +using namespace metal; + +// Thread local Hadamard transform for 2^R +template +METAL_FUNC void radix_func(thread float* x) { + constexpr short logR = __builtin_ctz(R); + short h = 1; + STEEL_PRAGMA_UNROLL + for (short s = 0; s < logR; s++) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < R / 2; i++) { + short k = i & (h - 1); + short j = ((i - k) << 1) + k; + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + h <<= 1; + } +} + +template +[[kernel]] void hadamard_n( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size N = 2^k + // + // Equivalent to: + // from scipy.linalg import hadamard + // y = hadamard(len(x)) @ x + + constexpr short num_threads = N / max_radix; + constexpr short logN = __builtin_ctz(N); + constexpr short logR = __builtin_ctz(max_radix); + constexpr short num_steps = logN / logR; + constexpr short logFinal = logN % logR; + constexpr short final_radix = 1 << (logFinal); + + int batch_idx = elem.x * N; + short i = elem.y; + + threadgroup T buf[N]; + + // Read values from device + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float x[max_radix]; + short h = 1; + + STEEL_PRAGMA_UNROLL + for (short s = 0; s < num_steps; s++) { + short k = i & (h - 1); + short j = ((i - k) << logR) + k; + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + buf[j + h * r] = T(x[r]); + } + + h <<= logR; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Do the final radix + // e.g. max_radix = 16 + // N = 1024 = 16 * 16 * 4 + if (final_radix > 1) { + // Each thread does multiple butterflies + STEEL_PRAGMA_UNROLL + for (int t = 0; t < max_radix / final_radix; t++) { + short index = i + t * num_threads; + short k = index & (h - 1); + short j = ((index - k) << logFinal) + k; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + buf[j + h * r] = T(x[r]); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write values to device + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } +} + +template +[[kernel]] void hadamard_m( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size M + // using a naive O(M^2) codelet. + // + // This kernel is the second stage in the computation + // of a Hadamard transform of size M*N where N = 2^k. + + int index = elem.x * grid.y + elem.y; + short i = index % (N / read_width); + int batch_idx = index / (N / read_width) * M * N; + + float x[read_width][M]; + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + x[r][c] = in[batch_idx + c * N + i * read_width + r]; + } + } + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + // This function is JIT compiled for M + // using the Hadamard matrix strings in `metal/hadamard.cpp` + hadamard_radix_m(x[r]); + } + + // Write back to device + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing.h b/Source/Cmlx/mlx-generated/metal/indexing.h new file mode 100644 index 00000000..9f76e477 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/indexing.h @@ -0,0 +1,22 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant size_t* strides; + const int ndim; +}; + +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/layer_norm.metal b/Source/Cmlx/mlx-generated/metal/layer_norm.metal new file mode 100644 index 00000000..79f04d7b --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/layer_norm.metal @@ -0,0 +1,556 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "bf16.h" +#include "defines.h" +#include "utils.h" + +using namespace metal; + +template +[[kernel]] void layer_norm_single_row( + const device T* x, + const device T* w, + const device T* b, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + constant uint& b_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float sumx = 0; + float sumx2 = 0; + float thread_x[N_READS]; + + constexpr int SIMD_SIZE = 32; + + threadgroup float local_sumx[SIMD_SIZE]; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_mean[1]; + threadgroup float local_normalizer[1]; + + x += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + b += b_stride * lid * N_READS; + + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = x[i]; + sumx2 += thread_x[i] * thread_x[i]; + sumx += thread_x[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = x[i]; + sumx2 += thread_x[i] * thread_x[i]; + sumx += thread_x[i]; + } + } + } + + sumx = simd_sum(sumx); + sumx2 = simd_sum(sumx2); + + // Initialize shared memory + if (simd_group_id == 0) { + local_sumx[simd_lane_id] = 0; + local_sumx2[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sumx[simd_group_id] = sumx; + local_sumx2[simd_group_id] = sumx2; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + sumx = simd_sum(local_sumx[simd_lane_id]); + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + if (simd_lane_id == 0) { + float mean = sumx / axis_size; + float variance = sumx2 / axis_size - mean * mean; + + local_mean[0] = mean; + local_normalizer[0] = metal::precise::rsqrt(variance + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = local_mean[0]; + float normalizer = local_normalizer[0]; + + // Write the outputs + out += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = (thread_x[i] - mean) * normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = (thread_x[i] - mean) * normalizer; + out[i] = + w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; + } + } + } +} + +template +[[kernel]] void layer_norm_looped( + const device T* x, + const device T* w, + const device T* b, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + constant uint& b_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float sumx = 0; + float sumx2 = 0; + + constexpr int SIMD_SIZE = 32; + + threadgroup float local_sumx[SIMD_SIZE]; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_mean[1]; + threadgroup float local_normalizer[1]; + + x += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + b += b_stride * lid * N_READS; + + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + sumx2 += xi * xi; + sumx += xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + sumx2 += xi * xi; + sumx += xi; + } + } + } + } + + sumx = simd_sum(sumx); + sumx2 = simd_sum(sumx2); + + // Initialize shared memory + if (simd_group_id == 0) { + local_sumx[simd_lane_id] = 0; + local_sumx2[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sumx[simd_group_id] = sumx; + local_sumx2[simd_group_id] = sumx2; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + sumx = simd_sum(local_sumx[simd_lane_id]); + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + if (simd_lane_id == 0) { + float mean = sumx / axis_size; + float variance = sumx2 / axis_size - mean * mean; + + local_mean[0] = mean; + local_normalizer[0] = metal::precise::rsqrt(variance + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = local_mean[0]; + float normalizer = local_normalizer[0]; + + // Write the outputs + out += gid * size_t(axis_size) + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = (x[r + i] - mean) * normalizer; + out[r + i] = + w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = (x[r + i] - mean) * normalizer; + out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + + b[b_stride * (i + r)]; + } + } + } + } +} + +template +[[kernel]] void vjp_layer_norm_single_row( + const device T* x, + const device T* w, + const device T* g, + device T* gx, + device T* gw, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Advance the input pointers + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + + // Allocate registers for the computation and accumulators + float thread_x[N_READS]; + float thread_w[N_READS]; + float thread_g[N_READS]; + float sumx = 0; + float sumx2 = 0; + float sumwg = 0; + float sumwgx = 0; + + constexpr int SIMD_SIZE = 32; + + threadgroup float local_sumx[SIMD_SIZE]; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_sumwg[SIMD_SIZE]; + threadgroup float local_sumwgx[SIMD_SIZE]; + threadgroup float local_mean[1]; + threadgroup float local_normalizer[1]; + threadgroup float local_meanwg[1]; + threadgroup float local_meanwgx[1]; + + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = x[i]; + thread_w[i] = w[i * w_stride]; + thread_g[i] = g[i]; + float wg = thread_w[i] * thread_g[i]; + sumx += thread_x[i]; + sumx2 += thread_x[i] * thread_x[i]; + sumwg += wg; + sumwgx += wg * thread_x[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = x[i]; + thread_w[i] = w[i * w_stride]; + thread_g[i] = g[i]; + float wg = thread_w[i] * thread_g[i]; + sumx += thread_x[i]; + sumx2 += thread_x[i] * thread_x[i]; + sumwg += wg; + sumwgx += wg * thread_x[i]; + } + } + } + + sumx = simd_sum(sumx); + sumx2 = simd_sum(sumx2); + sumwg = simd_sum(sumwg); + sumwgx = simd_sum(sumwgx); + + // Initialize shared memory + if (simd_group_id == 0) { + local_sumx[simd_lane_id] = 0; + local_sumx2[simd_lane_id] = 0; + local_sumwg[simd_lane_id] = 0; + local_sumwgx[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sumx[simd_group_id] = sumx; + local_sumx2[simd_group_id] = sumx2; + local_sumwg[simd_group_id] = sumwg; + local_sumwgx[simd_group_id] = sumwgx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + sumx = simd_sum(local_sumx[simd_lane_id]); + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + sumwg = simd_sum(local_sumwg[simd_lane_id]); + sumwgx = simd_sum(local_sumwgx[simd_lane_id]); + if (simd_lane_id == 0) { + float mean = sumx / axis_size; + float variance = sumx2 / axis_size - mean * mean; + + local_mean[0] = mean; + local_normalizer[0] = metal::precise::rsqrt(variance + eps); + local_meanwg[0] = sumwg / axis_size; + local_meanwgx[0] = sumwgx / axis_size; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = local_mean[0]; + float normalizer = local_normalizer[0]; + float meanwg = local_meanwg[0]; + float meanwgxc = local_meanwgx[0] - meanwg * mean; + float normalizer2 = normalizer * normalizer; + + // Write the outputs + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = (thread_x[i] - mean) * normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - meanwg) - + thread_x[i] * meanwgxc * normalizer2); + gw[i] = static_cast(thread_g[i] * thread_x[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = (thread_x[i] - mean) * normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - meanwg) - + thread_x[i] * meanwgxc * normalizer2); + gw[i] = static_cast(thread_g[i] * thread_x[i]); + } + } + } +} + +template +[[kernel]] void vjp_layer_norm_looped( + const device T* x, + const device T* w, + const device T* g, + device T* gx, + device T* gw, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Advance the input pointers + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + + // Allocate registers for the accumulators + float sumx = 0; + float sumx2 = 0; + float sumwg = 0; + float sumwgx = 0; + + constexpr int SIMD_SIZE = 32; + + threadgroup float local_sumx[SIMD_SIZE]; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_sumwg[SIMD_SIZE]; + threadgroup float local_sumwgx[SIMD_SIZE]; + threadgroup float local_mean[1]; + threadgroup float local_normalizer[1]; + threadgroup float local_meanwg[1]; + threadgroup float local_meanwgx[1]; + + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + sumx += xi; + sumx2 += xi * xi; + sumwg += wg; + sumwgx += wg * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + sumx += xi; + sumx2 += xi * xi; + sumwg += wg; + sumwgx += wg * xi; + } + } + } + } + + sumx = simd_sum(sumx); + sumx2 = simd_sum(sumx2); + sumwg = simd_sum(sumwg); + sumwgx = simd_sum(sumwgx); + + // Initialize shared memory + if (simd_group_id == 0) { + local_sumx[simd_lane_id] = 0; + local_sumx2[simd_lane_id] = 0; + local_sumwg[simd_lane_id] = 0; + local_sumwgx[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sumx[simd_group_id] = sumx; + local_sumx2[simd_group_id] = sumx2; + local_sumwg[simd_group_id] = sumwg; + local_sumwgx[simd_group_id] = sumwgx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + sumx = simd_sum(local_sumx[simd_lane_id]); + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + sumwg = simd_sum(local_sumwg[simd_lane_id]); + sumwgx = simd_sum(local_sumwgx[simd_lane_id]); + if (simd_lane_id == 0) { + float mean = sumx / axis_size; + float variance = sumx2 / axis_size - mean * mean; + + local_mean[0] = mean; + local_normalizer[0] = metal::precise::rsqrt(variance + eps); + local_meanwg[0] = sumwg / axis_size; + local_meanwgx[0] = sumwgx / axis_size; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = local_mean[0]; + float normalizer = local_normalizer[0]; + float meanwg = local_meanwg[0]; + float meanwgxc = local_meanwgx[0] - meanwg * mean; + float normalizer2 = normalizer * normalizer; + + // Write the outputs + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = (x[i + r] - mean) * normalizer; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + gx[i + r] = static_cast( + normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + gw[i + r] = static_cast(gi * xi); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = (x[i + r] - mean) * normalizer; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + gx[i + r] = static_cast( + normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + gw[i + r] = static_cast(gi * xi); + } + } + } + } +} + +// clang-format off +#define instantiate_layer_norm_single_row(name, itype) \ + template [[host_name("layer_norm" #name)]] [[kernel]] void \ + layer_norm_single_row( \ + const device itype* x, \ + const device itype* w, \ + const device itype* b, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + constant uint& b_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \ + vjp_layer_norm_single_row( \ + const device itype* x, \ + const device itype* w, \ + const device itype* g, \ + device itype* gx, \ + device itype* gw, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_layer_norm_looped(name, itype) \ + template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \ + layer_norm_looped( \ + const device itype* x, \ + const device itype* w, \ + const device itype* b, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + constant uint& b_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \ + vjp_layer_norm_looped( \ + const device itype* x, \ + const device itype* w, \ + const device itype* g, \ + device itype* gx, \ + device itype* gb, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_layer_norm(name, itype) \ + instantiate_layer_norm_single_row(name, itype) \ + instantiate_layer_norm_looped(name, itype) + +instantiate_layer_norm(float32, float) +instantiate_layer_norm(float16, half) +instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/quantized.h b/Source/Cmlx/mlx-generated/metal/quantized.h new file mode 100644 index 00000000..4f388b9f --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/quantized.h @@ -0,0 +1,1603 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + MLX_MTL_CONST short pack_factor = 32 / bits; + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint32_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint32_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld / pack_factor + bj), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template +METAL_FUNC void qmv_fast_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = bits > 2 ? 2 : 1; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void qmv_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + U sum = + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + U sum = + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void qvm_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = 32 / bits; + constexpr int tn = 32 / pack_factor; + constexpr int blocksize = SIMD_SIZE; + + typedef float U; + typedef struct { + uint32_t wi[tn]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = + tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; + w += out_col / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + biases += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.y * in_vec_size + simd_lid; + y += tid.y * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of blocksize + int remaining = in_vec_size % blocksize; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += blocksize) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += blocksize; + scales += blocksize * out_vec_size_g; + biases += blocksize * out_vec_size_g; + w += blocksize * out_vec_size_w; + } + } else { + for (int i = blocksize; i < in_vec_size; i += blocksize) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += blocksize; + scales += blocksize * out_vec_size_g; + biases += blocksize * out_vec_size_g; + w += blocksize * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + } else { + x_local = 0; + scale = 0; + bias = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void qmm_t_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = 32 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + x += y_row * K; + w += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void qmm_n_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = 32 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * K; + w += y_col / pack_factor; + scales += y_col / group_size; + biases += y_col / group_size; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant size_t* lhs_strides, + const constant size_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant size_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant size_t* w_strides, + const constant size_t* s_strides, + const constant size_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qvm_impl( + x, + w, + scales, + biases, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + qmm_t_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + qmm_n_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void bs_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void bs_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void bs_qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qvm_impl( + x, + w, + scales, + biases, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void bs_qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& K [[buffer(9)]], + const constant int& batch_ndims [[buffer(10)]], + const constant int* batch_shape [[buffer(11)]], + const constant size_t* lhs_strides [[buffer(12)]], + const constant size_t* rhs_strides [[buffer(13)]], + const constant int& x_batch_ndims [[buffer(14)]], + const constant int* x_shape [[buffer(15)]], + const constant size_t* x_strides [[buffer(16)]], + const constant int& w_batch_ndims [[buffer(17)]], + const constant int* w_shape [[buffer(18)]], + const constant size_t* w_strides [[buffer(19)]], + const constant size_t* s_strides [[buffer(20)]], + const constant size_t* b_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void bs_qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& K [[buffer(9)]], + const constant int& batch_ndims [[buffer(10)]], + const constant int* batch_shape [[buffer(11)]], + const constant size_t* lhs_strides [[buffer(12)]], + const constant size_t* rhs_strides [[buffer(13)]], + const constant int& x_batch_ndims [[buffer(14)]], + const constant int* x_shape [[buffer(15)]], + const constant size_t* x_strides [[buffer(16)]], + const constant int& w_batch_ndims [[buffer(17)]], + const constant int* w_shape [[buffer(18)]], + const constant size_t* w_strides [[buffer(19)]], + const constant size_t* s_strides [[buffer(20)]], + const constant size_t* b_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void affine_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device T* scales [[buffer(2)]], + device T* biases [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr T eps = T(1e-7); + constexpr int simd_size = 32; + constexpr int uint8_bits = 8; + constexpr T n_bins = (1 << bits) - 1; + constexpr int packs_per_int = uint8_bits / bits; + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + + static_assert( + group_size % simd_size == 0, + "Group size must be divisible by simd size."); + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * values_per_reduce; + size_t out_index = offset * writes_per_pack; + + T w_thread[values_per_reduce]; + T w_min = Limits::max; + T w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + T val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + w_min = simd_min(w_min); + w_max = simd_max(w_max); + + T scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + T edge = side ? w_min : w_max; + T q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + T bias = at_zero ? T(0) : edge; + + // Write out the scales and biases + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = scale; + biases[gindex] = bias; + } + + uint8_t output = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output += val << (bits * (i % packs_per_int)); + } + + if (packs_per_int < values_per_reduce && + i % packs_per_int == packs_per_int - 1) { + out[out_index + i / packs_per_int] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 0; j < writes_per_reduce - 1; j++) { + uint8_t sval = simd_shuffle_down(val, j + 1); + output += sval << (bits * (values_per_reduce + j + i)); + } + } + } + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } +} + +template +[[kernel]] void affine_quantize_scales_biases( + const device T* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device uint8_t* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr int uint8_bits = 8; + constexpr int packs_per_int = uint8_bits / bits; + constexpr T n_bins = (1 << bits) - 1; + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * packs_per_int; + size_t gindex = in_index / group_size; + + T scale = scales[gindex]; + T bias = biases[gindex]; + + uint8_t output = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < packs_per_int; i++) { + uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output += val << (bits * i); + } + } + out[offset] = output; +} + +template +[[kernel]] void affine_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr int uint8_bits = 8; + constexpr int packs_per_int = uint8_bits / bits; + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * packs_per_int; + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + uint val = w[offset]; + +#pragma clang loop unroll(full) + for (int i = 0; i < packs_per_int; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[oindex + i] = scale * d + bias; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/random.metal b/Source/Cmlx/mlx-generated/metal/random.metal new file mode 100644 index 00000000..5a1704d2 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/random.metal @@ -0,0 +1,103 @@ +// Copyright © 2023 Apple Inc. + +#include "utils.h" + +static constexpr constant uint32_t rotations[2][4] = { + {13, 15, 26, 6}, + {17, 29, 16, 24}}; + +union rbits { + uint2 val; + uchar4 bytes[2]; +}; + +rbits threefry2x32_hash(const thread uint2& key, uint2 count) { + uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; + + rbits v; + v.val.x = count.x + ks[0]; + v.val.y = count.y + ks[1]; + + for (int i = 0; i < 5; ++i) { + for (auto r : rotations[i % 2]) { + v.val.x += v.val.y; + v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); + v.val.y ^= v.val.x; + } + v.val.x += ks[(i + 1) % 3]; + v.val.y += ks[(i + 2) % 3] + i + 1; + } + + return v; +} + +[[kernel]] void rbitsc( + device const uint32_t* keys, + device char* out, + device const bool& odd, + device const uint& bytes_per_key, + uint2 grid_dim [[threads_per_grid]], + uint2 index [[thread_position_in_grid]]) { + auto kidx = 2 * index.x; + auto key = uint2(keys[kidx], keys[kidx + 1]); + auto half_size = grid_dim.y - odd; + out += index.x * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto bits = threefry2x32_hash( + key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); + size_t idx = size_t(index.y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +[[kernel]] void rbits( + device const uint32_t* keys, + device char* out, + device const bool& odd, + device const uint& bytes_per_key, + constant const int& ndim, + constant const int* key_shape, + constant const size_t* key_strides, + uint2 grid_dim [[threads_per_grid]], + uint2 index [[thread_position_in_grid]]) { + auto kidx = 2 * index.x; + auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim); + auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); + auto key = uint2(keys[k1_elem], keys[k2_elem]); + auto half_size = grid_dim.y - odd; + out += size_t(index.x) * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto bits = threefry2x32_hash( + key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); + size_t idx = size_t(index.y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/reduce.h b/Source/Cmlx/mlx-generated/metal/reduce.h new file mode 100644 index 00000000..8d1f609d --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduce.h @@ -0,0 +1,5 @@ +#pragma once +#include "reduction/reduce_all.h" +#include "reduction/reduce_col.h" +#include "reduction/reduce_init.h" +#include "reduction/reduce_row.h" diff --git a/Source/Cmlx/mlx-generated/metal/reduce_utils.h b/Source/Cmlx/mlx-generated/metal/reduce_utils.h new file mode 100644 index 00000000..f5ccc3f1 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduce_utils.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "atomic.h" +#include "reduction/ops.h" diff --git a/Source/Cmlx/mlx-generated/metal/reduction/ops.h b/Source/Cmlx/mlx-generated/metal/reduction/ops.h new file mode 100644 index 00000000..68ed1198 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduction/ops.h @@ -0,0 +1,204 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#define DEFINE_SIMD_REDUCE() \ + template = true> \ + T simd_reduce(T val) { \ + return simd_reduce_impl(val); \ + } \ + \ + template = true> \ + T simd_reduce(T val) { \ + for (short i = simd_size / 2; i > 0; i /= 2) { \ + val = operator()(val, simd_shuffle_down(val, i)); \ + } \ + return val; \ + } + +static constant constexpr const uint8_t simd_size = 32; + +union bool4_or_uint { + bool4 b; + unsigned int i; +}; + +struct None { + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_store_explicit(out, val, offset); + } +}; + +template +struct And { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { + return simd_all(val); + } + + static constexpr constant bool init = true; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + size_t offset = 0) { + if (!val) { + bool4_or_uint update; + update.b = {true, true, true, true}; + update.b[elem_idx] = false; + mlx_atomic_fetch_and_explicit(out, update.i, offset); + } + } + + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { + if (!val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out &= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a && b; + } +}; + +template +struct Or { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { + return simd_any(val); + } + + static constexpr constant bool init = false; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + size_t offset = 0) { + if (val) { + bool4_or_uint update; + update.b = {false, false, false, false}; + update.b[elem_idx] = true; + mlx_atomic_fetch_or_explicit(out, update.i, offset); + } + } + + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { + if (val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out |= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a || b; + } +}; + +template +struct Sum { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_sum(val); + } + + static constexpr constant U init = U(0); + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_add_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a + b; + } +}; + +template +struct Prod { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_product(val); + } + + static constexpr constant U init = U(1); + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_mul_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a * b; + } +}; + +template +struct Min { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_min(val); + } + + static constexpr constant U init = Limits::max; + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_min_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a < b ? a : b; + } +}; + +template +struct Max { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_max(val); + } + + static constexpr constant U init = Limits::min; + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_max_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a > b ? a : b; + } +}; diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h new file mode 100644 index 00000000..381d5e20 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h @@ -0,0 +1,61 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void all_reduce( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& in_size [[buffer(2)]], + const constant size_t& row_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + + U total = Op::init; + int64_t start_idx = gid.y * row_size; + int64_t actual_row = + (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; + int64_t blocks = actual_row / (lsize.x * N_READS); + int extra = actual_row - blocks * (lsize.x * N_READS); + extra -= lid.x * N_READS; + start_idx += lid.x * N_READS; + in += start_idx; + + if (extra >= N_READS) { + blocks++; + extra = 0; + } + + for (int64_t b = 0; b < blocks; b++) { + for (int i = 0; i < N_READS; i++) { + total = op(static_cast(in[i]), total); + } + in += lsize.x * N_READS; + } + if (extra > 0) { + for (int i = 0; i < extra; i++) { + total = op(static_cast(in[i]), total); + } + } + + // Reduction within simd group + total = op.simd_reduce(total); + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + shared_vals[simd_group_id] = total; + } + + // Reduction within thread group + threadgroup_barrier(mem_flags::mem_threadgroup); + total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; + total = op.simd_reduce(total); + } + + if (lid.x == 0) { + out[gid.y] = total; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h new file mode 100644 index 00000000..52e763dd --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h @@ -0,0 +1,331 @@ +// Copyright © 2023-2024 Apple Inc. + +template < + typename T, + typename U, + typename Op, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void col_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { + Op op; + looped_elem_to_loc loop; + const device T* row; + + // Case 1: Small row small column + if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { + U totals[31]; + for (int i = 0; i < 31; i++) { + totals[i] = Op::init; + } + + short stride = reduction_stride; + short size = reduction_size; + short blocks = stride / N_READS; + short extra = stride - blocks * N_READS; + + size_t out_idx = tid.x + tsize.y * size_t(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + + for (uint r = 0; r < non_col_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + for (short i = 0; i < size; i++) { + for (short j = 0; j < blocks; j++) { + for (short k = 0; k < N_READS; k++) { + totals[j * N_READS + k] = + op(totals[j * N_READS + k], + static_cast(row[i * stride + j * N_READS + k])); + } + } + for (short k = 0; k < extra; k++) { + totals[blocks * N_READS + k] = + op(totals[blocks * N_READS + k], + static_cast(row[i * stride + blocks * N_READS + k])); + } + } + + loop.next(reduce_shape, reduce_strides); + } + out += out_idx * reduction_stride; + for (short j = 0; j < stride; j++) { + out[j] = totals[j]; + } + } + + // Case 2: Long row small column + else if (reduction_size * non_col_reductions < 32) { + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = Op::init; + } + + short size = reduction_size; + size_t offset = size_t(tid.x) * N_READS; + bool safe = offset + N_READS <= reduction_stride; + short extra = reduction_stride - offset; + + size_t out_idx = tid.y + tsize.z * size_t(tid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + offset; + + for (uint r = 0; r < non_col_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + if (safe) { + for (short i = 0; i < size; i++) { + for (short j = 0; j < N_READS; j++) { + totals[j] = + op(static_cast(row[i * reduction_stride + j]), totals[j]); + } + } + } else { + for (short i = 0; i < size; i++) { + for (short j = 0; j < extra; j++) { + totals[j] = + op(static_cast(row[i * reduction_stride + j]), totals[j]); + } + } + } + + loop.next(reduce_shape, reduce_strides); + } + out += out_idx * reduction_stride + offset; + if (safe) { + for (short i = 0; i < N_READS; i++) { + out[i] = totals[i]; + } + } else { + for (short i = 0; i < extra; i++) { + out[i] = totals[i]; + } + } + } + + // Case 3: Long row medium column + else { + threadgroup U shared_vals[1024]; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = Op::init; + } + + short stride = reduction_stride; + short lid = simd_group_id * simd_size + simd_lane_id; + short2 tile((stride + N_READS - 1) / N_READS, 32); + short2 offset((lid % tile.x) * N_READS, lid / tile.x); + short sm_stride = tile.x * N_READS; + bool safe = offset.x + N_READS <= stride; + + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; + + // Read cooperatively and contiguously and aggregate the partial results. + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y, reduce_shape, reduce_strides); + for (size_t r = offset.y; r < total; r += simd_size) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + if (safe) { + for (int i = 0; i < N_READS; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset.x + i < stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(simd_size, reduce_shape, reduce_strides); + } + + // Each thread holds N_READS partial results but the simdgroups are not + // aligned to do the reduction across the simdgroup so we write our results + // in the shared memory and read them back according to the simdgroup. + for (int i = 0; i < N_READS; i++) { + shared_vals[offset.y * sm_stride + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_READS; i++) { + totals[i] = op.simd_reduce( + shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + short column = simd_group_id * N_READS; + out += out_idx * reduction_stride + column; + if (column + N_READS <= stride) { + for (int i = 0; i < N_READS; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < stride; i++) { + out[i] = totals[i]; + } + } + } + } +} + +/** + * Our approach is the following simple looped approach: + * 1. Each thread keeps running totals for BN / n_simdgroups outputs. + * 2. Load a tile BM, BN in registers and accumulate in the running totals + * 3. Move ahead by BM steps until the column axis and the non column + * reductions are exhausted. + * 6. If BM == 32 then transpose in SM and simd reduce the running totals. + * Otherwise write in shared memory and BN threads accumulate the running + * totals with a loop. + * 7. Write them to the output + */ +template +[[kernel]] void col_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 4; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + looped_elem_to_loc loop; + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + size_t column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y, reduce_shape, reduce_strides); + for (size_t r = offset.y; r < total; r += BM) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + if (BM == 32) { + constexpr int n_outputs = BN / n_simdgroups; + static_assert( + BM != 32 || n_outputs == n_reads, + "The tile should be selected such that n_outputs == n_reads"); + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + size_t out_column = BN * gid.x + out_offset.x; + out += out_idx * reduction_stride + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } + + // Each thread holds n_reads partial results. We write them all out to shared + // memory and threads with offset.y == 0 aggregate the columns and write the + // outputs. + else { + short x_block = offset.x / n_reads; + for (int i = 0; i < n_reads; i++) { + shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (offset.y == 0) { + for (int i = 0; i < n_reads; i++) { + for (int j = 1; j < BM; j++) { + totals[i] = + op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); + } + } + } + + // Write the output. + if (offset.y == 0) { + out += out_idx * reduction_stride + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h new file mode 100644 index 00000000..604efa78 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h @@ -0,0 +1,8 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void init_reduce( + device T* out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h new file mode 100644 index 00000000..af8a01da --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h @@ -0,0 +1,366 @@ +// Copyright © 2023-2024 Apple Inc. + +// Row reduction utilities +// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup +// - `threadgroup_reduce` collaborative reduction in the threadgroup such that +// lid.x == 0 holds the reduced value +// - `thread_reduce` simple loop and reduce the row + +/** + * The thread group collaboratively reduces across the rows with bounds + * checking. In the end each thread holds a part of the reduction. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* inputs[N_WRITES], + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + Op op; + + // Set up the accumulator registers + for (int i = 0; i < N_WRITES; i++) { + totals[i] = Op::init; + } + + // Loop over the reduction size within thread group + for (int i = 0; i < blocks; i++) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + + inputs[j] += lsize_x * N_READS; + } + } + + // Separate case for the last set as we close the reduction size + int index = lid_x * N_READS; + if (index + N_READS <= extra) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } else { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; index + i < extra; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } +} + +/** + * Consecutive rows in a contiguous array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const constant size_t& reduction_size, + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + inputs[0] = in + lid_x * N_READS; + for (int i = 1; i < N_READS; i++) { + inputs[i] = inputs[i - 1] + reduction_size; + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Consecutive rows in an arbitrarily ordered array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const size_t row_idx, + int blocks, + int extra, + const constant int* shape, + const constant size_t* strides, + const constant int& ndim, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + in += lid_x * N_READS; + for (int i = 0; i < N_READS; i++) { + inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Reduce within the threadgroup. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void threadgroup_reduce( + thread U totals[N_WRITES], + threadgroup U* shared_vals, + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + + // Simdgroup first + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(totals[i]); + } + + // Across simdgroups + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + for (int i = 0; i < N_WRITES; i++) { + shared_vals[simd_group_id * N_WRITES + i] = totals[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + U values[N_WRITES]; + for (int i = 0; i < N_WRITES; i++) { + values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] + : op.init; + } + + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(values[i]); + } + } +} + +template +METAL_FUNC void +thread_reduce(thread U& total, const device T* row, int blocks, int extra) { + Op op; + for (int i = 0; i < blocks; i++) { + U vals[N_READS]; + for (int j = 0; j < N_READS; j++) { + vals[j] = row[j]; + } + for (int j = 0; j < N_READS; j++) { + total = op(vals[j], total); + } + row += N_READS; + } + for (int i = 0; i < extra; i++) { + total = op(*row++, total); + } +} + +// Reduction kernels +// - `row_reduce_small` depending on the non-row reductions and row size it +// either just loops over everything or a simd collaboratively reduces the +// non_row reductions. In the first case one thread is responsible for one +// output on the 2nd one simd is responsible for one output. +// - `row_reduce_simple` simple contiguous row reduction +// - `row_reduce_looped` simply loop and reduce each row for each non-row +// reduction. One threadgroup is responsible for one output. + +template < + typename T, + typename U, + typename Op, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& row_size [[buffer(2)]], + const constant size_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { + Op op; + + U total_val = Op::init; + looped_elem_to_loc loop; + + // Precompute some row reduction numbers + const device T* row; + int blocks = row_size / N_READS; + int extra = row_size % N_READS; + + if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { + // Simple loop over non_row_reductions and reduce the row in the thread. + size_t out_idx = tid.x + tsize.y * size_t(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + + for (uint r = 0; r < non_row_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + thread_reduce(total_val, row, blocks, extra); + loop.next(reduce_shape, reduce_strides); + } + + out[out_idx] = total_val; + } else { + // Collaboratively reduce over non_row_reductions in the simdgroup. Each + // thread reduces every 32nd row and then a simple simd reduce. + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); + + loop.next(simd_lane_id, reduce_shape, reduce_strides); + + for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + thread_reduce(total_val, row, blocks, extra); + loop.next(simd_size, reduce_shape, reduce_strides); + } + + total_val = op.simd_reduce(total_val); + + if (simd_lane_id == 0) { + out[out_idx] = total_val; + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +[[kernel]] void row_reduce_simple( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + threadgroup U shared_vals[simd_size * N_WRITES]; + U totals[N_WRITES]; + + // Move to the row + size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); + if (out_idx + N_WRITES > out_size) { + out_idx = out_size - N_WRITES; + } + in += out_idx * reduction_size; + out += out_idx; + + // Each thread reduces across the row + int blocks = reduction_size / (lsize.x * N_READS); + int extra = reduction_size - blocks * (lsize.x * N_READS); + per_thread_row_reduce( + totals, in, reduction_size, blocks, extra, lsize.x, lid.x); + + // Reduce across the threadgroup + threadgroup_reduce( + totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + for (int i = 0; i < N_WRITES; i++) { + out[i] = totals[i]; + } + } +} + +template < + typename T, + typename U, + typename Op, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& row_size [[buffer(2)]], + const constant size_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + U total = Op::init; + + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + + // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it + // needs a small refactor. + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + + looped_elem_to_loc loop; + const device T* row; + int blocks = row_size / (lsize.x * N_READS); + int extra = row_size - blocks * (lsize.x * N_READS); + + for (size_t i = 0; i < non_row_reductions; i++) { + row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); + + // Each thread reduces across the row + U row_total; + per_thread_row_reduce( + &row_total, &row, blocks, extra, lsize.x, lid.x); + + // Aggregate across rows + total = op(total, row_total); + + loop.next(reduce_shape, reduce_strides); + } + + // Reduce across the threadgroup + threadgroup_reduce( + &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + out[out_idx] = total; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/rms_norm.metal b/Source/Cmlx/mlx-generated/metal/rms_norm.metal new file mode 100644 index 00000000..9b52a986 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/rms_norm.metal @@ -0,0 +1,440 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "bf16.h" +#include "defines.h" +#include "utils.h" + +using namespace metal; + +template +[[kernel]] void rms_single_row( + const device T* x, + const device T* w, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + threadgroup float* local_inv_mean [[threadgroup(0)]], + threadgroup float* local_sums [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float acc = 0; + x += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i]; + acc += xi * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + float xi = x[i]; + acc += xi * xi; + } + } + } + acc = simd_sum(acc); + // Initialize shared memory + if (simd_group_id == 0) { + local_sums[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sums[simd_group_id] = acc; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + acc = simd_sum(local_sums[simd_lane_id]); + if (simd_lane_id == 0) { + local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the outputs + out += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); + } + } + } +} + +template +[[kernel]] void rms_looped( + const device T* x, + const device T* w, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + threadgroup float* local_inv_mean [[threadgroup(0)]], + threadgroup float* local_sums [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float acc = 0; + x += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + acc += xi * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + acc += xi * xi; + } + } + } + } + acc = simd_sum(acc); + // Initialize shared memory + if (simd_group_id == 0) { + local_sums[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sums[simd_group_id] = acc; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + acc = simd_sum(local_sums[simd_lane_id]); + if (simd_lane_id == 0) { + local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the outputs + out += gid * size_t(axis_size) + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[r + i] = w[w_stride * (i + r)] * + static_cast(x[r + i] * local_inv_mean[0]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + out[r + i] = w[w_stride * (i + r)] * + static_cast(x[r + i] * local_inv_mean[0]); + } + } + } + } +} + +template +[[kernel]] void vjp_rms_single_row( + const device T* x, + const device T* w, + const device T* g, + device T* gx, + device T* gw, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Advance the input pointers + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + + // Allocate registers for the computation and accumulators + float thread_x[N_READS]; + float thread_w[N_READS]; + float thread_g[N_READS]; + float sumx2 = 0; + float sumgwx = 0; + + // Allocate shared memory to implement the reduction + constexpr int SIMD_SIZE = 32; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_sumgwx[SIMD_SIZE]; + threadgroup float local_normalizer[1]; + threadgroup float local_meangwx[1]; + + // Read and accumulate locally + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = x[i]; + thread_w[i] = w[w_stride * i]; + thread_g[i] = g[i]; + + sumx2 += thread_x[i] * thread_x[i]; + sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = x[i]; + thread_w[i] = w[w_stride * i]; + thread_g[i] = g[i]; + + sumx2 += thread_x[i] * thread_x[i]; + sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; + } + } + } + + // Accumulate across threads + sumx2 = simd_sum(sumx2); + sumgwx = simd_sum(sumgwx); + if (simd_group_id == 0) { + local_sumx2[simd_lane_id] = 0; + local_sumgwx[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == 0) { + local_sumx2[simd_group_id] = sumx2; + local_sumgwx[simd_group_id] = sumgwx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + sumgwx = simd_sum(local_sumgwx[simd_lane_id]); + if (simd_lane_id == 0) { + local_meangwx[0] = sumgwx / axis_size; + local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float meangwx = local_meangwx[0]; + float normalizer = local_normalizer[0]; + float normalizer3 = normalizer * normalizer * normalizer; + + // Write the outputs + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + gx[i] = static_cast( + thread_g[i] * thread_w[i] * normalizer - + thread_x[i] * meangwx * normalizer3); + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + gx[i] = static_cast( + thread_g[i] * thread_w[i] * normalizer - + thread_x[i] * meangwx * normalizer3); + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } + } + } +} + +template +[[kernel]] void vjp_rms_looped( + const device T* x, + const device T* w, + const device T* g, + device T* gx, + device T* gw, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Advance the input pointers + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + + // Allocate registers for the accumulators + float sumx2 = 0; + float sumgwx = 0; + + // Allocate shared memory to implement the reduction + constexpr int SIMD_SIZE = 32; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_sumgwx[SIMD_SIZE]; + threadgroup float local_normalizer[1]; + threadgroup float local_meangwx[1]; + + // Read and accumulate locally + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + sumx2 += xi * xi; + sumgwx += xi * wi * gi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + sumx2 += xi * xi; + sumgwx += xi * wi * gi; + } + } + } + } + + // Accumulate across threads + sumx2 = simd_sum(sumx2); + sumgwx = simd_sum(sumgwx); + if (simd_group_id == 0) { + local_sumx2[simd_lane_id] = 0; + local_sumgwx[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == 0) { + local_sumx2[simd_group_id] = sumx2; + local_sumgwx[simd_group_id] = sumgwx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + sumgwx = simd_sum(local_sumgwx[simd_lane_id]); + if (simd_lane_id == 0) { + local_meangwx[0] = sumgwx / axis_size; + local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float meangwx = local_meangwx[0]; + float normalizer = local_normalizer[0]; + float normalizer3 = normalizer * normalizer * normalizer; + + // Write the outputs + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + gx[i + r] = + static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); + gw[i + r] = static_cast(gi * xi * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + gx[i + r] = + static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); + gw[i + r] = static_cast(gi * xi * normalizer); + } + } + } + } +} + +// clang-format off +#define instantiate_rms_single_row(name, itype) \ + template [[host_name("rms" #name)]] [[kernel]] void \ + rms_single_row( \ + const device itype* x, \ + const device itype* w, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + threadgroup float* local_inv_mean [[threadgroup(0)]], \ + threadgroup float* local_sums [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name("vjp_rms" #name)]] [[kernel]] void \ + vjp_rms_single_row( \ + const device itype* x, \ + const device itype* w, \ + const device itype* g, \ + device itype* gx, \ + device itype* gw, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_rms_looped(name, itype) \ + template [[host_name("rms_looped" #name)]] [[kernel]] void \ + rms_looped( \ + const device itype* x, \ + const device itype* w, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + threadgroup float* local_inv_mean [[threadgroup(0)]], \ + threadgroup float* local_sums [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \ + vjp_rms_looped( \ + const device itype* x, \ + const device itype* w, \ + const device itype* g, \ + device itype* gx, \ + device itype* gw, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_rms(name, itype) \ + instantiate_rms_single_row(name, itype) \ + instantiate_rms_looped(name, itype) + +instantiate_rms(float32, float) +instantiate_rms(float16, half) +instantiate_rms(bfloat16, bfloat16_t) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/rope.metal b/Source/Cmlx/mlx-generated/metal/rope.metal new file mode 100644 index 00000000..cc9a4648 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/rope.metal @@ -0,0 +1,261 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "bf16.h" +#include "utils.h" +template +void rope_single_impl( + const device T* in, + device T* out, + constant const int& offset, + const float inv_freq, + constant const float& scale, + constant const size_t& stride, + uint2 pos, + uint2 grid) { + float L = scale * static_cast(offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = metal::fast::cos(theta); + float sintheta = metal::fast::sin(theta); + + // Compute the input and output indices + uint index_1, index_2; + if (traditional) { + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + grid.x; + } + + // Read and write the output + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +[[kernel]] void rope_single( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t& stride, + constant const float& base [[buffer(10)]], + uint2 pos [[thread_position_in_grid]], + uint2 grid [[threads_per_grid]]) { + float d = static_cast(pos.x) / static_cast(grid.x); + float inv_freq = metal::exp2(-d * base); + rope_single_impl( + in, out, offset, inv_freq, scale, stride, pos, grid); +} + +template +[[kernel]] void rope_single_freqs( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t& stride, + const device float* freqs [[buffer(10)]], + constant const size_t& freq_stride [[buffer(11)]], + uint2 pos [[thread_position_in_grid]], + uint2 grid [[threads_per_grid]]) { + float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); + rope_single_impl( + in, out, offset, inv_freq, scale, stride, pos, grid); +} + +template +void rope_impl( + const device T* in, + device T* out, + constant const int& offset, + const float inv_freq, + constant const float& scale, + constant const size_t strides[3], + constant const size_t out_strides[3], + constant const size_t& n_batch, + uint3 pos, + uint3 grid) { + float L = scale * static_cast(pos.y + offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = metal::fast::cos(theta); + float sintheta = metal::fast::sin(theta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + grid.x * out_strides[2]; + in_index_1 = + pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + grid.x * strides[2]; + } + for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +template +[[kernel]] void rope( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t strides[3], + constant const size_t out_strides[3], + constant const size_t& n_batch, + constant const float& base [[buffer(10)]], + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + float d = static_cast(pos.x) / static_cast(grid.x); + float inv_freq = metal::exp2(-d * base); + rope_impl( + in, + out, + offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + grid); +} + +template +[[kernel]] void rope_freqs( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t strides[3], + constant const size_t out_strides[3], + constant const size_t& n_batch, + const device float* freqs [[buffer(10)]], + constant const size_t& freq_stride [[buffer(11)]], + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); + rope_impl( + in, + out, + offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + grid); +} + +// clang-format off +#define instantiate_rope_g(name, type, traditional, forward) \ + template [[host_name("rope_" #name)]] [[kernel]] void \ + rope( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t strides[3], \ + constant const size_t out_strides[3], \ + constant const size_t& n_batch, \ + constant const float& base [[buffer(10)]], \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); \ + template [[host_name("rope_freqs_" #name)]] \ + [[kernel]] void rope_freqs( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t strides[3], \ + constant const size_t out_strides[3], \ + constant const size_t& n_batch, \ + const device float* freqs [[buffer(10)]], \ + constant const size_t& freq_stride [[buffer(11)]], \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); + +#define instantiate_rope_s(name, type, traditional, forward) \ + template [[host_name("rope_single_" #name)]] [[kernel]] void \ + rope_single( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t& stride, \ + constant const float& base [[buffer(10)]], \ + uint2 pos [[thread_position_in_grid]], \ + uint2 grid [[threads_per_grid]]); \ + template [[host_name("rope_single_freqs_" #name)]] \ + [[kernel]] void rope_single_freqs( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t& stride, \ + const device float* freqs [[buffer(10)]], \ + constant const size_t& freq_stride [[buffer(11)]], \ + uint2 pos [[thread_position_in_grid]], \ + uint2 grid [[threads_per_grid]]); + +#define instantiate_rope(name, type, traditional, forward) \ + instantiate_rope_s(name, type, traditional, forward) \ + instantiate_rope_g(name, type, traditional, forward) + +instantiate_rope(traditional_float16, half, true, true) +instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) +instantiate_rope(traditional_float32, float, true, true) +instantiate_rope(float16, half, false, true) +instantiate_rope(bfloat16, bfloat16_t, false, true) +instantiate_rope(float32, float, false, true) +instantiate_rope(vjp_traditional_float16, half, true, false) +instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false) +instantiate_rope(vjp_traditional_float32, float, true, false) +instantiate_rope(vjp_float16, half, false, false) +instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) +instantiate_rope(vjp_float32, float, false, false) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal new file mode 100644 index 00000000..8ec0a579 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal @@ -0,0 +1,1469 @@ +#include +#include + +#include "steel/defines.h" +#include "steel/gemm/transforms.h" +#include "steel/utils.h" + +#include "scaled_dot_product_attention_params.h" +using namespace metal; + +using namespace mlx::steel; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderFA { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoaderFA( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } + METAL_FUNC void next(short n) { + src += n * tile_stride; + } +}; + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMAFA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + ushort sid; + ushort slid; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMAFA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + slid = simd_lane_id; + sid = simd_group_id; + + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + METAL_FUNC void rescale_output(const threadgroup float* Corrections) { + // Loop over all simdgroup tiles + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + short row = sm + tm + i * TM_stride; + float scale_value = Corrections[row]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + // int offset = (i * TM_stride) * ldc + (j * TN_stride); + accum[0] *= scale_value; + accum[1] *= scale_value; + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* C, const int ldc) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out C + C[offset] = outs[0]; + C[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_to_tgp_memory( + threadgroup U* C, + const int ldc, + short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } + + METAL_FUNC void clear_results() { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + results[i * TN + j] = simdgroup_matrix(0); + } + } + } +}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct FastAttentionKernel { + STEEL_CONST short tgp_padding = 16 / sizeof(T); + STEEL_CONST short float_padding = 16 / sizeof(float); + STEEL_CONST short tgp_mem_size_q = + transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_k = + transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_v = + transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); + + // maxes, rowsums, rescale + STEEL_CONST short tgp_mem_size_corrections = + 4 * (BM * sizeof(float) + float_padding); + + STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; + + STEEL_CONST short tgp_mem_size = share_kv_smem + ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + tgp_mem_size_v; + + STEEL_CONST short tgp_size = WM * WN * 32; + + static_assert(transpose_q == false, "Expected Q not transposed."); + static_assert(transpose_k == true, "Expected K transposed."); + static_assert(transpose_v == false, "Expected V not transposed."); + static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); + + using loader_q_t = BlockLoaderFA< + T, + transpose_q ? BK : BM, + transpose_q ? BM : BK, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + !transpose_q, + tgp_size>; + + using loader_k_t = BlockLoaderFA< + T, + transpose_k ? BN : BK, + transpose_k ? BK : BN, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + transpose_k, + tgp_size>; + + using loader_v_t = BlockLoaderFA< + T, + transpose_v ? BK : BN, + transpose_v ? BN : BK, + transpose_v ? BN + tgp_padding : BK + tgp_padding, + transpose_v, + tgp_size>; + + using mma_qk_t = BlockMMAFA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + AccumType, + Epilogue>; + + using mma_sv_t = BlockMMAFA< + T, + U, + BM, + BK, + BN, + WM, + WN, + false, + transpose_v, + BN + tgp_padding, + BK + tgp_padding, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_k_t& loader_b, + thread mma_qk_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + (void)tgp_bm; + + short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + // not valid for gemm_k_iterations > 1 (so, BK == d_k) + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + } + + static METAL_FUNC void initialize_corrections( + threadgroup float* C, + uint simd_lane_id, + uint simd_group_id) { + if (simd_group_id == 0) { + threadgroup float* maxes = C; + threadgroup float* sums = C + (BM + float_padding); + threadgroup float* o_rescale = sums + (BM + float_padding); + threadgroup float* output_rescale = o_rescale + (BM + float_padding); + + if (simd_lane_id < BM) { + maxes[simd_lane_id] = -INFINITY; // m_i + sums[simd_lane_id] = 0.f; // l_i + o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) + output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + } + } + } + + static METAL_FUNC void rescale_ss( + threadgroup T* Ss, + threadgroup float* Corrections, + uint simd_group_id, + uint simd_lane_id, + short2 local_blocks, + float alpha) { + if (simd_group_id == 0) { + short row_offset = BM + float_padding; + threadgroup float* maxes = Corrections; + threadgroup float* sums = Corrections + row_offset; + threadgroup float* o_rescale = sums + row_offset; + threadgroup float* output_scales = o_rescale + row_offset; + + if (simd_lane_id < uint(local_blocks.y)) { + float m_i_old = maxes[simd_lane_id]; + float l_i_old = sums[simd_lane_id]; + + float m_i_new = m_i_old; + float l_i_new = l_i_old; + + short offset = simd_lane_id * (BN + tgp_padding); + + float m_ij = -INFINITY; + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + m_ij = max(m_ij, val); + } + + m_i_new = max(m_ij, m_i_new); + + float rowsum = 0.f; // lij + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + float P_i_j = exp(val - m_ij); + rowsum += P_i_j; + P_i_j = P_i_j * exp(m_ij - m_i_new); + Ss[offset + j] = T(P_i_j); + } + + l_i_new = + exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; + maxes[simd_lane_id] = m_i_new; + sums[simd_lane_id] = l_i_new; + float rescale = l_i_old * exp(m_i_old - m_i_new); + o_rescale[simd_lane_id] = rescale; + output_scales[simd_lane_id] = 1.0 / l_i_new; + } + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device U* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + threadgroup T* Qs [[threadgroup(0)]], + threadgroup T* Ks [[threadgroup(1)]], + threadgroup T* Ss [[threadgroup(2)]], + threadgroup T* Vs [[threadgroup(3)]], + threadgroup float* Corrections [[threadgroup(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in Q, O; and head in K, V. + const int c_row = tid_y * BM; + + Q += transpose_q ? c_row : c_row * params->ldq; + thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); + + short tgp_bm = min(BM, params->M - c_row); + short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + loader_q.load_safe(tile_dims_Q); + + initialize_corrections(Corrections, simd_lane_id, simd_group_id); + + O += c_row * params->ldo; + + // Prepare threadgroup mma operation + thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); + thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); + thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); + thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); + + for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; + n_block++) { + short c_col = BN; + + // Prepare threadgroup loading operations + short gemm_k_iterations = params->gemm_k_iterations_aligned; + short tgp_bn_qk = min(BN, params->N - c_col * n_block); + threadgroup_barrier(mem_flags::mem_none); + + /////////////////////////////////////////////////////////////////////////////// + { // Loop over K - unaligned case + + if (tgp_bm == BM && tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } else if (tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else if (tgp_bm == BM) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } + } + + mma_qk_op.store_result_to_tgp_memory( + Ss, BN + tgp_padding, short2(BN, BM)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + rescale_ss( + Ss, + Corrections, + simd_group_id, + simd_lane_id, + short2(tgp_bn_qk, tgp_bm), + params->alpha); + + loader_v.load_safe(short2(BK, tgp_bn_qk)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); + mma_softmax_sv_op.rescale_output(o_scales); + + mma_softmax_sv_op.mma(Ss, Vs); + + threadgroup float* final_output_scales = + Corrections + 3 * (BM + float_padding); + + mma_softmax_sv_op.rescale_output(final_output_scales); + + loader_v.next(); + loader_k.next(BN); + + mma_qk_op.clear_results(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + } +}; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using attention_kernel = FastAttentionKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_v, + MN_aligned, + K_aligned>; + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* Q_bstrides = batch_strides; + const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); + + Q += batch_offsets.x; + K += batch_offsets.y; + V += batch_offsets.y; + + } else { + Q += params->batch_stride_q * tid.z; + K += params->batch_stride_k * tid.z; + V += params->batch_stride_v * tid.z; + } + + // same shape as input + O += params->batch_stride_o * tid.z; + threadgroup T Qs[attention_kernel::tgp_mem_size_q]; + threadgroup T Ss[attention_kernel::tgp_mem_size_s]; + threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; + + if (attention_kernel::share_kv_smem) { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } else { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T Vs[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } +} + +#define instantiate_fast_inference_self_attention_kernel( \ + itype, otype, bm, bn, bk, wm, wn) \ + template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ + "_itype_" #itype)]] [[kernel]] void \ + attention( \ + const device itype* Q [[buffer(0)]], \ + const device itype* K [[buffer(1)]], \ + const device itype* V [[buffer(2)]], \ + device otype* O [[buffer(3)]], \ + const constant MLXFastAttentionParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 64, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 128, + 2, + 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); + +template < + typename T, + typename T2, + typename T4, + uint16_t TILE_SIZE_CONST, + uint16_t NSIMDGROUPS> +[[kernel]] void fast_inference_sdpa_compute_partials_template( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + const device uint64_t& L [[buffer(3)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], + device float* O_partials [[buffer(5)]], + device float* p_lse [[buffer(6)]], + device float* p_maxes [[buffer(7)]], + threadgroup T* threadgroup_block [[threadgroup(0)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr const size_t DK = 128; + constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8; + constexpr const size_t THREADS_PER_SIMDGROUP = 32; + constexpr const uint iter_offset = NSIMDGROUPS * 4; + const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS; + uint kv_head_offset_factor = tid.x; + if (is_gqa) { + int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS; + kv_head_offset_factor = tid.x / q_kv_head_ratio; + } + constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; + constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = + TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); + constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; + constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * + SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * + NSIMDGROUPS; + + threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; +#pragma clang loop unroll(full) + for (uint i = 0; i < 8; i++) { + smemFlush + [simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: multiple query sequence length for speculative decoding + const uint tgroup_query_head_offset = + tid.x * DK + tid.z * (params.N_Q_HEADS * DK); + + const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; + const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; + const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; + + const device T* baseK = + K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; + const device T* baseQ = Q + tgroup_query_head_offset; + + device T4* simdgroupQueryData = (device T4*)baseQ; + + constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; + float threadAccum[ACCUM_PER_GROUP]; + +#pragma clang loop unroll(full) + for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; + threadAccumIndex++) { + threadAccum[threadAccumIndex] = -INFINITY; + } + + uint KROW_ACCUM_INDEX = 0; + + const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; + const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; + const bool LAST_TILE_ALIGNED = + (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); + + T4 thread_data_x4; + T4 thread_data_y4; + if (!LAST_TILE || LAST_TILE_ALIGNED) { + thread_data_x4 = *(simdgroupQueryData + simd_lane_id); +#pragma clang loop unroll(full) + for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; + KROW += NSIMDGROUPS) { + const uint KROW_OFFSET = KROW * DK; + const device T* baseKRow = baseK + KROW_OFFSET; + device T4* keysData = (device T4*)baseKRow; + thread_data_y4 = *(keysData + simd_lane_id); + T kq_scalar = dot(thread_data_x4, thread_data_y4); + threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); + KROW_ACCUM_INDEX++; + } + } else { + thread_data_x4 = *(simdgroupQueryData + simd_lane_id); + const uint START_ROW = tid.y * TILE_SIZE_CONST; + const device T* baseKThisHead = + K + tgroup_k_batch_offset + tgroup_k_head_offset; + + for (size_t KROW = START_ROW + simd_group_id; KROW < L; + KROW += NSIMDGROUPS) { + const uint KROW_OFFSET = KROW * DK; + const device T* baseKRow = baseKThisHead + KROW_OFFSET; + device T4* keysData = (device T4*)baseKRow; + thread_data_y4 = *(keysData + simd_lane_id); + T kq_scalar = dot(thread_data_x4, thread_data_y4); + threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); + KROW_ACCUM_INDEX++; + } + } + threadgroup float* smemP = (threadgroup float*)threadgroup_block; + +#pragma clang loop unroll(full) + for (size_t i = 0; i < P_VEC4; i++) { + thread_data_x4 = + T4(threadAccum[4 * i], + threadAccum[4 * i + 1], + threadAccum[4 * i + 2], + threadAccum[4 * i + 3]); + simdgroup_barrier(mem_flags::mem_none); + thread_data_y4 = simd_sum(thread_data_x4); + if (simd_lane_id == 0) { + const uint base_smem_p_offset = i * iter_offset + simd_group_id; + smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x); + smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y); + smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z); + smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float groupMax; + float lse = 0.f; + + constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; + constexpr const size_t ACCUM_ARRAY_LENGTH = + TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; + float4 pvals[ACCUM_ARRAY_LENGTH]; + +#pragma clang loop unroll(full) + for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; + accum_array_iter++) { + pvals[accum_array_iter] = float4(-INFINITY); + } + + if (TILE_SIZE_CONST == 64) { + threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block; + float2 vals = smemPtrFlt2[simd_lane_id]; + vals *= params.INV_ALPHA; + float maxval = max(vals.x, vals.y); + simdgroup_barrier(mem_flags::mem_none); + groupMax = simd_max(maxval); + + float2 expf_shifted = exp(vals - groupMax); + float sumExpLocal = expf_shifted.x + expf_shifted.y; + simdgroup_barrier(mem_flags::mem_none); + float tgroupExpSum = simd_sum(sumExpLocal); + + lse = log(tgroupExpSum); + float2 local_p_hat = expf_shifted / tgroupExpSum; + pvals[0].x = local_p_hat.x; + pvals[0].y = local_p_hat.y; + smemPtrFlt2[simd_lane_id] = float2(0.f); + } + constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64; + constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128; + + if (TILE_SIZE_LARGER_THAN_64) { + float maxval = -INFINITY; + threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block; +#pragma clang loop unroll(full) + for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { + float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP]; + vals *= params.INV_ALPHA; + pvals[i] = vals; + maxval = fmax3(vals.x, vals.y, maxval); + maxval = fmax3(vals.z, vals.w, maxval); + } + simdgroup_barrier(mem_flags::mem_none); + groupMax = simd_max(maxval); + + float sumExpLocal = 0.f; +#pragma clang loop unroll(full) + for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { + pvals[i] = exp(pvals[i] - groupMax); + sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w; + } + simdgroup_barrier(mem_flags::mem_none); + float tgroupExpSum = simd_sum(sumExpLocal); + lse = log(tgroupExpSum); +#pragma clang loop unroll(full) + for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { + pvals[i] = pvals[i] / tgroupExpSum; + smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f); + } + } + + threadgroup T* smemV = (threadgroup T*)threadgroup_block; + + const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK; + const size_t v_head_offset = kv_head_offset_factor * L * DK; + + const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK; + const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset; + device T* baseV = (device T*)V + v_offset; + + threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV); + + if (!LAST_TILE || LAST_TILE_ALIGNED) { +#pragma clang loop unroll(full) + for (size_t col = 0; col < MATRIX_COLS; col++) { + uint matrix_load_loop_iter = 0; + constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; + + for (size_t tile_start = simd_group_id; + tile_start < TILE_SIZE_CONST_DIV_8; + tile_start += NSIMDGROUPS) { + simdgroup_matrix tmp; + ulong simdgroup_matrix_offset = + matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + ulong2 matrixOrigin = + ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); + simdgroup_load(tmp, baseV, DK, matrixOrigin, true); + const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); + const ulong elemsPerRowSmem = TILE_SIZE_CONST; + simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false); + matrix_load_loop_iter++; + }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (TILE_SIZE_CONST == 64) { + T2 local_p_hat = T2(pvals[0].x, pvals[0].y); + uint loop_iter = 0; + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + +#pragma clang loop unroll(full) + for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; + row += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); + threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; + T2 v_local = *(smemV2 + simd_lane_id); + + T val = dot(local_p_hat, v_local); + simdgroup_barrier(mem_flags::mem_none); + + T row_sum = simd_sum(val); + oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = + float(row_sum); + loop_iter++; + } + } + + if (TILE_SIZE_CONST > 64) { + constexpr const size_t TILE_SIZE_CONST_DIV_128 = + (TILE_SIZE_CONST + 1) / 128; + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + uint loop_iter = 0; + for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; + row += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); + + T row_sum = 0.f; + for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) { + threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; + T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP); + T4 p_local = T4(pvals[i]); + T val = dot(p_local, v_local); + row_sum += val; + } + simdgroup_barrier(mem_flags::mem_none); + row_sum = simd_sum(row_sum); + oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = + float(row_sum); + loop_iter++; + } + } + } + } else { + const int32_t START_ROW = tid.y * TILE_SIZE_CONST; + const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1; + const device T* baseVThisHead = V + v_batch_offset + v_head_offset; + constexpr const int ROWS_PER_ITER = 8; +#pragma clang loop unroll(full) + for (size_t col = 0; col < MATRIX_COLS; col++) { + uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + int32_t tile_start; + for (tile_start = + START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + tile_start < MAX_START_ROW; + tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { + simdgroup_matrix tmp; + ulong2 matrixOrigin = + ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); + simdgroup_load( + tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); + const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); + constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; + simdgroup_store( + tmp, + smemV, + elemsPerRowSmem, + matrixOriginSmem, + /* transpose */ false); + smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; + }; + + tile_start = + ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); + + const int32_t INT_L = int32_t(L); + for (int row_index = tile_start + simd_group_id; row_index < INT_L; + row_index += NSIMDGROUPS) { + if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { + const uint elems_per_row_gmem = DK; + const uint col_index_v_gmem = + col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; + const uint row_index_v_gmem = row_index; + + const uint elems_per_row_smem = TILE_SIZE_CONST; + const uint col_index_v_smem = row_index % TILE_SIZE_CONST; + const uint row_index_v_smem = simd_lane_id; + + const uint scalar_offset_gmem = + row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; + const uint scalar_offset_smem = + row_index_v_smem * elems_per_row_smem + col_index_v_smem; + T vdata = T(*(baseVThisHead + scalar_offset_gmem)); + smemV[scalar_offset_smem] = vdata; + smem_col_index += NSIMDGROUPS; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (TILE_SIZE_CONST == 64) { + T2 local_p_hat = T2(pvals[0].x, pvals[0].y); + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + for (size_t smem_row_index = simd_group_id; + smem_row_index < ROWS_PER_ITER; + smem_row_index += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); + threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; + T2 v_local = *(smemV2 + simd_lane_id); + T val = dot(local_p_hat, v_local); + simdgroup_barrier(mem_flags::mem_none); + T row_sum = simd_sum(val); + oPartialSmem[smem_row_index] = float(row_sum); + } + } + + if (TILE_SIZE_CONST > 64) { + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + uint loop_count = 0; + for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER; + row_index += NSIMDGROUPS) { + T row_sum = 0.f; + for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; + tile_iters++) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); + threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; + T4 v_local = + *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); + T4 p_local = T4(pvals[tile_iters]); + row_sum += dot(p_local, v_local); + } + simdgroup_barrier(mem_flags::mem_none); + row_sum = simd_sum(row_sum); + oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = + float(row_sum); + loop_count++; + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_group_id == 0) { + threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; + float4 vals = *(oPartialVec4 + simd_lane_id); + device float* oPartialGmem = + O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; + device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; + oPartialGmemVec4[simd_lane_id] = vals; + } + + if (simd_group_id == 0 && simd_lane_id == 0) { + const uint tileIndex = tid.y; + const uint gmem_partial_scalar_offset = + tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + + tileIndex; + p_lse[gmem_partial_scalar_offset] = lse; + p_maxes[gmem_partial_scalar_offset] = groupMax; + } +} + +#define instantiate_fast_inference_sdpa_to_partials_kernel( \ + itype, itype2, itype4, tile_size, nsimdgroups) \ + template [[host_name("fast_inference_sdpa_compute_partials_" #itype \ + "_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \ + fast_inference_sdpa_compute_partials_template< \ + itype, \ + itype2, \ + itype4, \ + tile_size, \ + nsimdgroups>( \ + const device itype* Q [[buffer(0)]], \ + const device itype* K [[buffer(1)]], \ + const device itype* V [[buffer(2)]], \ + const device uint64_t& L [[buffer(3)]], \ + const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \ + device float* O_partials [[buffer(5)]], \ + device float* p_lse [[buffer(6)]], \ + device float* p_maxes [[buffer(7)]], \ + threadgroup itype* threadgroup_block [[threadgroup(0)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]]); + +// clang-format off +#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \ + itype, itype2, itype4, tile_size) \ + instantiate_fast_inference_sdpa_to_partials_kernel( \ + itype, itype2, itype4, tile_size, 4) \ + instantiate_fast_inference_sdpa_to_partials_kernel( \ + itype, itype2, itype4, tile_size, 8) // clang-format on + +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 64); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 128); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 256); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 512); + +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 64); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 128); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 256); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 512); + +template +void fast_inference_sdpa_reduce_tiles_template( + const device float* O_partials [[buffer(0)]], + const device float* p_lse [[buffer(1)]], + const device float* p_maxes [[buffer(2)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], + device T* O [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + constexpr const int DK = 128; + const ulong offset_rows = + tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; + const device float* p_lse_row = p_lse + offset_rows; + const device float* p_rowmax_row = p_maxes + offset_rows; + // reserve some number of registers. this constitutes an assumption on max + // value of KV TILES. + constexpr const uint8_t reserve = 128; + float p_lse_regs[reserve]; + float p_rowmax_regs[reserve]; + float weights[reserve]; + + float true_max = -INFINITY; + for (size_t i = 0; i < params.KV_TILES; i++) { + p_lse_regs[i] = float(*(p_lse_row + i)); + p_rowmax_regs[i] = float(*(p_rowmax_row + i)); + true_max = fmax(p_rowmax_regs[i], true_max); + weights[i] = exp(p_lse_regs[i]); + } + + float denom = 0.f; + for (size_t i = 0; i < params.KV_TILES; i++) { + weights[i] *= exp(p_rowmax_regs[i] - true_max); + denom += weights[i]; + } + + const device float* O_partials_with_offset = O_partials + + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + + tid.x * DK * params.KV_TILES; + + float o_value = 0.f; + for (size_t i = 0; i < params.KV_TILES; i++) { + float val = *(O_partials_with_offset + i * DK + lid.x); + o_value += val * weights[i] / denom; + } + device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK; + O_gmem[lid.x] = T(o_value); + return; +} + +kernel void fast_inference_sdpa_reduce_tiles_float( + const device float* O_partials [[buffer(0)]], + const device float* p_lse [[buffer(1)]], + const device float* p_maxes [[buffer(2)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], + device float* O [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + fast_inference_sdpa_reduce_tiles_template( + O_partials, p_lse, p_maxes, params, O, tid, lid); +} + +kernel void fast_inference_sdpa_reduce_tiles_half( + const device float* O_partials [[buffer(0)]], + const device float* p_lse [[buffer(1)]], + const device float* p_maxes [[buffer(2)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], + device half* O [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + fast_inference_sdpa_reduce_tiles_template( + O_partials, p_lse, p_maxes, params, O, tid, lid); +} diff --git a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention_params.h b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention_params.h new file mode 100644 index 00000000..a77dad26 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention_params.h @@ -0,0 +1,42 @@ +// +// scaled_dot_product_attention_params.h +// mlx + +#pragma once + +struct MLXFastAttentionParams { + const int M; + const int N; + const int K; + + const int ldq; // ldq == ldo + const int ldk; + const int ldv; + const int lds; + const int ldo; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_q; + const int batch_stride_k; + const int batch_stride_v; + const int batch_stride_o; + + const int swizzle_log; + const int gemm_n_iterations_aligned; + const int gemm_k_iterations_aligned; + const int gemm_sv_m_block_iterations; + + const int batch_ndim; + const float alpha; +}; + +struct MLXScaledDotProductAttentionParams { + // Associated dimensions & transposition information + const uint QUERY_SEQUENCE_LENGTH = 1; + const uint N_Q_HEADS = 32; + const uint N_KV_HEADS = 32; + const uint KV_TILES = 1; + const float INV_ALPHA = 0.08838834764831843f; +}; diff --git a/Source/Cmlx/mlx-generated/metal/scan.h b/Source/Cmlx/mlx-generated/metal/scan.h new file mode 100644 index 00000000..67b27ba8 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/scan.h @@ -0,0 +1,441 @@ +// Copyright © 2023-2024 Apple Inc. + +template +struct CumSum { + static constexpr constant U init = static_cast(0); + + template + U operator()(U a, T b) { + return a + b; + } + + U simd_scan(U x) { + return simd_prefix_inclusive_sum(x); + } + + U simd_exclusive_scan(U x) { + return simd_prefix_exclusive_sum(x); + } +}; + +template +struct CumProd { + static constexpr constant U init = static_cast(1.0f); + + template + U operator()(U a, T b) { + return a * b; + } + + U simd_scan(U x) { + return simd_prefix_inclusive_product(x); + } + + U simd_exclusive_scan(U x) { + return simd_prefix_exclusive_product(x); + } +}; + +template <> +struct CumProd { + static constexpr constant bool init = true; + + template + bool operator()(bool a, T b) { + return a & static_cast(b); + } + + bool simd_scan(bool x) { + for (int i = 1; i <= 16; i *= 2) { + bool other = simd_shuffle_up(x, i); + x &= other; + } + return x; + } + + bool simd_exclusive_scan(bool x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMax { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return (a >= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_up(x, i); + x = (x >= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMin { + static constexpr constant U init = Limits::max; + + template + U operator()(U a, T b) { + return (a <= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_up(x, i); + x = (x <= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +inline void load_unsafe(U values[N_READS], const device T* input) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = input[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = input[i]; + } + } +} + +template +inline void load_safe( + U values[N_READS], + const device T* input, + int start, + int total, + U init) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = + (start + N_READS - i - 1 < total) ? input[i] : init; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = (start + i < total) ? input[i] : init; + } + } +} + +template +inline void write_unsafe(U values[N_READS], device U* out) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + out[i] = values[N_READS - i - 1]; + } + } else { + for (int i = 0; i < N_READS; i++) { + out[i] = values[i]; + } + } +} + +template +inline void write_safe(U values[N_READS], device U* out, int start, int total) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + if (start + N_READS - i - 1 < total) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; i++) { + if (start + i < total) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void contiguous_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + + // Position the pointers + in += (gid / lsize) * axis_size; + out += (gid / lsize) * axis_size; + + // Compute the number of simd_groups + uint simd_groups = lsize / simd_size; + + // Allocate memory + U prefix = Op::init; + U values[N_READS]; + threadgroup U simdgroup_sums[32]; + + // Loop over the reduced axis in blocks of size ceildiv(axis_size, + // N_READS*lsize) + // Read block + // Compute inclusive scan of the block + // Compute inclusive scan per thread + // Compute exclusive scan of thread sums in simdgroup + // Write simdgroup sums in SM + // Compute exclusive scan of simdgroup sums + // Compute the output by scanning prefix, prev_simdgroup, prev_thread, + // value + // Write block + + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + // Compute the block offset + uint offset = r * lsize * N_READS + lid * N_READS; + + // Read the values + if (reverse) { + if ((offset + N_READS) < axis_size) { + load_unsafe( + values, in + axis_size - offset - N_READS); + } else { + load_safe( + values, + in + axis_size - offset - N_READS, + offset, + axis_size, + Op::init); + } + } else { + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + offset); + } else { + load_safe( + values, in + offset, offset, axis_size, Op::init); + } + } + + // Compute an inclusive scan per thread + for (int i = 1; i < N_READS; i++) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums + U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); + + // Write simdgroup_sums to SM + if (simd_lane_id == simd_size - 1) { + simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute exclusive scan of simdgroup_sums + if (simd_group_id == 0) { + U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); + simdgroup_sums[simd_lane_id] = prev_simdgroup; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute the output + for (int i = 0; i < N_READS; i++) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], simdgroup_sums[simd_group_id]); + values[i] = op(values[i], prev_thread); + } + + // Write the values + if (reverse) { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe( + values, out + axis_size - offset - N_READS); + } else { + write_safe( + values, out + axis_size - offset - N_READS, offset, axis_size); + } + } else { + if (lid == 0 && offset == 0) { + out[axis_size - 1] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe( + values, out + axis_size - offset - 1 - N_READS); + } else { + write_safe( + values, + out + axis_size - offset - 1 - N_READS, + offset + 1, + axis_size); + } + } + } else { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe(values, out + offset); + } else { + write_safe( + values, out + offset, offset, axis_size); + } + } else { + if (lid == 0 && offset == 0) { + out[0] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe(values, out + offset + 1); + } else { + write_safe( + values, out + offset + 1, offset + 1, axis_size); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS - 1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void strided_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]]) { + Op op; + + // Allocate memory + threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32]; + U values[N_READS]; + U prefix[N_READS]; + for (int i = 0; i < N_READS; i++) { + prefix[i] = Op::init; + } + + // Compute offsets + int offset = gid.y * axis_size * stride; + int global_index_x = gid.x * lsize.y * N_READS; + + for (uint j = 0; j < axis_size; j += simd_size) { + // Calculate the indices for the current thread + uint index_y = j + lid.y; + uint check_index_y = index_y; + uint index_x = global_index_x + lid.x * N_READS; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM + if (check_index_y < axis_size && (index_x + N_READS) < stride) { + for (int i = 0; i < N_READS; i++) { + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = + in[offset + index_y * stride + index_x + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (index_x + i) < stride) { + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = + in[offset + index_y * stride + index_x + i]; + } else { + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = + Op::init; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read strided into registers + for (int i = 0; i < N_READS; i++) { + values[i] = + read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; + } + // Do we need the following barrier? Shouldn't all simd threads execute + // simultaneously? + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Perform the scan + for (int i = 0; i < N_READS; i++) { + values[i] = op.simd_scan(values[i]); + values[i] = op(values[i], prefix[i]); + prefix[i] = simd_shuffle(values[i], simd_size - 1); + } + + // Write to SM + for (int i = 0; i < N_READS; i++) { + read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = + values[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write to device memory + if (!inclusive) { + if (check_index_y == 0) { + if ((index_x + N_READS) < stride) { + for (int i = 0; i < N_READS; i++) { + out[offset + index_y * stride + index_x + i] = Op::init; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((index_x + i) < stride) { + out[offset + index_y * stride + index_x + i] = Op::init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (index_x + N_READS) < stride) { + for (int i = 0; i < N_READS; i++) { + out[offset + index_y * stride + index_x + i] = + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (index_x + i) < stride) { + out[offset + index_y * stride + index_x + i] = + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + } + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/scatter.h b/Source/Cmlx/mlx-generated/metal/scatter.h new file mode 100644 index 00000000..6c9f84e8 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/scatter.h @@ -0,0 +1,76 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "indexing.h" + +template +METAL_FUNC void scatter_1d_index_impl( + const device T* updates [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* out_shape [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& out_ndim [[buffer(5)]], + const constant int* upd_shape [[buffer(6)]], + const constant size_t& upd_ndim [[buffer(7)]], + const constant size_t& upd_size [[buffer(8)]], + const thread array& idx_buffers, + uint2 gid [[thread_position_in_grid]]) { + Op op; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; i++) { + auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); + out_idx += idx_val * out_strides[i]; + } + + if (upd_ndim > 1) { + auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim); + out_idx += out_offset; + } else { + out_idx += gid.x; + } + + op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx); +} + +template +METAL_FUNC void scatter_impl( + const device T* updates [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* upd_shape [[buffer(3)]], + const constant size_t* upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int* out_shape [[buffer(7)]], + const constant size_t* out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + Op op; + auto ind_idx = gid.y; + auto ind_offset = gid.x; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + + if (upd_size > 1) { + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + out_idx += out_offset; + } + + auto upd_idx = + elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); + op.atomic_update(out, updates[upd_idx], out_idx); +} diff --git a/Source/Cmlx/mlx-generated/metal/softmax.h b/Source/Cmlx/mlx-generated/metal/softmax.h new file mode 100644 index 00000000..b36b73bd --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/softmax.h @@ -0,0 +1,190 @@ +// Copyright © 2023-2024 Apple Inc. + +template +inline T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return fast::exp(x); +} + +template +[[kernel]] void softmax_single_row( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + AccT exp_x = softmax_exp(ld[i] - maxval); + ld[i] = exp_x; + normalizer += exp_x; + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + local_normalizer[0] = normalizer; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = 1 / local_normalizer[0]; + + // Normalize and write to the output + out += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = T(ld[i] * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = T(ld[i] * normalizer); + } + } + } +} + +template +[[kernel]] void softmax_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) + : Limits::finite_min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += softmax_exp(vals[i] - maxval); + } + } + // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * + // lsize) parts. We need to combine them. + // 1. We start by finding the max across simd groups + // 2. We then change the partial normalizers to account for a possible + // change in max + // 3. We sum all normalizers + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= softmax_exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + // Now the normalizer and max value is correct for each simdgroup. We write + // them shared memory and combine them. + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= softmax_exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + normalizer = 1 / normalizer; + + // Finally given the normalizer and max value we can directly write the + // softmax output + out += gid * size_t(axis_size); + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if (offset + i < axis_size) { + out[offset + i] = + T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/sort.h b/Source/Cmlx/mlx-generated/metal/sort.h new file mode 100644 index 00000000..bfa4d98e --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/sort.h @@ -0,0 +1,695 @@ +// Copyright © 2023-2024 Apple Inc. + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal; + +// Based on GPU merge sort algorithm at +// https://github.com/NVIDIA/cccl/tree/main/cub/cub + +/////////////////////////////////////////////////////////////////////////////// +// Thread-level sort +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void thread_swap(thread T& a, thread T& b) { + T w = a; + a = b; + b = w; +} + +template +struct LessThan { + static constexpr constant T init = Limits::max; + + METAL_FUNC bool operator()(T a, T b) { + return a < b; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + static METAL_FUNC void sort( + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + CompareOp op; + + MLX_MTL_LOOP_UNROLL + for (short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Threadgroup-level sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = + ThreadSort; + static METAL_FUNC int merge_partition( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + short A_sz, + short B_sz, + short sort_md) { + CompareOp op; + + short A_st = max(0, sort_md - B_sz); + short A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + short md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } + + static METAL_FUNC void merge_step( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + const threadgroup idx_t* As_idx, + const threadgroup idx_t* Bs_idx, + short A_sz, + short B_sz, + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + CompareOp op; + short a_idx = 0; + short b_idx = 0; + + for (int i = 0; i < N_PER_THREAD; ++i) { + auto a = As[a_idx]; + auto b = Bs[b_idx]; + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + + b_idx += short(pred); + a_idx += short(!pred); + } + } + + static METAL_FUNC void sort( + threadgroup val_t* tgp_vals [[threadgroup(0)]], + threadgroup idx_t* tgp_idxs [[threadgroup(1)]], + int size_sorted_axis, + uint3 lid [[thread_position_in_threadgroup]]) { + // Get thread location + int idx = lid.x * N_PER_THREAD; + + // Load from shared memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if (ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + // Per thread sort + if (idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + // Do merges using threadgroup memory + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { + // Update threadgroup memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find location in merge step + int merge_group = lid.x / merge_threads; + int merge_lane = lid.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + // As = tgp_vals[A_st:A_ed] is sorted + // Bs = tgp_vals[B_st:B_ed] is sorted + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const threadgroup val_t* As = tgp_vals + A_st; + const threadgroup val_t* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Find a partition of merge elements + // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // of size N_PER_THREAD for each merge lane i + // C = [Ci] is sorted + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const threadgroup idx_t* As_idx = + ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup idx_t* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + // Merge starting at the partition and store results in thread registers + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); + } + + // Write out to shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using val_t = T; + using idx_t = uint; + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device T* inp, + device U* out, + const constant int& size_sorted_axis, + const constant int& in_stride_sorted_axis, + const constant int& out_stride_sorted_axis, + const constant int& in_stride_segment_axis, + const constant int& out_stride_segment_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + inp += tid.y * in_stride_segment_axis; + out += tid.y * out_stride_segment_axis; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] + : val_t(CompareOp::init); + if (ARG_SORT) { + tgp_idxs[i] = i; + } + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { + if (ARG_SORT) { + out[i * out_stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * out_stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& in_stride_segment_axis [[buffer(5)]], + const constant int& out_stride_segment_axis [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + if (ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + nullptr, + tid, + lid); + } +} + +constant constexpr const int zero_helper = 0; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant size_t* in_nc_strides [[buffer(7)]], + const constant size_t* out_nc_strides [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); + auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); + inp += in_block_idx; + out += out_block_idx; + + if (ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + nullptr, + tid, + lid); + } +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMultiBlockMergeSort { + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device val_t* inp, + device val_t* out_vals, + device idx_t* out_idxs, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + int base_idx = tid.x * N_PER_BLOCK; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] + : val_t(CompareOp::init); + tgp_idxs[i] = idx; + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + out_vals[idx] = tgp_vals[i]; + out_idxs[idx] = tgp_idxs[i]; + } + } + } + + static METAL_FUNC int merge_partition( + const device val_t* As, + const device val_t* Bs, + int A_sz, + int B_sz, + int sort_md) { + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( + const device val_t* inp [[buffer(0)]], + device val_t* out_vals [[buffer(1)]], + device idx_t* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant size_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out_vals += tid.y * size_sorted_axis; + out_idxs += tid.y * size_sorted_axis; + + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + + sort_kernel::block_sort( + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, + tgp_idxs, + tid, + lid); +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel]] void mb_block_partition( + device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals [[buffer(1)]], + const device idx_t* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + const constant int& n_blocks [[buffer(5)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + block_partitions += tid.y * tgp_dims.x; + dev_vals += tid.y * size_sorted_axis; + dev_idxs += tid.y * size_sorted_axis; + + for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { + // Find location in merge step + int merge_group = i / merge_tiles; + int merge_lane = i % merge_tiles; + + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); + + block_partitions[i] = A_st + partition; + } +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_merge( + const device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals_in [[buffer(1)]], + const device idx_t* dev_idxs_in [[buffer(2)]], + device val_t* dev_vals_out [[buffer(3)]], + device idx_t* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + using block_sort_t = typename sort_kernel::block_merge_sort_t; + + block_partitions += tid.y * (num_tiles + 1); + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_out += tid.y * size_sorted_axis; + dev_idxs_out += tid.y * size_sorted_axis; + + int block_idx = tid.x; + int merge_group = block_idx / merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; + + int A_st = block_partitions[block_idx + 0]; + int A_ed = block_partitions[block_idx + 1]; + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); + int B_ed = min( + size_sorted_axis, + 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + + if ((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + B_ed = min(size_sorted_axis, sort_st + sort_sz); + } + + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Load from global memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + if (idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] + : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] + : dev_idxs_in[B_st + idx - A_sz]; + } else { + thread_vals[i] = CompareOp::init; + thread_idxs[i] = 0; + } + } + + // Write to shared memory + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + tgp_vals[idx] = thread_vals[i]; + tgp_idxs[idx] = thread_idxs[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Merge + int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); + + int A_st_local = block_sort_t::merge_partition( + tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); + int A_ed_local = A_sz; + + int B_st_local = sort_md_local - A_st_local; + int B_ed_local = B_sz; + + int A_sz_local = A_ed_local - A_st_local; + int B_sz_local = B_ed_local - B_st_local; + + // Do merge + block_sort_t::merge_step( + tgp_vals + A_st_local, + tgp_vals + A_ed_local + B_st_local, + tgp_idxs + A_st_local, + tgp_idxs + A_ed_local + B_st_local, + A_sz_local, + B_sz_local, + thread_vals, + thread_idxs); + + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + int idx = lid.x * N_PER_THREAD; + tgp_vals[idx + i] = thread_vals[i]; + tgp_idxs[idx + i] = thread_idxs[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Write output + int base_idx = tid.x * sort_kernel::N_PER_BLOCK; + for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + dev_vals_out[idx] = tgp_vals[i]; + dev_idxs_out[idx] = tgp_idxs[i]; + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h b/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h new file mode 100644 index 00000000..0845f521 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h @@ -0,0 +1,13 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/defines.h" +#include "../../steel/utils.h" + +#include "../../steel/conv/loader.h" +#include "../../steel/conv/params.h" +#include "../../steel/gemm/mma.h" + +using namespace metal; +using namespace mlx::steel; diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h new file mode 100644 index 00000000..6f822c1d --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h @@ -0,0 +1,176 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + int N_CHANNELS = 0, + bool SMALL_FILTER = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using namespace mlx::steel; + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + + using loader_a_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DInputBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_a>, + + // Else go to general loader + typename metal::conditional_t< + // Check if filter size is small enough + SMALL_FILTER, + + // Go to small filter specialization + Conv2DInputBlockLoaderSmallFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>, + + // Else go to large filter generalization + Conv2DInputBlockLoaderLargeFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>>>; + + // Weight loader + using loader_b_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DWeightBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_b>, + + // Else go to general loader + Conv2DWeightBlockLoader>; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + const int N = gemm_params->N; + const int C_per_group = params->C / params->groups; + + // Groups + A += tid.z * C_per_group; + B += tid.z * N * K; + C += tid.z * N; + + B += c_col * K; + C += c_row * (N * params->groups) + c_col; + + const int2 offsets_a(0, c_row); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); + loader_b_t loader_b( + B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + short tgp_bm = min(BM, gemm_params->M - c_row); + short tgp_bn = min(BN, gemm_params->N - c_col); + const int ldc = N * params->groups; + mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h new file mode 100644 index 00000000..d52b2654 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h @@ -0,0 +1,188 @@ +// Copyright © 2024 Apple Inc. + +#include "../../../steel/conv/loaders/loader_general.h" + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + typename AccumType = float, + typename Epilogue = TransformNone> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d_general( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + using loader_a_t = + Conv2DInputBlockLoaderGeneral; + + // Weight loader + using loader_b_t = + Conv2DWeightBlockLoaderGeneral; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int tid_z = tid.z; + + const int base_oh = tid_z / jump_params->f_out_jump_w; + const int base_ow = tid_z % jump_params->f_out_jump_w; + + const int base_wh = base_h[base_oh].weight_base; + const int base_ww = base_w[base_ow].weight_base; + + const int base_wh_size = base_h[base_oh].weight_size; + const int base_ww_size = base_w[base_ow].weight_size; + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + + B += c_col * K; + + const int4 offsets_a(0, c_row, base_oh, base_ow); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, + As, + offsets_a, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + loader_b_t loader_b( + B, + Bs, + offsets_b, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + { + // Adjust for simdgroup and thread locatio + int offset_m = c_row + mma_op.sm + mma_op.tm; + int offset_n = c_col + mma_op.sn + mma_op.tn; + C += offset_n; + + if (offset_n >= gemm_params->N) + return; + + short diff = gemm_params->N - offset_n; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < mma_t::TM; i++) { + int cm = offset_m + i * mma_t::TM_stride; + + int n = cm / jump_params->adj_out_hw; + int hw = cm % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; + + if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { + int offset_cm = n * params->out_strides[0] + + oh * params->out_strides[1] + ow * params->out_strides[2]; + + STEEL_PRAGMA_UNROLL + for (int j = 0; j < mma_t::TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = + mma_op.results[i * mma_t::TN + j].thread_elements(); + int offset = offset_cm + (j * mma_t::TN_stride); + + // Apply epilogue and output C + if (j * mma_t::TN_stride < diff) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * mma_t::TN_stride + 1 < diff) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h new file mode 100644 index 00000000..bb9b3926 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/conv/loaders/loader_channel_l.h" +#include "../../steel/conv/loaders/loader_channel_n.h" \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h new file mode 100644 index 00000000..85a6d134 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h @@ -0,0 +1,449 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../../steel/utils.h" + +#include "../../../steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderLargeFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderLargeFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h * params->kdil[0]; + int iw = read_iw[i] + weight_w * params->kdil[1]; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + using mask_t = short; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + mask_t mask_h[n_rows]; + mask_t mask_w[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + mask_h[i] = 0; + mask_w[i] = 0; + } + + for (short kh = 0; kh < params->wS[0]; kh++) { + short flip_h = params->flip ? params->wS[0] - kh - 1 : kh; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int n = read_n[i]; + int ih = read_ih[i] + flip_h * params->kdil[0]; + + bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0]; + + mask_h[i] |= (in_bounds << kh); + } + } + + for (short kw = 0; kw < params->wS[1]; kw++) { + short flip_w = params->flip ? params->wS[1] - kw - 1 : kw; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int iw = read_iw[i] + flip_w * params->kdil[1]; + + bool in_bounds = iw >= 0 && iw < params->iS[1]; + + mask_w[i] |= (in_bounds << kw); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + mask_t h_mask = mask_t(1) << weight_h; + mask_t w_mask = mask_t(1) << weight_w; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Read from input if in bounds + if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoader { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_hw(0), + read_n(offsets.y + bi), + do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((read_n + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_hw < (params->wS[1] * params->wS[0])) { + src += params->wt_strides[2]; + return; + } + + weight_hw = 0; + + src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h new file mode 100644 index 00000000..2f12535f --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h @@ -0,0 +1,319 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../../steel/utils.h" + +#include "../../../steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct ChannelHelper { + STEEL_CONST short n_channels = n_channels_; + STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8; + STEEL_CONST short excess = vec_size - n_channels_; +}; + +template <> +struct ChannelHelper<1> { + STEEL_CONST short n_channels = 1; + STEEL_CONST short vec_size = 1; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<2> { + STEEL_CONST short n_channels = 2; + STEEL_CONST short vec_size = 2; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<3> { + STEEL_CONST short n_channels = 3; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 1; +}; + +template <> +struct ChannelHelper<4> { + STEEL_CONST short n_channels = 4; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 0; +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_hw; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_hw(thread_idx % TCOLS) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + int wh = (weight_hw / params->wS[1]); + int ww = (weight_hw % params->wS[1]); + + int flip_h = params->flip ? params->wS[0] - wh - 1 : wh; + int flip_w = params->flip ? params->wS[1] - ww - 1 : ww; + + int weight_h = flip_h * params->kdil[0]; + int weight_w = flip_w * params->kdil[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h; + int iw = read_iw[i] + weight_w; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + const device T* curr_src = src[i] + weight_h * params->in_strides[1] + + weight_w * params->in_strides[2]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; ++j) { + dst[is * dst_ld + j] = curr_src[j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld), + params(params_), + weight_hw(thread_idx % TCOLS), + read_n(offsets.y + bi), + do_read(read_n + BN <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (bi >= BROWS || bj >= BCOLS) + return; + + if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + + return; + } + + const device T* curr_src = src + weight_hw * params->wt_strides[2]; + + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } else { + for (short i = 0; i < BROWS; i += TROWS) { + if (((read_n + i) < params->O)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h new file mode 100644 index 00000000..3f5be762 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h @@ -0,0 +1,286 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../../steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int4 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / jump_params->adj_out_hw; + int hw = offset_nhw % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += BK; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const int start_row; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_), + start_row(offsets.y + bi) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + src += BK; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/params.h b/Source/Cmlx/mlx-generated/metal/steel/conv/params.h new file mode 100644 index 00000000..f75851dc --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/params.h @@ -0,0 +1,62 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +template +struct MLXConvParams { + const int N; // Batch size + const int C; // In channels + const int O; // Out channels + const int iS[NDIM]; // Input spatial dim + const int wS[NDIM]; // Weight spatial dim + const int oS[NDIM]; // Output spatial dim + const int str[NDIM]; // Kernel strides + const int pad[NDIM]; // Input padding + const int kdil[NDIM]; // Kernel dilation + const int idil[NDIM]; // Input dilation + const size_t in_strides[NDIM + 2]; // In strides + const size_t wt_strides[NDIM + 2]; // Wt strides + const size_t out_strides[NDIM + 2]; // Out strides + const int groups; // Input channel groups + const bool flip; +}; + +namespace mlx { +namespace steel { + +struct ImplicitGemmConv2DParams { + const int M; + const int N; + const int K; + + const int gemm_k_iterations; + + const int inp_jump_w; + const int inp_jump_h; + const int inp_jump_c; + + const int tiles_n; + const int tiles_m; + const int swizzle_log; +}; + +struct Conv2DGeneralJumpParams { + const int f_wgt_jump_h; + const int f_wgt_jump_w; + + const int f_out_jump_h; + const int f_out_jump_w; + + const int adj_out_h; + const int adj_out_w; + const int adj_out_hw; + const int adj_implicit_m; +}; + +struct Conv2DGeneralBaseInfo { + int weight_base; + int weight_size; +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/defines.h b/Source/Cmlx/mlx-generated/metal/steel/defines.h new file mode 100644 index 00000000..6c3bfcf4 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/defines.h @@ -0,0 +1,4 @@ +// Copyright © 2024 Apple Inc. + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h new file mode 100644 index 00000000..697a8b56 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h @@ -0,0 +1,295 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/gemm/loader.h" +#include "../../steel/gemm/mma.h" +#include "../../steel/gemm/params.h" +#include "../../steel/gemm/transforms.h" +#include "../../steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h new file mode 100644 index 00000000..5e1d2f23 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h @@ -0,0 +1,415 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +constant bool do_gather [[function_constant(300)]]; + +constant bool gather_bias = do_gather && use_out_source; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + + // Handle gather + if (do_gather) { + // Read indices + uint32_t indx_A, indx_B, indx_C; + + if (has_batch) { + const constant size_t* indx_A_bstrides = batch_strides; + const constant size_t* indx_B_bstrides = + batch_strides + params->batch_ndim; + + ulong2 indx_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + indx_A_bstrides, + indx_B_bstrides, + params->batch_ndim); + indx_A = lhs_indices[indx_offsets.x]; + indx_B = rhs_indices[indx_offsets.y]; + + if (use_out_source) { + const constant size_t* indx_C_bstrides = + indx_B_bstrides + params->batch_ndim; + auto indx_offset_C = elem_to_loc( + tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); + indx_C = C_indices[indx_offset_C]; + } + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + + if (use_out_source) { + indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; + } + } + + // Translate indices to offsets + int batch_ndim_A = operand_batch_ndim.x; + const constant int* batch_shape_A = operand_shape; + const constant size_t* batch_strides_A = operand_strides; + A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); + + int batch_ndim_B = operand_batch_ndim.y; + const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; + const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; + B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + + if (use_out_source) { + int batch_ndim_C = operand_batch_ndim.z; + const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; + const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; + C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + } + + } + + // Handle regular batch + else { + if (has_batch) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h new file mode 100644 index 00000000..46e91bb2 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h @@ -0,0 +1,719 @@ +// Copyright © 2024 Apple Inc. + +#include "../../../steel/defines.h" +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +typedef struct _NoMask nomask_t; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const device out_mask_t* out_mask [[buffer(10)]], + const device op_mask_t* lhs_mask [[buffer(11)]], + const device op_mask_t* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + static_assert( + BM == BN, + "block_masked_gemm must have the same block M and block N size"); + static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + constexpr bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + constexpr bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + constexpr short k_mask_factor = short(BM / BK); + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + const constant size_t* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + + if (params->batch_ndim > 1) { + if (has_output_mask) { + out_mask += elem_to_loc( + tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_lhs = mask_batch_strides; + const constant size_t* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + lhs_mask += tid.z * mask_batch_strides[0]; + rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + const constant int* out_mask_strides = mask_strides; + const constant int* lhs_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* rhs_mask_strides = + lhs_mask_strides + (has_operand_mask ? 2 : 0); + + const int out_mask_offset = !has_output_mask + ? 0 + : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; + int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; + int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; + const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; + const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; + short k_factor_cnt = k_mask_factor; + + ScaleOp out_mask_op; + ScaleOp lhs_mask_op; + ScaleOp rhs_mask_op; + + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + if (has_mul_output_mask) { + out_mask_op.scale = float(mask_out); + } + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = + MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); + const short tgp_bn = + MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // Do unaligned K iterations first + if (!K_aligned) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int mask_idx_last = k_last / BM; + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && + bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = + lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; + rhs_mask_op.scale = + rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; + } + + // Move loader source ahead to end + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + } + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { + const bool M_aligned = (tgp_bm == BM); + const bool N_aligned = (tgp_bn == BN); + + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + bool has_operand_mask = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const device bool* out_mask [[buffer(10)]], + const device bool* lhs_mask [[buffer(11)]], + const device bool* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + if (params->batch_ndim > 1) { + const constant size_t* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + if (has_operand_mask) { + const constant size_t* mask_strides_lhs = + mask_batch_strides + params->batch_ndim; + const constant size_t* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + out_mask += tid.z * batch_strides[2 * params->batch_ndim]; + if (has_operand_mask) { + lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; + rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short lbk = params->K - params->gemm_k_iterations_aligned * BK; + + bool M_aligned = (tgp_bm == BM); + bool N_aligned = (tgp_bn == BN); + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h new file mode 100644 index 00000000..1ff97ea4 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h @@ -0,0 +1,227 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const size_t k_start_long = size_t(k_start); + + A += transpose_a ? (c_row_long + k_start_long * params->lda) + : (k_start_long + c_row_long * params->lda); + B += transpose_b ? (k_start_long + c_col_long * params->ldb) + : (c_col_long + k_start_long * params->ldb); + C += (size_t(params->split_k_partition_stride) * tid_z) + + (c_row_long * params->ldc + c_col_long); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = + (params->K - (k_start + params->split_k_partition_size)) / BK; + if (!K_aligned || gemm_k_iter_remaining > 0) + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iter_remaining, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Split k accumulation kernel +/////////////////////////////////////////////////////////////////////////////// + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformNone> +[[kernel]] void gemm_splitk_accum( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + D[0] = Epilogue::apply(out); +} + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformAxpby> +[[kernel]] void gemm_splitk_accum_axpby( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device OutT* C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + C += gid.x * size_t(fdc) + gid.y * size_t(ldc); + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + Epilogue op(alpha, beta); + D[0] = op.apply(out, *C); +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h new file mode 100644 index 00000000..1846e26d --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h @@ -0,0 +1,137 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h new file mode 100644 index 00000000..948bbc61 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h @@ -0,0 +1,361 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "../../steel/defines.h" +#include "../../steel/gemm/transforms.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) const { + // Adjust for simdgroup and thread location + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out D + D[offset] = outs[0]; + D[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + D += (sm + tm) * ldd + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0]); + accum[1] = epilogue_op.apply(accum[1]); + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], C[offset_c]); + accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Read C + U c_elems[2] = {0}; + + if ((j * TN_stride + 1) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + c_elems[1] = C[offset_c + fdc]; + } else if ((j * TN_stride) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + } + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], c_elems[0]); + accum[1] = epilogue_op.apply(accum[1], c_elems[1]); + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h new file mode 100644 index 00000000..e8bcb221 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h @@ -0,0 +1,64 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const size_t batch_stride_a; + const size_t batch_stride_b; + const size_t batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const int batch_ndim; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int ldc; + const int fdc; + + const size_t batch_stride_c; + + const float alpha; + const float beta; +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h new file mode 100644 index 00000000..3d8ca054 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils.h b/Source/Cmlx/mlx-generated/metal/steel/utils.h new file mode 100644 index 00000000..322b2250 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/utils.h @@ -0,0 +1,42 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} diff --git a/Source/Cmlx/mlx-generated/metal/ternary.h b/Source/Cmlx/mlx-generated/metal/ternary.h new file mode 100644 index 00000000..2bd1242c --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/ternary.h @@ -0,0 +1,110 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void ternary_v( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint index [[thread_position_in_grid]]) { + d[index] = Op()(a[index], b[index], c[index]); +} + +template +[[kernel]] void ternary_v2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + d[offset] = Op()(a[offset], b[offset], c[offset]); +} + +template +[[kernel]] void ternary_g_nd1( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t& a_strides, + constant const size_t& b_strides, + constant const size_t& c_strides, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); + d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + constant const size_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd3( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + constant const size_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + size_t out_idx = + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd( + {N * index.x, index.y, index.z}, + shape, + a_strides, + b_strides, + c_strides, + ndim); + auto xshape = shape[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + auto c_xstride = c_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); + idx.x += a_xstride; + idx.y += b_xstride; + idx.z += c_xstride; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/ternary_ops.h b/Source/Cmlx/mlx-generated/metal/ternary_ops.h new file mode 100644 index 00000000..e0235d9d --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/ternary_ops.h @@ -0,0 +1,10 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; diff --git a/Source/Cmlx/mlx-generated/metal/unary.h b/Source/Cmlx/mlx-generated/metal/unary.h new file mode 100644 index 00000000..8d404ae2 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/unary.h @@ -0,0 +1,40 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void unary_v( + device const T* in, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = Op()(in[index]); +} + +template +[[kernel]] void unary_v2( + device const T* in, + device T* out, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + out[offset] = Op()(in[offset]); +} + +template +[[kernel]] void unary_g( + device const T* in, + device T* out, + constant const int* in_shape, + constant const size_t* in_strides, + device const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = + elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto xshape = in_shape[ndim - 1]; + auto xstride = in_strides[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + out[out_idx++] = Op()(in[idx]); + idx += xstride; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/unary_ops.h b/Source/Cmlx/mlx-generated/metal/unary_ops.h new file mode 100644 index 00000000..b7346405 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/unary_ops.h @@ -0,0 +1,400 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "erf.h" +#include "expm1f.h" + +namespace { +constant float inf = metal::numeric_limits::infinity(); +} + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Conjugate { + complex64_t operator()(complex64_t x) { + return complex64_t{x.real, -x.imag}; + } +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + template <> + complex64_t operator()(complex64_t x) { + auto m = metal::precise::exp(x.real); + return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + } +}; + +struct Expm1 { + template + T operator()(T x) { + return static_cast(expm1f(static_cast(x))); + }; +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + template <> + uint32_t operator()(uint32_t x) { + return x != 0; + }; + template <> + complex64_t operator()(complex64_t x) { + if (x == complex64_t(0)) { + return x; + } + return x / + (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; diff --git a/Source/Cmlx/mlx-generated/metal/utils.h b/Source/Cmlx/mlx-generated/metal/utils.h new file mode 100644 index 00000000..e94901c9 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/utils.h @@ -0,0 +1,322 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include "bf16.h" +#include "complex.h" +#include "defines.h" + +typedef half float16_t; + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +template <> +struct Limits { + static constexpr constant complex64_t max = complex64_t( + metal::numeric_limits::infinity(), + metal::numeric_limits::infinity()); + static constexpr constant complex64_t min = complex64_t( + -metal::numeric_limits::infinity(), + -metal::numeric_limits::infinity()); +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template +METAL_FUNC stride_t elem_to_loc( + uint elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +// Non templated version to handle arbitrary dims +template +METAL_FUNC stride_t elem_to_loc( + uint3 elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with fixed N dims + +template +METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) { + return elem * stride; +} + +template +METAL_FUNC stride_t +elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) { + return elem.x * strides[1] + elem.y * strides[0]; +} + +template +METAL_FUNC stride_t +elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { + return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +} + +/////////////////////////////////////////////////////////////////////////////// +// Multiple Arrays with generic dims + +template +METAL_FUNC ulong2 elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const stride_t* a_strides, + constant const stride_t* b_strides, + int ndim) { + ulong2 loc = { + ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +METAL_FUNC ulong3 elem_to_loc_3_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong3 loc = { + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2], + elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + loc.z += l * c_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct looped_elem_to_loc { + looped_elem_to_loc inner_looper; + offset_t offset{0}; + int index{0}; + + void next(const constant int* shape, const constant size_t* strides) { + index++; + offset += strides[dim - 1]; + + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + void next(int n, const constant int* shape, const constant size_t* strides) { + index += n; + offset += n * strides[dim - 1]; + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + offset_t + location(offset_t, const constant int*, const constant size_t*, int) { + return offset; + } +}; + +template +struct looped_elem_to_loc<1, offset_t> { + offset_t offset{0}; + + void next(const constant int*, const constant size_t* strides) { + offset += strides[0]; + } + + void next(int n, const constant int*, const constant size_t* strides) { + offset += n * strides[0]; + } + + offset_t + location(offset_t, const constant int*, const constant size_t*, int) { + return offset; + } +}; + +template +struct looped_elem_to_loc<0, offset_t> { + void next(const constant int*, const constant size_t*) {} + void next(int, const constant int*, const constant size_t*) {} + + offset_t location( + offset_t idx, + const constant int* shape, + const constant size_t* strides, + int ndim) { + return elem_to_loc(idx, shape, strides, ndim); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Calculation utils +/////////////////////////////////////////////////////////////////////////////// + +/** Compute ceil((float)N/(float)M) */ +template +inline T ceildiv(T N, U M) { + return (N + M - 1) / M; +} + +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + float xp1 = 1.0f + x; + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return x * (metal::log(xp1) / (xp1 - 1.0f)); +} + +inline bfloat16_t log1p(bfloat16_t x) { + float xp1 = 1.0f + static_cast(x); + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); +} + +/////////////////////////////////////////////////////////////////////////////// +// SIMD shuffle ops +/////////////////////////////////////////////////////////////////////////////// + +inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline bool simd_shuffle_down(bool data, uint16_t delta) { + return simd_shuffle_down(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); +} diff --git a/tools/fix-metal-includes.sh b/tools/fix-metal-includes.sh new file mode 100755 index 00000000..b5cced9e --- /dev/null +++ b/tools/fix-metal-includes.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Fixing include path for mlx-swift metal headers + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(realpath "${SCRIPT_DIR}/..") + +# Where the files end up +OUTPUT_DIR="${ROOT_DIR}/Source/Cmlx/mlx-generated/metal" + +# The Cmlx source dir +CMLX_MLX_DIR="${ROOT_DIR}/Source/Cmlx/mlx" + +# sub-directory of Cmlx source containing the kernels +KERNELS_INCLUDE_PATH="mlx/backend/metal/kernels" + +KERNELS_DIR="${CMLX_MLX_DIR}/${KERNELS_INCLUDE_PATH}" + +# list of kernels files to process +# see Source/Cmlx/mlx/mlx/backend/metal/kernels/CMakeLists.txt +KERNEL_LIST=" \ +arg_reduce.metal \ +conv.metal \ +gemv.metal \ +random.metal \ +rms_norm.metal \ +layer_norm.metal \ +rope.metal \ +scaled_dot_product_attention.metal" + +# We fixup all the header files AND the listed kernel files +HEADERS=$(find "${KERNELS_DIR}" -name "*.h") +KERNELS=$(for file in ${KERNEL_LIST}; do echo "${KERNELS_DIR}/${file}"; done) + +# Regular expression to replace include directives +PATTERN="^#include \"${KERNELS_INCLUDE_PATH}/([^\"]*)\"" + +mkdir -p "${OUTPUT_DIR}" + +# Mimic the original logic in PrepareMetalShaders::transformIncludes +# Returns rootPath, a string containing a sequence of "../../" to prefix the +# include path +function replaceIncludePrefix { + #Extract components up to the output dir and drop the last one + #swift: let pathUnderKernels = url.pathComponents.drop { $0 != "output" }.dropLast() + + absolutePath=$(realpath "${1}") + absoluteOut=$(realpath "${OUTPUT_DIR}") + remainingPath=${absolutePath#"$absoluteOut"/} + + # Doing the `dropLast` with `dirname`, handling the case where it returns `.`` + remainingPath=$(dirname "${remainingPath}" | sed -E 's|^\.$||') + + # Build the root path + # swift: let rootPath =Array(repeating: "..", count: pathUnderKernels.count - 1).joined(separator: "/") + # + ((pathUnderKernels.count - 1 == 0) ? "" : "/") + IFS='/' read -r -a path <<< "${remainingPath}" + count=${#path[@]} + + if [ "$count" -le 0 ]; then + root_path="" + else + root_path=$(printf "../%.0s" $(seq 1 "${count}")) + fi + echo "${root_path}" +} + +# First pass : copy the files if needed +for src in ${HEADERS} ${KERNELS}; do + + relative_path=${src#"$KERNELS_DIR"/} + dest=${OUTPUT_DIR}/${relative_path} + + # If destination file doesn't exist or if it's older than the source + # copy from source and replace the #include directives + if [ ! -e "$dest" ] || [ "$src" -nt "$dest" ]; then + echo "${src} -> ${dest}" + mkdir -p "$(dirname "${dest}")" + cp -p "${src}" "${dest}" + else + echo "Skipping $src (more recent destination)" + fi + +done + +# second pass: update the include lines +# iterating on src to only process the list of files we copied +# (in case the destination directory has other unrelated files) +for src in ${HEADERS} ${KERNELS}; do + + relative_path=${src#"$KERNELS_DIR"/} + dest=${OUTPUT_DIR}/${relative_path} + prefix=$(replaceIncludePrefix "${dest}") + + # for each matching input line, compute the relative path, then replace the line + while read -r includeLine; do + includePath=$(echo "${includeLine}" | sed -E -n "s|${PATTERN}|\1|p") + + # Note the absence of "/" between the prefix and the path + replace="${prefix}${includePath}" + + # Replace the include line with the new one + echo sed -i '' -e "s|${KERNELS_INCLUDE_PATH}/${includePath}|${replace}|" "${dest}" + sed -i '' -e "s|${KERNELS_INCLUDE_PATH}/${includePath}|${replace}|" "${dest}" + + done < <(grep -E -o "${PATTERN}" "${dest}") +done diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index bdf7941d..b24c9c55 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -20,8 +20,6 @@ cmake ../Source/Cmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0 # until mlx supports overriding the METAL_VERSION you will need to edit # Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION # to "3.0" -# -# Also Plugins/PrepareMetalShaders/main.swift kernels needs to be in sync. # run the cmake build to generate the source files cd mlx/backend/metal @@ -60,7 +58,8 @@ make cpu_compiled_preamble cd ../../../.. -rm Source/Cmlx/mlx-generated/* +rm -rf Source/Cmlx/mlx-generated/metal +rm -f Source/Cmlx/mlx-generated/* cp build/mlx/backend/metal/jit/* Source/Cmlx/mlx-generated cp build/mlx/backend/common/compiled_preamble.cpp Source/Cmlx/mlx-generated @@ -72,3 +71,6 @@ for x in Source/Cmlx/mlx-generated/*.cpp ; do \ sed -i .tmp -e "s:`pwd`/::g" $x done; rm Source/Cmlx/mlx-generated/*.tmp + +# Update the headers +./tools/fix-metal-includes.sh \ No newline at end of file