diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 830be5a29..c281769d0 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -18,7 +18,17 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) export MPSMatrixDescriptor -@objcwrapper MPSMatrixDescriptor <: NSObject +@objcwrapper immutable=false MPSMatrixDescriptor <: NSObject + +@objcproperties MPSMatrixDescriptor begin + @autoproperty rows::NSUInteger setter=setRows + @autoproperty columns::NSUInteger setter=setColumns + @autoproperty matrices::NSUInteger + @autoproperty dataType::MPSDataType setter=setDataType + @autoproperty rowBytes::NSUInteger setter=setRowBytes + @autoproperty matrixBytes::NSUInteger +end + # Mapping from Julia types to the Performance Shader bitfields const jl_typ_to_mps = Dict{DataType,MPSDataType}( @@ -49,6 +59,17 @@ function MPSMatrixDescriptor(rows, columns, rowBytes, dataType) return obj end +function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dataType) + desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger + columns:columns::NSUInteger + matrices:matrices::NSUInteger + rowBytes:rowBytes::NSUInteger + matrixBytes:matrixBytes::NSUInteger + dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor} + obj = MPSMatrixDescriptor(desc) + # XXX: who releases this object? + return obj +end # # matrix object @@ -58,6 +79,19 @@ export MPSMatrix @objcwrapper immutable=false MPSMatrix <: NSObject +@objcproperties MPSMatrix begin + @autoproperty device::id{MTLDevice} + @autoproperty rows::NSUInteger + @autoproperty columns::NSUInteger + @autoproperty matrices::NSUInteger + @autoproperty dataType::MPSDataType + @autoproperty rowBytes::NSUInteger + @autoproperty matrixBytes::NSUInteger + @autoproperty offset::NSUInteger + @autoproperty data::id{MTLBuffer} +end + + """ MPSMatrix(arr::MtlMatrix) @@ -71,13 +105,37 @@ function MPSMatrix(arr::MtlMatrix{T}) where T desc = MPSMatrixDescriptor(n_rows, n_cols, sizeof(T)*n_cols, T) mat = @objc [MPSMatrix alloc]::id{MPSMatrix} obj = MPSMatrix(mat) + offset = arr.offset * sizeof(T) finalizer(release, obj) @objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer} + offset:offset::NSUInteger descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix} return obj end +""" + MPSMatrix(arr::MtlArray{T,3}) + +Metal batched matrix representation used in Performance Shaders. + +Note that this results in a transposed view of the input, +as Metal stores matrices row-major instead of column-major. +""" +function MPSMatrix(arr::MtlArray{T,3}) where T + n_cols, n_rows, n_matrices = size(arr) + row_bytes = sizeof(T)*n_cols + desc = MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T) + mat = @objc [MPSMatrix alloc]::id{MPSMatrix} + obj = MPSMatrix(mat) + offset = arr.offset * sizeof(T) + finalizer(release, obj) + @objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer} + offset:offset::NSUInteger + descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix} + return obj +end + # # matrix multiplication # diff --git a/lib/mps/vector.jl b/lib/mps/vector.jl index 2128bda04..e2ad87efb 100644 --- a/lib/mps/vector.jl +++ b/lib/mps/vector.jl @@ -53,8 +53,10 @@ function MPSVector(arr::MtlVector{T}) where T desc = MPSVectorDescriptor(len, T) vec = @objc [MPSVector alloc]::id{MPSVector} obj = MPSVector(vec) + offset = arr.offset * sizeof(T) finalizer(release, obj) @objc [obj::id{MPSVector} initWithBuffer:arr::id{MTLBuffer} + offset:offset::NSUInteger descriptor:desc::id{MPSVectorDescriptor}]::id{MPSVector} return obj end diff --git a/test/mps.jl b/test/mps.jl index 545e09710..1e28807f5 100644 --- a/test/mps.jl +++ b/test/mps.jl @@ -35,6 +35,26 @@ if MPS.is_supported(current_device()) end end +@testset "test matrix vector multiplication of views" begin + N = 20 + a = rand(Float32, N,N) + b = rand(Float32, N) + + mtl_a = mtl(a) + mtl_b = mtl(b) + + view_a = @view a[:,10:end] + view_b = @view b[10:end] + + mtl_view_a = @view mtl_a[:,10:end] + mtl_view_b = @view mtl_b[10:end] + + mtl_c = mtl_view_a * mtl_view_b + c = view_a * view_b + + @test mtl_c == mtl(c) +end + @testset "mixed-precision matrix vector multiplication" begin N = 10 rows = N