-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors] Get more block sparse operations working on GPU (#1215)
- Loading branch information
Showing
33 changed files
with
399 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function Base.getindex(::Type{<:CuArray}, T::DenseTensor{<:Number}) | ||
return CUDA.@allowscalar data(T)[] | ||
end | ||
|
||
function Base.setindex!(::Type{<:CuArray}, T::DenseTensor{<:Number}, x::Number) | ||
CUDA.@allowscalar data(T)[] = x | ||
return T | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Catches a bug in `copyto!` in Metal backend. | ||
function NDTensors.copyto!( | ||
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::SubArray | ||
) | ||
return Base.copyto!(dest, copy(src)) | ||
end | ||
|
||
# Catches a bug in `copyto!` in Metal backend. | ||
function NDTensors.copyto!( | ||
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::Base.ReshapedArray | ||
) | ||
return NDTensors.copyto!(dest, parent(src)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function Base.getindex(::Type{<:MtlArray}, T::DenseTensor{<:Number}) | ||
return Metal.@allowscalar data(T)[] | ||
end | ||
|
||
function Base.setindex!(::Type{<:MtlArray}, T::DenseTensor{<:Number}, x::Number) | ||
Metal.@allowscalar data(T)[] = x | ||
return T | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
function NDTensors.qr(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix) | ||
Q, R = NDTensors.qr(NDTensors.cpu(A)) | ||
return adapt(leaf_parenttype, Matrix(Q)), adapt(leaf_parenttype, R) | ||
end | ||
|
||
function NDTensors.eigen(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix) | ||
D, U = NDTensors.eigen(NDTensors.cpu(A)) | ||
return adapt(set_ndims(leaf_parenttype, ndims(D)), D), adapt(leaf_parenttype, U) | ||
end | ||
|
||
function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix) | ||
U, S, V = NDTensors.svd(NDTensors.cpu(A)) | ||
return adapt(leaf_parenttype, U), | ||
adapt(set_ndims(leaf_parenttype, ndims(S)), S), | ||
adapt(leaf_parenttype, V) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
function NDTensors.permutedims!( | ||
::Type{<:MtlArray}, | ||
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray}, | ||
::Type{<:MtlArray}, | ||
A, | ||
perm, | ||
) | ||
Aperm = permutedims(A, perm) | ||
Adest_parent = parent(Adest) | ||
copyto!(Adest_parent, Aperm) | ||
return Adest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# NDTensors.copyto! | ||
function copyto!(R::AbstractArray, T::AbstractArray) | ||
copyto!(leaf_parenttype(R), R, leaf_parenttype(T), T) | ||
return R | ||
end | ||
|
||
# NDTensors.copyto! | ||
function copyto!( | ||
::Type{<:AbstractArray}, R::AbstractArray, ::Type{<:AbstractArray}, T::AbstractArray | ||
) | ||
Base.copyto!(R, T) | ||
return R | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# TODO: Make `isgpu`, `ismtl`, etc. | ||
# For `isgpu`, will require a `NDTensorsGPUArrayCoreExt`. | ||
iscu(A::AbstractArray) = iscu(typeof(A)) | ||
function iscu(A::Type{<:AbstractArray}) | ||
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A))) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Trait indicating if the AbstractArray type is an array wrapper. | ||
# Assumes that it implements `NDTensors.parenttype`. | ||
@traitdef IsWrappedArray{ArrayT} | ||
|
||
#! format: off | ||
@traitimpl IsWrappedArray{ArrayT} <- is_wrapped_array(ArrayT) | ||
#! format: on | ||
|
||
is_wrapped_array(arraytype::Type{<:AbstractArray}) = (parenttype(arraytype) ≠ arraytype) | ||
|
||
# TODO: This is only defined because the current design | ||
# of `Diag` using a `Number` as the data type if it | ||
# is a uniform diagonal type. Delete this when it is | ||
# replaced by `DiagonalArray`. | ||
is_wrapped_array(arraytype::Type{<:Number}) = false | ||
|
||
# For working with instances, not used by | ||
# `SimpleTraits.jl` traits dispatch. | ||
is_wrapped_array(array::AbstractArray) = is_wrapped_array(typeof(array)) | ||
|
||
# By default, the `parentype` of an array type is itself | ||
parenttype(arraytype::Type{<:AbstractArray}) = arraytype | ||
|
||
# TODO: Use `SetParameters` here. | ||
parenttype(::Type{<:ReshapedArray{<:Any,<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:Transpose{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:Adjoint{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:Symmetric{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:Hermitian{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:UpperTriangular{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:LowerTriangular{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P | ||
parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P | ||
|
||
# For working with instances, not used by | ||
# `SimpleTraits.jl` traits dispatch. | ||
parenttype(array::AbstractArray) = parenttype(typeof(array)) | ||
|
||
@traitfn function leaf_parenttype( | ||
arraytype::Type{ArrayT} | ||
) where {ArrayT; IsWrappedArray{ArrayT}} | ||
return leaf_parenttype(parenttype(arraytype)) | ||
end | ||
|
||
@traitfn function leaf_parenttype( | ||
arraytype::Type{ArrayT} | ||
) where {ArrayT; !IsWrappedArray{ArrayT}} | ||
return arraytype | ||
end | ||
|
||
# For working with instances. | ||
leaf_parenttype(array::AbstractArray) = leaf_parenttype(typeof(array)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# NDTensors.qr | ||
qr(A::AbstractMatrix) = qr(leaf_parenttype(A), A) | ||
qr(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.qr(A) | ||
|
||
# NDTensors.eigen | ||
eigen(A::AbstractMatrix) = eigen(leaf_parenttype(A), A) | ||
eigen(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.eigen(A) | ||
|
||
# NDTensors.svd | ||
svd(A::AbstractMatrix) = svd(leaf_parenttype(A), A) | ||
svd(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.svd(A) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.