forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MPS] Fix
torch.mm
correctness for large matrices (pytorch#117549)
Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows: ```objc NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new]; for (int64_t i = 0; i < M; i += tile_size) { const auto i_end = std::min(i + tile_size, M); NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new]; for (int64_t j = 0; j < K; j += tile_size) { const auto j_end = std::min(j + tile_size, K); MPSGraphTensor* tile = nil; for (int64_t k = 0; k < N; k += tile_size) { const auto k_end = std::min(k + tile_size, N); auto selfChunk = [graph sliceTensor:selfTensor starts:@[ @(i), @(k) ] ends:@[ @(i_end), @(k_end) ] strides:@[ @(1), @(1) ] name:nil]; auto otherChunk = [graph sliceTensor:otherTensor starts:@[ @(k), @(j) ] ends:@[ @(k_end), @(j_end) ] strides:@[ @(1), @(1) ] name:nil]; auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil]; tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM; } [row_chunks addObject:tile]; } auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject; [rows addObject:row]; } return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject; ``` One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable Fixes pytorch#116769 Pull Request resolved: pytorch#117549 Approved by: https://github.com/kulinseth
- Loading branch information
1 parent
f518cf8
commit 1872834
Showing
3 changed files
with
147 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters