diff --git a/src/TypeSortedCollections.jl b/src/TypeSortedCollections.jl index b0c1947..ec22f24 100644 --- a/src/TypeSortedCollections.jl +++ b/src/TypeSortedCollections.jl @@ -228,36 +228,43 @@ end end ## broadcast! +@generated function _broadcast!(f, dest, A, Bs...) + T = first_tsc_type(A, Bs...) + N = num_types(T) + expr = Expr(:block) + push!(expr.args, :(Base.@_inline_meta)) # TODO: good idea? + push!(expr.args, :(leading_tsc = first_tsc(A, Bs...))) + push!(expr.args, :(@boundscheck lengths_match(dest, A, Bs...) || lengths_match_fail())) + for i = 1 : N + vali = Val(i) + push!(expr.args, quote + let inds = leading_tsc.indices[$i] + @boundscheck indices_match($vali, inds, A, Bs...) || indices_match_fail() + @inbounds for j in linearindices(inds) + vecindex = inds[j] + _setindex!($vali, j, vecindex, dest, f(_getindex_all($vali, j, vecindex, A, Bs...)...)) + end + end + end) + end + quote + Base.@_inline_meta + $expr + dest + end +end + @static if VERSION >= v"0.7.0-DEV.3181" struct TypeSortedStyle <: Broadcast.BroadcastStyle end Base.BroadcastStyle(::Type{<:TypeSortedCollection}) = TypeSortedStyle() Base.BroadcastStyle(::Broadcast.AbstractArrayStyle{1}, ::TypeSortedStyle) = TypeSortedStyle() Base.BroadcastStyle(::Broadcast.Scalar, ::TypeSortedStyle) = TypeSortedStyle() - @generated function Base.broadcast!(f, dest, ::TypeSortedStyle, A, Bs...) - T = first_tsc_type(A, Bs...) - N = num_types(T) - expr = Expr(:block) - push!(expr.args, :(Base.@_inline_meta)) # TODO: good idea? - push!(expr.args, :(leading_tsc = first_tsc(A, Bs...))) - push!(expr.args, :(@boundscheck lengths_match(dest, A, Bs...) || lengths_match_fail())) - for i = 1 : N - vali = Val(i) - push!(expr.args, quote - let inds = leading_tsc.indices[$i] - @boundscheck indices_match($vali, inds, A, Bs...) || indices_match_fail() - @inbounds for j in linearindices(inds) - vecindex = inds[j] - _setindex!($vali, j, vecindex, dest, f(_getindex_all($vali, j, vecindex, A, Bs...)...)) - end - end - end) - end - quote - $expr - dest - end + @inline function Base.broadcast!(f, dest, ::TypeSortedStyle, A, Bs...) + _broadcast!(f, dest, A, Bs...) end + + @inline Base.broadcast!(f, A::TypeSortedCollection, ::Nothing, Bs...) = _broadcast!(f, A, Bs...) else Base.Broadcast._containertype(::Type{<:TypeSortedCollection}) = TypeSortedCollection Base.Broadcast.promote_containertype(::Type{TypeSortedCollection}, _) = TypeSortedCollection @@ -266,30 +273,12 @@ else Base.Broadcast.promote_containertype(::Type{TypeSortedCollection}, ::Type{Array}) = TypeSortedCollection # handle ambiguities with `Array` Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TypeSortedCollection}) = TypeSortedCollection # handle ambiguities with `Array` - @generated function Base.Broadcast.broadcast_c!(f, ::Type, ::Type{TypeSortedCollection}, dest::AbstractVector, A, Bs...) - T = first_tsc_type(A, Bs...) - N = num_types(T) - expr = Expr(:block) - push!(expr.args, :(Base.@_inline_meta)) # TODO: good idea? - push!(expr.args, :(leading_tsc = first_tsc(A, Bs...))) - push!(expr.args, :(@boundscheck lengths_match(dest, A, Bs...) || lengths_match_fail())) - for i = 1 : N - vali = Val(i) - push!(expr.args, quote - let inds = leading_tsc.indices[$i] - @boundscheck indices_match($vali, inds, A, Bs...) || indices_match_fail() - @inbounds for j in linearindices(inds) - vecindex = inds[j] - _setindex!($vali, j, vecindex, dest, f(_getindex_all($vali, j, vecindex, A, Bs...)...)) - end - end - end) - end - quote - $expr - dest - end + @inline function Base.Broadcast.broadcast_c!(f, ::Type, ::Type{TypeSortedCollection}, dest::AbstractVector, A, Bs...) + _broadcast!(f, dest, A, Bs...) end + + @inline Base.broadcast!(f, A::TypeSortedCollection, Bs...) = _broadcast!(f, A, Bs...) end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index 9ec6e60..6ec83af 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -258,6 +258,16 @@ end @test (@allocated broadcast!(M.g, results, sortedx, y1, sortedy2)) == 0 end +@testset "broadcast! TSC destination" begin + x = Number[3.; 4; 5] + sortedx = TypeSortedCollection(x) + results = typeof(sortedx)(indices(sortedx)) + results .= 3 .* sortedx + @test results.data[1][1] === 3 * 3. + @test results.data[2][1] === 3 * 4 + @test results.data[2][2] === 3 * 5 +end + @testset "eltype" begin x = [4.; 5; 3.; Int32(2); Int16(1); "foo"] let sortedx = TypeSortedCollection(x)