Skip to content

Commit

Permalink
[NDTensors] Get more block sparse operations working on GPU (#1215)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Oct 24, 2023
1 parent 3f043f5 commit 871e59d
Show file tree
Hide file tree
Showing 33 changed files with 399 additions and 152 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ NDTensorsOctavianExt = "Octavian"
NDTensorsTBLISExt = "TBLIS"

[compat]
Adapt = "3.5"
Adapt = "3.7"
BlockArrays = "0.16"
Compat = "4.9"
Dictionaries = "0.3.5"
Expand Down
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ include("imports.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
include("indexing.jl")
include("linearalgebra.jl")
end
8 changes: 8 additions & 0 deletions NDTensors/ext/NDTensorsCUDAExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function Base.getindex(::Type{<:CuArray}, T::DenseTensor{<:Number})
return CUDA.@allowscalar data(T)[]
end

function Base.setindex!(::Type{<:CuArray}, T::DenseTensor{<:Number}, x::Number)
CUDA.@allowscalar data(T)[] = x
return T
end
9 changes: 7 additions & 2 deletions NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module NDTensorsMetalExt

using Adapt
using Functors
using LinearAlgebra: LinearAlgebra
using NDTensors
using NDTensors.SetParameters
using Functors
using Adapt

if isdefined(Base, :get_extension)
using Metal
Expand All @@ -14,4 +15,8 @@ end
include("imports.jl")
include("adapt.jl")
include("set_types.jl")
include("indexing.jl")
include("linearalgebra.jl")
include("copyto.jl")
include("permutedims.jl")
end
13 changes: 13 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Catches a bug in `copyto!` in Metal backend.
function NDTensors.copyto!(
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::SubArray
)
return Base.copyto!(dest, copy(src))
end

# Catches a bug in `copyto!` in Metal backend.
function NDTensors.copyto!(
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::Base.ReshapedArray
)
return NDTensors.copyto!(dest, parent(src))
end
8 changes: 8 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function Base.getindex(::Type{<:MtlArray}, T::DenseTensor{<:Number})
return Metal.@allowscalar data(T)[]
end

function Base.setindex!(::Type{<:MtlArray}, T::DenseTensor{<:Number}, x::Number)
Metal.@allowscalar data(T)[] = x
return T
end
16 changes: 16 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function NDTensors.qr(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
Q, R = NDTensors.qr(NDTensors.cpu(A))
return adapt(leaf_parenttype, Matrix(Q)), adapt(leaf_parenttype, R)
end

function NDTensors.eigen(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
D, U = NDTensors.eigen(NDTensors.cpu(A))
return adapt(set_ndims(leaf_parenttype, ndims(D)), D), adapt(leaf_parenttype, U)
end

function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
U, S, V = NDTensors.svd(NDTensors.cpu(A))
return adapt(leaf_parenttype, U),
adapt(set_ndims(leaf_parenttype, ndims(S)), S),
adapt(leaf_parenttype, V)
end
12 changes: 12 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function NDTensors.permutedims!(
::Type{<:MtlArray},
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray},
::Type{<:MtlArray},
A,
perm,
)
Aperm = permutedims(A, perm)
Adest_parent = parent(Adest)
copyto!(Adest_parent, Aperm)
return Adest
end
5 changes: 5 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ include("algorithm.jl")
include("aliasstyle.jl")
include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/iswrappedarray.jl")
include("abstractarray/iscu.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/copyto.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
include("abstractarray/linearalgebra.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
Expand All @@ -68,6 +72,7 @@ include("dims.jl")
include("tensor/set_types.jl")
include("tensor/similar.jl")
include("tensor/permutedims.jl")
include("tensor/linearalgebra.jl")
include("adapt.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")
Expand Down
13 changes: 13 additions & 0 deletions NDTensors/src/abstractarray/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# NDTensors.copyto!
function copyto!(R::AbstractArray, T::AbstractArray)
copyto!(leaf_parenttype(R), R, leaf_parenttype(T), T)
return R
end

# NDTensors.copyto!
function copyto!(
::Type{<:AbstractArray}, R::AbstractArray, ::Type{<:AbstractArray}, T::AbstractArray
)
Base.copyto!(R, T)
return R
end
6 changes: 6 additions & 0 deletions NDTensors/src/abstractarray/iscu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# TODO: Make `isgpu`, `ismtl`, etc.
# For `isgpu`, will require a `NDTensorsGPUArrayCoreExt`.
iscu(A::AbstractArray) = iscu(typeof(A))
function iscu(A::Type{<:AbstractArray})
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
end
54 changes: 54 additions & 0 deletions NDTensors/src/abstractarray/iswrappedarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Trait indicating if the AbstractArray type is an array wrapper.
# Assumes that it implements `NDTensors.parenttype`.
@traitdef IsWrappedArray{ArrayT}

#! format: off
@traitimpl IsWrappedArray{ArrayT} <- is_wrapped_array(ArrayT)
#! format: on

is_wrapped_array(arraytype::Type{<:AbstractArray}) = (parenttype(arraytype) arraytype)

# TODO: This is only defined because the current design
# of `Diag` using a `Number` as the data type if it
# is a uniform diagonal type. Delete this when it is
# replaced by `DiagonalArray`.
is_wrapped_array(arraytype::Type{<:Number}) = false

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
is_wrapped_array(array::AbstractArray) = is_wrapped_array(typeof(array))

# By default, the `parentype` of an array type is itself
parenttype(arraytype::Type{<:AbstractArray}) = arraytype

# TODO: Use `SetParameters` here.
parenttype(::Type{<:ReshapedArray{<:Any,<:Any,P}}) where {P} = P
parenttype(::Type{<:Transpose{<:Any,P}}) where {P} = P
parenttype(::Type{<:Adjoint{<:Any,P}}) where {P} = P
parenttype(::Type{<:Symmetric{<:Any,P}}) where {P} = P
parenttype(::Type{<:Hermitian{<:Any,P}}) where {P} = P
parenttype(::Type{<:UpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:LowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P
parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
parenttype(array::AbstractArray) = parenttype(typeof(array))

@traitfn function leaf_parenttype(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
return leaf_parenttype(parenttype(arraytype))
end

@traitfn function leaf_parenttype(
arraytype::Type{ArrayT}
) where {ArrayT; !IsWrappedArray{ArrayT}}
return arraytype
end

# For working with instances.
leaf_parenttype(array::AbstractArray) = leaf_parenttype(typeof(array))
11 changes: 11 additions & 0 deletions NDTensors/src/abstractarray/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# NDTensors.qr
qr(A::AbstractMatrix) = qr(leaf_parenttype(A), A)
qr(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.qr(A)

# NDTensors.eigen
eigen(A::AbstractMatrix) = eigen(leaf_parenttype(A), A)
eigen(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.eigen(A)

# NDTensors.svd
svd(A::AbstractMatrix) = svd(leaf_parenttype(A), A)
svd(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.svd(A)
17 changes: 15 additions & 2 deletions NDTensors/src/abstractarray/permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@ end

# NDTensors.permutedims!
function permutedims!(Mdest::AbstractArray, M::AbstractArray, perm)
return permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
return Mdest
end

# NDTensors.permutedims!
function permutedims!(::Type{<:AbstractArray}, Mdest, ::Type{<:AbstractArray}, M, perm)
return Base.permutedims!(Mdest, M, perm)
Base.permutedims!(Mdest, M, perm)
return Mdest
end

function permutedims!!(B::AbstractArray, A::AbstractArray, perm)
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm)
end

function permutedims!!(
Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm
)
permutedims!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm)
return B
end

function permutedims!!(B::AbstractArray, A::AbstractArray, perm, f)
Expand Down
53 changes: 0 additions & 53 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,13 @@
## Custom `NDTensors.similar` implementation.
## More extensive than `Base.similar`.

# Trait indicating if the AbstractArray type is an array wrapper.
# Assumes that it implements `NDTensors.parenttype`.
@traitdef IsWrappedArray{ArrayT}

#! format: off
@traitimpl IsWrappedArray{ArrayT} <- is_wrapped_array(ArrayT)
#! format: on

is_wrapped_array(arraytype::Type{<:AbstractArray}) = (parenttype(arraytype) arraytype)

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
is_wrapped_array(array::AbstractArray) = is_wrapped_array(typeof(array))

# By default, the `parentype` of an array type is itself
parenttype(arraytype::Type{<:AbstractArray}) = arraytype

parenttype(::Type{<:ReshapedArray{<:Any,<:Any,P}}) where {P} = P
parenttype(::Type{<:Transpose{<:Any,P}}) where {P} = P
parenttype(::Type{<:Adjoint{<:Any,P}}) where {P} = P
parenttype(::Type{<:Symmetric{<:Any,P}}) where {P} = P
parenttype(::Type{<:Hermitian{<:Any,P}}) where {P} = P
parenttype(::Type{<:UpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:LowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P
parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
parenttype(array::AbstractArray) = parenttype(typeof(array))

@traitfn function leaf_parenttype(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
return leaf_parenttype(parenttype(arraytype))
end

@traitfn function leaf_parenttype(
arraytype::Type{ArrayT}
) where {ArrayT; !IsWrappedArray{ArrayT}}
return arraytype
end

# For working with instances.
leaf_parenttype(array::AbstractArray) = leaf_parenttype(typeof(array))

# This function actually allocates the data.
# NDTensors.similar
function similar(arraytype::Type{<:AbstractArray}, dims::Tuple)
shape = NDTensors.to_shape(arraytype, dims)
return similartype(arraytype, shape)(undef, NDTensors.to_shape(arraytype, shape))
end

# For when there are CUArray specific issues inline
iscu(A::AbstractArray) = iscu(typeof(A))
function iscu(A::Type{<:AbstractArray})
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
end
# This function actually allocates the data.
# Catches conversions of dimensions specified by ranges
# dimensions specified by integers with `Base.to_shape`.
Expand Down
10 changes: 10 additions & 0 deletions NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ function BlockSparse(
return BlockSparse(Vector{ElT}(undef, dim), blockoffsets; vargs...)
end

function BlockSparse(
datatype::Type{<:AbstractArray},
::UndefInitializer,
blockoffsets::BlockOffsets,
dim::Integer;
vargs...,
)
return BlockSparse(datatype(undef, dim), blockoffsets; vargs...)
end

function BlockSparse(blockoffsets::BlockOffsets, dim::Integer; vargs...)
return BlockSparse(Float64, blockoffsets, dim; vargs...)
end
Expand Down
Loading

0 comments on commit 871e59d

Please sign in to comment.