From b421ad1067a401b33723649f126fb7712c2e5699 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 16 Nov 2024 14:07:38 -0500 Subject: [PATCH] [NestedPermutedDimsArrays] Fix setindex --- NDTensors/Project.toml | 2 +- .../src/NestedPermutedDimsArrays.jl | 41 ++++++++++++++++++- .../NestedPermutedDimsArrays/test/runtests.jl | 16 ++++---- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index c367740318..2d2bc68d1c 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.71" +version = "0.3.72" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl b/NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl index d9749215ff..9234f6aed1 100644 --- a/NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl +++ b/NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl @@ -1,6 +1,45 @@ # Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl # Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose` # (though those are fully recursive). +#= +TODO: Investigate replacing this with a `PermutedDimsArray` wrapped around a `MappedArrays.MappedArray`. +There are a few issues with that: +1. Just using a type alias leads to type piracy, for example the constructor is type piracy. +2. `setindex!(::NestedPermutedDimsArray, I...)` fails because no conversion is defined between `Array` +and `PermutedDimsArray`. +3. The type alias is tricky to define, ideally it would have similar type parameters to the current +`NestedPermutedDimsArrays.NestedPermutedDimsArray` definition which matches the type parameters +of `PermutedDimsArrays.PermutedDimsArray` but that seems to be difficult to achieve. +```julia +module NestedPermutedDimsArrays + +using MappedArrays: MultiMappedArray, mappedarray +export NestedPermutedDimsArray + +const NestedPermutedDimsArray{TT,T<:AbstractArray{TT},N,perm,iperm,AA<:AbstractArray{T}} = PermutedDimsArray{ + PermutedDimsArray{TT,N,perm,iperm,T}, + N, + perm, + iperm, + MultiMappedArray{ + PermutedDimsArray{TT,N,perm,iperm,T}, + N, + Tuple{AA}, + Type{PermutedDimsArray{TT,N,perm,iperm,T}}, + Type{PermutedDimsArray{TT,N,iperm,perm,T}}, + }, +} + +function NestedPermutedDimsArray(a::AbstractArray, perm) + iperm = invperm(perm) + f = PermutedDimsArray{eltype(eltype(a)),ndims(a),perm,iperm,eltype(a)} + finv = PermutedDimsArray{eltype(eltype(a)),ndims(a),iperm,perm,eltype(a)} + return PermutedDimsArray(mappedarray(f, finv, a), perm) +end + +end +``` +=# module NestedPermutedDimsArrays import Base: permutedims, permutedims! @@ -107,7 +146,7 @@ end A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N} ) where {T,N,perm,iperm} @boundscheck checkbounds(A, I...) - @inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...) + @inbounds setindex!(A.parent, PermutedDimsArray(val, iperm), genperm(I, iperm)...) return val end diff --git a/NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl b/NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl index d13b881540..704297fcc2 100644 --- a/NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl +++ b/NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl @@ -5,19 +5,19 @@ using Test: @test, @testset Float32, Float64, Complex{Float32}, Complex{Float64} ) a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4))) - perm = (3, 2, 1) + perm = (3, 1, 2) p = NestedPermutedDimsArray(a, perm) T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)} @test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)} - @test size(p) == (4, 3, 2) + @test size(p) == (4, 2, 3) @test eltype(p) === T for I in eachindex(p) - @test size(p[I]) == (4, 3, 2) - @test p[I] == permutedims(a[CartesianIndex(reverse(Tuple(I)))], perm) + @test size(p[I]) == (4, 2, 3) + @test p[I] == permutedims(a[CartesianIndex(map(i -> Tuple(I)[i], invperm(perm)))], perm) end - x = randn(elt, 4, 3, 2) - p[3, 2, 1] = x - @test p[3, 2, 1] == x - @test a[1, 2, 3] == permutedims(x, perm) + x = randn(elt, 4, 2, 3) + p[3, 1, 2] = x + @test p[3, 1, 2] == x + @test a[1, 2, 3] == permutedims(x, invperm(perm)) end end