Skip to content

Commit

Permalink
Merge pull request #81 from mcabbott/map
Browse files Browse the repository at this point in the history
map + collect
  • Loading branch information
oxinabox authored Nov 15, 2019
2 parents 7b6f001 + 7b9f269 commit 51893cf
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDims"
uuid = "356022a1-0364-5f58-8944-0da4b18d706f"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.10"
version = "0.2.11"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
55 changes: 54 additions & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function nameddimsarray_result(original_nda, reduced_data, reduction_dims::Colon
return reduced_data
end

###################################################################################
################################################
# Overloads

# 1 Arg
Expand Down Expand Up @@ -138,3 +138,56 @@ function Base.append!(A::NamedDimsArray{L,T,1}, B::AbstractVector) where {L,T}
data = append!(parent(A), unname(B))
return NamedDimsArray{newL}(data)
end

################################################
# map, collect

Base.map(f, A::NamedDimsArray) = NamedDimsArray(map(f, parent(A)), names(A))

for (T, S) in [
(:NamedDimsArray, :AbstractArray),
(:AbstractArray, :NamedDimsArray),
(:NamedDimsArray, :NamedDimsArray),
]
for fun in [:map, :map!]

# Here f::F where {F} is needed to avoid ambiguities in Julia 1.0
@eval function Base.$fun(f::F, a::$T, b::$S, cs::AbstractArray...) where {F}
data = $fun(f, unname(a), unname(b), unname.(cs)...)
new_names = unify_names(names(a), names(b), names.(cs)...)
return NamedDimsArray(data, new_names)
end

end

@eval function Base.foreach(f::F, a::$T, b::$S, cs::AbstractArray...) where {F}
data = foreach(f, unname(a), unname(b), unname.(cs)...)
unify_names(names(a), names(b), names.(cs)...)
return nothing
end
end

Base.filter(f, A::NamedDimsArray{L,T,1}) where {L,T} = NamedDimsArray(filter(f, parent(A)), L)
Base.filter(f, A::NamedDimsArray{L,T,N}) where {L,T,N} = filter(f, parent(A))


# We overload collect on various kinds of `Generators` so that that can keep names.
function Base.collect(x::Base.Generator{<:NamedDimsArray{L}}) where {L}
data = collect(Base.Generator(x.f, parent(x.iter)))
return NamedDimsArray(data, L)
end

function Base.collect(x::Base.Generator{<:Iterators.Enumerate{<:NamedDimsArray{L}}}) where {L}
data = collect(Base.Generator(x.f, enumerate(parent(x.iter.itr))))
return NamedDimsArray(data, L)
end

Base.collect(x::Base.Generator{<:Iterators.ProductIterator{<:Tuple{<:NamedDimsArray,Vararg{Any}}}}) = collect_product(x)
Base.collect(x::Base.Generator{<:Iterators.ProductIterator{<:Tuple{<:Any,<:NamedDimsArray,Vararg{Any}}}}) = collect_product(x)
Base.collect(x::Base.Generator{<:Iterators.ProductIterator{<:Tuple{<:NamedDimsArray,<:NamedDimsArray,Vararg{Any}}}}) = collect_product(x)

function collect_product(x)
data = collect(Base.Generator(x.f, Iterators.product(unname.(x.iter.iterators)...)))
all_names = tuple_cat(names.(x.iter.iterators)...)
return NamedDimsArray(data, all_names)
end
14 changes: 14 additions & 0 deletions src/name_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ function unify_names(names_a, names_b)
return compile_time_return_hack(ret)
end

unify_names(a) = a
unify_names(a, b, cs...) = unify_names(unify_names(a,b), cs...)
# @btime (()->unify_names((:a, :b), (:a, :_), (:_, :b)))()

"""
unify_names_longest(a, b)
Expand Down Expand Up @@ -292,3 +296,13 @@ end
keep_names = [:(getfield(dimnames, $ii)) for ii in 1:N if ii βˆ‰ dropped_dims_vals]
return Expr(:call, :compile_time_return_hack, Expr(:tuple, keep_names...))
end

"""
tuple_cat(x, y, zs...)
This is like `vcat` for tuples, it splats everything into one long tuple.
"""
tuple_cat(x::Tuple, ys::Tuple...) = (x..., tuple_cat(ys...)...)
tuple_cat() = ()
# @btime tuple_cat((1, 2), (3, 4, 5), (6,)) # 0 allocations

66 changes: 66 additions & 0 deletions test/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,72 @@ using Statistics
@test length(ndv) == 0
end

@testset "map, map!" begin
nda = NamedDimsArray([11 12; 21 22], (:x, :y))

@test names(map(+, nda, nda, nda)) == (:x, :y)
@test names(map(+, nda, parent(nda), nda)) == (:x, :y)
@test names(map(+, parent(nda), nda)) == (:x, :y)

# this method only called based on first two arguments:
@test names(map(+, parent(nda), parent(nda), nda)) == (:_, :_)

# one-arg forms work without adding anything... except on 1.0...
@test names(map(sqrt, nda)) == (:x, :y)
@test foreach(sqrt, nda) === nothing

# map! may return a different wrapper of the same data, like sum!
semi = NamedDimsArray(rand(2,2), (:x, :_))
@test names(map!(sqrt, rand(2,2), nda)) == (:x, :y)
@test names(map!(sqrt, semi, nda)) == (:x, :y)

zed = similar(nda, Float64)
@test map!(sqrt, zed, nda) == sqrt.(nda)
@test zed[1,1] == sqrt(nda[1,1])

# mismatching names
@test_throws DimensionMismatch map(+, nda, transpose(nda))
@test_throws DimensionMismatch map(+, nda, parent(nda), nda, transpose(nda))
@test_throws DimensionMismatch map!(sqrt, semi, transpose(nda))

@test foreach(+, semi, nda) === nothing
@test_throws DimensionMismatch foreach(+, semi, transpose(nda))
end

@testset "filter" begin
nda = NamedDimsArray([11 12; 21 22], (:x, :y))
ndv = NamedDimsArray(1:7, (:z,))

@test names(filter(isodd, ndv)) == (:z,)
@test names(filter(isodd, nda)) == (:_,)
end

@testset "collect(generator)" begin
nda = NamedDimsArray([11 12; 21 22], (:x, :y))
ndv = NamedDimsArray([10, 20, 30], (:z,))

@test names([sqrt(x) for x in nda]) == (:x, :y)

@test names([x^i for (i,x) in enumerate(ndv)]) == (:z,)
@test names([x^i for (i,x) in enumerate(nda)]) == (:x, :y)

# Iterators.product -- has all names
@test names([x+y for x in nda, y in ndv]) == (:x, :y, :z)
@test names([x+y for x in nda, y in 1:5]) == (:x, :y, :_)
@test names([x+y for x in 1:5, y in ndv]) == (:_, :z)
four = [x*y/z^p for p in 1:2, x in ndv, y in 1:2, z in nda]
@test names(four) == (:_, :z, :_, :x, :y)

# Iterators.flatten -- no obvious name to use
@test names([x+y for x in nda for y in ndv]) == (:_,)

if VERSION >= v"1.1"
# can't see inside eachslice generators, until:
# https://github.com/JuliaLang/julia/pull/32310
@test names([sum(c) for c in eachcol(nda)]) == (:_,)
end
end

end # Base

@testset "Statistics" begin
Expand Down
18 changes: 18 additions & 0 deletions test/name_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using NamedDims:
unify_names_shortest,
dim_noerror,
tuple_issubset,
tuple_cat,
order_named_inds,
permute_dimnames,
remaining_dimnames_from_indexing,
Expand Down Expand Up @@ -64,6 +65,11 @@ end
@test_throws DimensionMismatch unify((:a,:b), (:b, :a))
@test_throws DimensionMismatch unify((:a, :b, :c), (:_, :_, :d))
end

# vararg version
@test unify_names((:a, :_), (:a, :b,), (:_, :b)) == (:a, :b)
@test unify_names((:a, :b,)) == (:a, :b)
@test_throws DimensionMismatch unify_names((:a, :_), (:a, :b,), (:_, :c))
end
@testset "allocations: unify_names_*" begin
for unify in (unify_names, unify_names_longest, unify_names_shortest)
Expand All @@ -76,10 +82,13 @@ end
if VERSION >= v"1.1"
@test 0 == @allocated (()->unify_names_longest((:a, :b), (:a, :_, :c)))()
@test 0 == @allocated (()->unify_names_shortest((:a, :b), (:a, :_, :c)))()
@test 0 == @allocated (()->unify_names((:a, :b), (:a, :_), (:_, :b)))()
else
@test_broken 0 == @allocated (()->unify_names_longest((:a, :b), (:a, :_, :c)))()
@test_broken 0 == @allocated (()->unify_names_shortest((:a, :b), (:a, :_, :c)))()
@test_broken 0 == @allocated (()->unify_names((:a, :b), (:a, :_), (:_, :b)))()
end

end


Expand Down Expand Up @@ -156,3 +165,12 @@ end
@test 0 == @allocated tuple_issubset((:a, :c), (:a, :b, :c))
@test 0 == @allocated tuple_issubset((:a, :b, :c), (:a, :c))
end


@testset "tuple_cat" begin
@test tuple_cat((1, 2), (3, 4, 5), (6,)) == (1, 2, 3, 4, 5, 6)
@test tuple_cat((1, 2)) == (1, 2)
end
@testset "allocations: tuple_cat" begin
@test 0 == @allocated tuple_cat((1, 2), (3, 4, 5), (6,))
end

5 comments on commit 51893cf

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@nickrobinson251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/5450

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.11 -m "<description of version>" 51893cfd36a771061c9fc0cf490d04ad5c7d86a1
git push origin v0.2.11

@nickrobinson251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 non-breaking updates is pretty solid, nice work @oxinabox and @mcabbott

Please sign in to comment.