Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPSMatrix improvements #157

Merged
merged 5 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

export MPSMatrixDescriptor

@objcwrapper MPSMatrixDescriptor <: NSObject
@objcwrapper immutable=false MPSMatrixDescriptor <: NSObject

@objcproperties MPSMatrixDescriptor begin
@autoproperty rows::NSUInteger setter=setRows
@autoproperty columns::NSUInteger setter=setColumns
@autoproperty matrices::NSUInteger
@autoproperty dataType::MPSDataType setter=setDataType
@autoproperty rowBytes::NSUInteger setter=setRowBytes
@autoproperty matrixBytes::NSUInteger
end


# Mapping from Julia types to the Performance Shader bitfields
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
Expand Down Expand Up @@ -49,6 +59,17 @@ function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
return obj
end

function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dataType)
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
columns:columns::NSUInteger
matrices:matrices::NSUInteger
rowBytes:rowBytes::NSUInteger
matrixBytes:matrixBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
obj = MPSMatrixDescriptor(desc)
# XXX: who releases this object?
return obj
end

#
# matrix object
Expand All @@ -58,6 +79,19 @@ export MPSMatrix

@objcwrapper immutable=false MPSMatrix <: NSObject

@objcproperties MPSMatrix begin
@autoproperty device::id{MTLDevice}
@autoproperty rows::NSUInteger
@autoproperty columns::NSUInteger
@autoproperty matrices::NSUInteger
@autoproperty dataType::MPSDataType
@autoproperty rowBytes::NSUInteger
@autoproperty matrixBytes::NSUInteger
@autoproperty offset::NSUInteger
@autoproperty data::id{MTLBuffer}
end


"""
MPSMatrix(arr::MtlMatrix)

Expand All @@ -71,13 +105,37 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
desc = MPSMatrixDescriptor(n_rows, n_cols, sizeof(T)*n_cols, T)
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
obj = MPSMatrix(mat)
offset = arr.offset * sizeof(T)
finalizer(release, obj)
@objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
return obj
end


"""
MPSMatrix(arr::MtlArray{T,3})

Metal batched matrix representation used in Performance Shaders.

Note that this results in a transposed view of the input,
as Metal stores matrices row-major instead of column-major.
"""
function MPSMatrix(arr::MtlArray{T,3}) where T
n_cols, n_rows, n_matrices = size(arr)
row_bytes = sizeof(T)*n_cols
desc = MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T)
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
obj = MPSMatrix(mat)
offset = arr.offset * sizeof(T)
finalizer(release, obj)
@objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
return obj
end

#
# matrix multiplication
#
Expand Down
2 changes: 2 additions & 0 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ function MPSVector(arr::MtlVector{T}) where T
desc = MPSVectorDescriptor(len, T)
vec = @objc [MPSVector alloc]::id{MPSVector}
obj = MPSVector(vec)
offset = arr.offset * sizeof(T)
finalizer(release, obj)
@objc [obj::id{MPSVector} initWithBuffer:arr::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSVectorDescriptor}]::id{MPSVector}
return obj
end
Expand Down
20 changes: 20 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ if MPS.is_supported(current_device())
end
end

@testset "test matrix vector multiplication of views" begin
N = 20
a = rand(Float32, N,N)
b = rand(Float32, N)

mtl_a = mtl(a)
mtl_b = mtl(b)

view_a = @view a[:,10:end]
view_b = @view b[10:end]

mtl_view_a = @view mtl_a[:,10:end]
mtl_view_b = @view mtl_b[10:end]

mtl_c = mtl_view_a * mtl_view_b
c = view_a * view_b

@test mtl_c == mtl(c)
end

@testset "mixed-precision matrix vector multiplication" begin
N = 10
rows = N
Expand Down