diff --git a/pi.jl b/pi.jl index ad94a44..8bd1b6c 100644 --- a/pi.jl +++ b/pi.jl @@ -1,16 +1,27 @@ using Base.Cartesian +using ResumableFunctions -function pi2(dims::Int, tensize::Int) - multinds = Tuple{fill(Int, dims)...,}[] +struct PyramidIndices{N} + size::Int +end + +@generated function _all_indices(p::PyramidIndices{N}) where {N} quote - @nloops $dims i x -> (x==dims) ? (1:tensize) : (i_{x+1}:tensize) begin - @inbounds multind = @ntuple $dims x -> i_{dims-x+1} - push!($multinds, multind) + # multinds = Tuple{fill(Int, $N)...,}[] + tensize = p.size + @nloops $N i x -> (x==$N) ? (1:tensize) : (i_{x+1}:tensize) begin + @inbounds multind = @ntuple $N x -> i_{$N-x+1} + @yield multind + # push!(multinds, multind) end + # multinds end - multinds end +function pi2(dims, tensize) + p = PyramidIndices{dims}(tensize) + return _all_indices(p) +end a = rand(2,2,2) dims, tensize = ndims(a), size(a, 1) @@ -23,11 +34,14 @@ dims, tensize = ndims(b), size(b, 1) @time pi2(dims, tensize) @time pi2(dims, tensize) +@show pi2(dims, tensize) + c = rand(4,4,4,4) dims, tensize = ndims(c), size(c, 1) @time pi2(dims, tensize) @time pi2(dims, tensize) +@show "asdasd" @time d = rand(10, 100, 100, 100, 100) dims, tensize = ndims(d), size(d, 1) @time pi2(dims, tensize) @@ -39,4 +53,3 @@ dims, tensize = ndims(e), size(e, 1) @time pi2(dims, tensize) -@show pi2(dims, tensize) \ No newline at end of file diff --git a/src/symmetrictensor.jl b/src/symmetrictensor.jl index e43fd57..ccb54e6 100644 --- a/src/symmetrictensor.jl +++ b/src/symmetrictensor.jl @@ -29,6 +29,22 @@ mutable struct SymmetricTensor{T <: AbstractFloat, N} end end +struct PyramidIndices{N} + size::Int +end + +@generated function _all_indices(p::PyramidIndices{N}) where {N} + quote + multinds = Tuple{fill(Int, $N)...,}[] + tensize = p.size + @nloops $N i x -> (x==$N) ? (1:tensize) : (i_{x+1}:tensize) begin + @inbounds multind = @ntuple $N x -> i_{$N-x+1} + push!(multinds, multind) + end + multinds + end +end + """ unfold(ar::Array{T,N}, mode::Int) @@ -123,14 +139,8 @@ julia> pyramidindices(2,3) ``` """ function pyramidindices(dims::Int, tensize::Int) - quote - multinds = Tuple{fill(Int, dims)...,}[] - @nloops $dims i x -> (x==$dims) ? (1:tensize) : (i_{x+1}:tensize) begin - @inbounds multind = @ntuple $dims x -> i_{$dims-x+1} - push!(multinds, multind) - end - multinds - end + p = PyramidIndices{dims}(tensize) + return _all_indices(p) end """