From 1872834247084d76544c95e04fe2ec3370bac87d Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 17 Jan 2024 01:33:08 +0000 Subject: [PATCH] [MPS] Fix `torch.mm` correctness for large matrices (#117549) Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows: ```objc NSMutableArray* rows = [NSMutableArray new]; for (int64_t i = 0; i < M; i += tile_size) { const auto i_end = std::min(i + tile_size, M); NSMutableArray* row_chunks = [NSMutableArray new]; for (int64_t j = 0; j < K; j += tile_size) { const auto j_end = std::min(j + tile_size, K); MPSGraphTensor* tile = nil; for (int64_t k = 0; k < N; k += tile_size) { const auto k_end = std::min(k + tile_size, N); auto selfChunk = [graph sliceTensor:selfTensor starts:@[ @(i), @(k) ] ends:@[ @(i_end), @(k_end) ] strides:@[ @(1), @(1) ] name:nil]; auto otherChunk = [graph sliceTensor:otherTensor starts:@[ @(k), @(j) ] ends:@[ @(k_end), @(j_end) ] strides:@[ @(1), @(1) ] name:nil]; auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil]; tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM; } [row_chunks addObject:tile]; } auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject; [rows addObject:row]; } return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject; ``` One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable Fixes https://github.com/pytorch/pytorch/issues/116769 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117549 Approved by: https://github.com/kulinseth --- .../ATen/native/mps/operations/CrossKernel.mm | 1 + .../native/mps/operations/LinearAlgebra.mm | 134 +++++++++++++++++- test/test_mps.py | 16 +++ 3 files changed, 147 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/CrossKernel.mm b/aten/src/ATen/native/mps/operations/CrossKernel.mm index 1e04a7633f1aa..afabf047ccd71 100644 --- a/aten/src/ATen/native/mps/operations/CrossKernel.mm +++ b/aten/src/ATen/native/mps/operations/CrossKernel.mm @@ -9,6 +9,7 @@ namespace { static const char* METAL_CROSS = R"CROSS_METAL( +#include #include using namespace metal; diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 8aad3adef4e2e..66813cf350924 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -22,12 +22,119 @@ namespace at::native { namespace mps { +namespace { +static const char* METAL_LINALG = R"MATMUL_METAL( +#include + +using namespace metal; +template +T dot_product(constant T *v1, constant T* v2, ulong2 strides, uint32_t size) { + T rc = 0.0; + for (uint32_t i = 0; i < size; ++i) { + rc += v1[i * strides.x] * v2[i * strides.y]; + } + return rc; +} -enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE }; +template +kernel void naive_matmul( + constant T * mat1Data [[buffer(0)]], + constant T * mat2Data [[buffer(1)]], + device T * outputData [[buffer(2)]], + constant array & strides [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], + uint thread_index [[thread_position_in_grid]]) { + uint y = thread_index / sizes.x; + uint x = thread_index % sizes.x; + if (x >= sizes.x || y >= sizes.z) { + return; + } + auto rc = dot_product(mat1Data + x * strides[0].x, + mat2Data + y * strides[1].y, + ulong2(strides[0].y, strides[1].x), + sizes.y); + outputData[x * strides[2].x + y * strides[2].y] = rc; +} + +#define INSTANTIATE_NAIVE_MM(DTYPE) \ +template \ +[[host_name("naive_matmul_" #DTYPE)]] \ +kernel void naive_matmul( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant array & strides [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint thread_index [[thread_position_in_grid]]) + +INSTANTIATE_NAIVE_MM(float); +INSTANTIATE_NAIVE_MM(half); +)MATMUL_METAL"; + +id compileLinalgOpLibrary(id device) { + static id linalgLibrary = nil; + if (linalgLibrary) { + return linalgLibrary; + } + + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + linalgLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_LINALG encoding:NSASCIIStringEncoding] + options:options + error:&error]; + TORCH_CHECK(linalgLibrary, "Failed to create metal linalg library, error: ", [[error description] UTF8String]); + return linalgLibrary; +} + +id matmulPipelineState(id device, ScalarType scalar_type) { + std::string kernel = "naive_matmul_" + mps::scalarToMetalTypeString(scalar_type); + static std::unordered_map> psoCache; + id pso = psoCache[kernel]; + if (pso) { + return pso; + } + + NSError* error = nil; + id linalgLib = compileLinalgOpLibrary(device); + id matmulFunc = [linalgLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(matmulFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:matmulFunc error:&error]; + TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + + psoCache[kernel] = pso; + return pso; +} -static std::tuple do_mm(MPSGraph* graph, - const Tensor& self, - const Tensor& other) { +Tensor& do_metal_mm(const Tensor& self, const Tensor& other, Tensor& output) { + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto matmulPSO = matmulPipelineState(device, output.scalar_type()); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + getMPSProfiler().beginProfileKernel(matmulPSO, "naive_matmul", {self, other}); + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:matmulPSO]; + std::array sizes = {static_cast(self.size(0)), + static_cast(self.size(1)), + static_cast(output.size(1))}; + std::array strides = { + self.stride(0), self.stride(1), other.stride(0), other.stride(1), output.stride(0), output.stride(1)}; + mtl_setBuffer(computeEncoder, self, 0); + mtl_setBuffer(computeEncoder, other, 1); + mtl_setBuffer(computeEncoder, output, 2); + [computeEncoder setBytes:strides.data() length:sizeof(uint64_t) * strides.size() atIndex:3]; + [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4]; + mtl_dispatch1DJob(computeEncoder, matmulPSO, output.numel()); + getMPSProfiler().endProfileKernel(matmulPSO); + } + }); + return output; +} + +std::tuple do_mm(MPSGraph* graph, + const Tensor& self, + const Tensor& other) { if (self.numel() == 0 || other.numel() == 0) { auto output = [graph constantWithScalar:0.0 shape:getMPSShape({self.size(0), other.size(1)}) @@ -40,6 +147,15 @@ return {selfTensor, otherTensor, output}; } +bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) { + static bool always_use_metal = std::getenv("PYTORCH_MPS_PREFER_METAL") != nullptr; + constexpr auto max_stride_size = 32768; + return always_use_metal || self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || + other.stride(0) > max_stride_size || other.stride(1) > max_stride_size; +} + +} // anonymous namespace + static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; using CachedGraph = MPSBinaryCachedGraph; @@ -58,6 +174,14 @@ return output; } + // MPS matmul returns silently incorrect results if one of the matrix dimentions is greater than 2**15 + // And crashes if its a view of matrix with dimentions larger than 2**15 + // See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095 + // In such cases, fallback to navie but accurate metal shader + if (use_metal_mm(self, other, output)) { + return do_metal_mm(self, other, output); + } + @autoreleasepool { string key = "mm_out_mps_impl" + getTensorsStringKey({self, other}); @@ -85,6 +209,8 @@ return output; } +enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE }; + static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input, const Tensor& batch1, const Tensor& batch2, diff --git a/test/test_mps.py b/test/test_mps.py index 741adc1641d39..62a72e43d48eb 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -6869,6 +6869,22 @@ def test_index_64bit(self): gc.collect() torch.mps.empty_cache() + def test_mm_large(self): + """ Test that MM works for matrices with index larger than 32K """ + x = torch.rand(10, 1, device="mps") + y = torch.rand(1, 32769, device="mps") + # This used to crash with: + # error: subRange.start (24576) is not less than length of dimension[0] (16384) + # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095 + self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0) + # And below used to produce incorrect results + m, n, k = 1024, 1, 32769 + x = torch.rand(m, n, device="mps") + y = torch.rand(n, k, device="mps") + z = torch.mm(x, y).to("cpu") + z_cpu = torch.mm(x.to("cpu"), y.to("cpu")) + self.assertEqual(z, z_cpu) + # Test flip def test_flip(self): def helper(shape, dims):