Skip to content

Commit

Permalink
[MPS] Fix torch.mm correctness for large matrices (pytorch#117549)
Browse files Browse the repository at this point in the history
Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K
Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows:
```objc
  NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new];
  for (int64_t i = 0; i < M; i += tile_size) {
    const auto i_end = std::min(i + tile_size, M);
    NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new];
    for (int64_t j = 0; j < K; j += tile_size) {
      const auto j_end = std::min(j + tile_size, K);
      MPSGraphTensor* tile = nil;
      for (int64_t k = 0; k < N; k += tile_size) {
        const auto k_end = std::min(k + tile_size, N);
        auto selfChunk = [graph sliceTensor:selfTensor
                                     starts:@[ @(i), @(k) ]
                                       ends:@[ @(i_end), @(k_end) ]
                                    strides:@[ @(1), @(1) ]
                                       name:nil];
        auto otherChunk = [graph sliceTensor:otherTensor
                                      starts:@[ @(k), @(j) ]
                                        ends:@[ @(k_end), @(j_end) ]
                                     strides:@[ @(1), @(1) ]
                                        name:nil];
        auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil];

        tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM;
      }
      [row_chunks addObject:tile];
    }
    auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject;
    [rows addObject:row];
  }
  return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject;
```

One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable
Fixes pytorch#116769
Pull Request resolved: pytorch#117549
Approved by: https://github.com/kulinseth
  • Loading branch information
malfet authored and pytorchmergebot committed Jan 17, 2024
1 parent f518cf8 commit 1872834
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 4 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/mps/operations/CrossKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
namespace {

static const char* METAL_CROSS = R"CROSS_METAL(
#include <metal_array>
#include <metal_stdlib>
using namespace metal;
Expand Down
134 changes: 130 additions & 4 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,119 @@

namespace at::native {
namespace mps {
namespace {
static const char* METAL_LINALG = R"MATMUL_METAL(
#include <metal_array>
using namespace metal;
template<typename T>
T dot_product(constant T *v1, constant T* v2, ulong2 strides, uint32_t size) {
T rc = 0.0;
for (uint32_t i = 0; i < size; ++i) {
rc += v1[i * strides.x] * v2[i * strides.y];
}
return rc;
}
enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
template<typename T>
kernel void naive_matmul(
constant T * mat1Data [[buffer(0)]],
constant T * mat2Data [[buffer(1)]],
device T * outputData [[buffer(2)]],
constant array<ulong2, 3> & strides [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
uint y = thread_index / sizes.x;
uint x = thread_index % sizes.x;
if (x >= sizes.x || y >= sizes.z) {
return;
}
auto rc = dot_product(mat1Data + x * strides[0].x,
mat2Data + y * strides[1].y,
ulong2(strides[0].y, strides[1].x),
sizes.y);
outputData[x * strides[2].x + y * strides[2].y] = rc;
}
#define INSTANTIATE_NAIVE_MM(DTYPE) \
template \
[[host_name("naive_matmul_" #DTYPE)]] \
kernel void naive_matmul<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
constant DTYPE * mat2Data [[buffer(1)]], \
device DTYPE * outputData [[buffer(2)]], \
constant array<ulong2, 3> & strides [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint thread_index [[thread_position_in_grid]])
INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
)MATMUL_METAL";

id<MTLLibrary> compileLinalgOpLibrary(id<MTLDevice> device) {
static id<MTLLibrary> linalgLibrary = nil;
if (linalgLibrary) {
return linalgLibrary;
}

NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
linalgLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_LINALG encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(linalgLibrary, "Failed to create metal linalg library, error: ", [[error description] UTF8String]);
return linalgLibrary;
}

id<MTLComputePipelineState> matmulPipelineState(id<MTLDevice> device, ScalarType scalar_type) {
std::string kernel = "naive_matmul_" + mps::scalarToMetalTypeString(scalar_type);
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> linalgLib = compileLinalgOpLibrary(device);
id<MTLFunction> matmulFunc = [linalgLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(matmulFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:matmulFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

psoCache[kernel] = pso;
return pso;
}

static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
const Tensor& self,
const Tensor& other) {
Tensor& do_metal_mm(const Tensor& self, const Tensor& other, Tensor& output) {
auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();
auto matmulPSO = matmulPipelineState(device, output.scalar_type());
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
getMPSProfiler().beginProfileKernel(matmulPSO, "naive_matmul", {self, other});
auto computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:matmulPSO];
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
static_cast<uint32_t>(self.size(1)),
static_cast<uint32_t>(output.size(1))};
std::array<int64_t, 6> strides = {
self.stride(0), self.stride(1), other.stride(0), other.stride(1), output.stride(0), output.stride(1)};
mtl_setBuffer(computeEncoder, self, 0);
mtl_setBuffer(computeEncoder, other, 1);
mtl_setBuffer(computeEncoder, output, 2);
[computeEncoder setBytes:strides.data() length:sizeof(uint64_t) * strides.size() atIndex:3];
[computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
mtl_dispatch1DJob(computeEncoder, matmulPSO, output.numel());
getMPSProfiler().endProfileKernel(matmulPSO);
}
});
return output;
}

std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
const Tensor& self,
const Tensor& other) {
if (self.numel() == 0 || other.numel() == 0) {
auto output = [graph constantWithScalar:0.0
shape:getMPSShape({self.size(0), other.size(1)})
Expand All @@ -40,6 +147,15 @@
return {selfTensor, otherTensor, output};
}

bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
static bool always_use_metal = std::getenv("PYTORCH_MPS_PREFER_METAL") != nullptr;
constexpr auto max_stride_size = 32768;
return always_use_metal || self.stride(0) > max_stride_size || self.stride(1) > max_stride_size ||
other.stride(0) > max_stride_size || other.stride(1) > max_stride_size;
}

} // anonymous namespace

static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
Expand All @@ -58,6 +174,14 @@
return output;
}

// MPS matmul returns silently incorrect results if one of the matrix dimentions is greater than 2**15
// And crashes if its a view of matrix with dimentions larger than 2**15
// See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
// In such cases, fallback to navie but accurate metal shader
if (use_metal_mm(self, other, output)) {
return do_metal_mm(self, other, output);
}

@autoreleasepool {
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});

Expand Down Expand Up @@ -85,6 +209,8 @@
return output;
}

enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };

static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
const Tensor& batch1,
const Tensor& batch2,
Expand Down
16 changes: 16 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6869,6 +6869,22 @@ def test_index_64bit(self):
gc.collect()
torch.mps.empty_cache()

def test_mm_large(self):
""" Test that MM works for matrices with index larger than 32K """
x = torch.rand(10, 1, device="mps")
y = torch.rand(1, 32769, device="mps")
# This used to crash with:
# error: subRange.start (24576) is not less than length of dimension[0] (16384)
# See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
# And below used to produce incorrect results
m, n, k = 1024, 1, 32769
x = torch.rand(m, n, device="mps")
y = torch.rand(n, k, device="mps")
z = torch.mm(x, y).to("cpu")
z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
self.assertEqual(z, z_cpu)

# Test flip
def test_flip(self):
def helper(shape, dims):
Expand Down

0 comments on commit 1872834

Please sign in to comment.