From fa4421c6212c01b3e426bbeea4b668567ae3e62f Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 31 Mar 2023 14:55:28 -0400 Subject: [PATCH] MPSMatrix from SubArray # Conflicts: # lib/mps/linalg.jl --- lib/mps/matrix.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index c281769d0..0366d588b 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -136,6 +136,32 @@ function MPSMatrix(arr::MtlArray{T,3}) where T return obj end + +""" + MPSMatrix(arr::MtlMatrix) + +Metal matrix representation used in Performance Shaders. + +Note that this results in a transposed view of the input, +as Metal stores matrices row-major instead of column-major. +""" +function MPSMatrix(arr::SubArray{T,2,MtlArray{T,3}}) where T + n_cols, n_rows = size(arr) + row_bytes = sizeof(T)*n_cols + index = parentindices(arr)[3] + offset = row_bytes * n_cols * (index-1) + desc = MPSMatrixDescriptor(n_rows, n_cols, row_bytes, T) + mat = @objc [MPSMatrix alloc]::id{MPSMatrix} + obj = MPSMatrix(mat) + finalizer(release, obj) + @objc [obj::id{MPSMatrix} initWithBuffer:parent(arr).buffer::id{MTLBuffer} + offset:offset::NSUInteger + descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix} + return obj +end + +### parentindices(A) + # # matrix multiplication #