Skip to content

Commit

Permalink
Merge pull request #15 from iitis/lp/pi-fix
Browse files Browse the repository at this point in the history
Fix `pyramidindices` performance
  • Loading branch information
kdomino authored May 17, 2023
2 parents 46d5ced + d915d2c commit 65e2127
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 22 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
name = "SymmetricTensors"
uuid = "1ab33d94-6c6c-50cc-93f0-e3f623a46aa0"
authors = ["Krzysztof Domino <[email protected]>", "Łukasz Pawela <[email protected]>", "Piotr Gawron <[email protected]>"]
version = "1.0.6"
version = "1.0.7"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
CompilerSupportLibraries_jll = "e66e0078-7015-5450-92f7-15fbd957f2ae"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"


[compat]
julia = "1"
Combinatorics = "1"
StatsBase = "0.33"
Memoization = "0.2"


[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
1 change: 0 additions & 1 deletion src/SymmetricTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ module SymmetricTensors
using StatsBase
using Random
using LinearAlgebra
using Memoization
import Base: +, -, *, /, size, getindex, rand, setindex!
if VERSION >= v"1.3"
using CompilerSupportLibraries_jll
Expand Down
2 changes: 1 addition & 1 deletion src/randgendat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function randsymarray(::Type{T}, dim::Int, N::Int = 4) where T<:Real
t = zeros(fill(dim, N)...,)
for i in pyramidindices(N,dim)
n = rand(T)
for j in collect(permutations(i))
for j in permutations(i)
@inbounds t[j...] = n
end
end
Expand Down
35 changes: 19 additions & 16 deletions src/symmetrictensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ mutable struct SymmetricTensor{T <: AbstractFloat, N}
end
end

struct PyramidIndices{N}
size::Int
end

@generated function _all_indices(p::PyramidIndices{N}) where {N}
quote
multinds = Tuple{fill(Int, $N)...,}[]
tensize = p.size
@nloops $N i x -> (x==$N) ? (1:tensize) : (i_{x+1}:tensize) begin
@inbounds multind = @ntuple $N x -> i_{$N-x+1}
push!(multinds, multind)
end
multinds
end
end

"""
unfold(ar::Array{T,N}, mode::Int)
Expand Down Expand Up @@ -122,24 +138,11 @@ julia> pyramidindices(2,3)
(3,3)
```
"""
@memoize function pyramidindices(dims::Int, tensize::Int)
multinds = Tuple{fill(Int,dims)...,}[]
@eval begin
@nloops $dims i x -> (x==$dims) ? (1:$tensize) : (i_{x+1}:$tensize) begin
@inbounds multind = @ntuple $dims x -> i_{$dims-x+1}
push!($multinds, multind)
end
end
multinds
function pyramidindices(dims::Int, tensize::Int)
p = PyramidIndices{dims}(tensize)
return _all_indices(p)
end

"""
pyramidindices(st::SymmetricTensor)
Return the indices of the unique elements of the given symmetric tensor.
"""
pyramidindices(st::SymmetricTensor{<:Any, N}) where N = pyramidindices(N, st.dats)

"""
sizetest(dats::Int, bls::Int)
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ t1 = randsymarray(7, 3)
end
end
@test found == length(pinds)
@test pyramidindices(st) == pinds
end
end

Expand Down

0 comments on commit 65e2127

Please sign in to comment.