diff --git a/Project.toml b/Project.toml index 10a8695..7b74ebb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AxisKeys" uuid = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" license = "MIT" -version = "0.2.7" +version = "0.2.8" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -23,6 +23,7 @@ BenchmarkTools = "0.5, 1.0" ChainRulesCore = "1" ChainRulesTestUtils = "1" CovarianceEstimation = "0.2" +DataFrames = "1" FiniteDifferences = "0.12" IntervalSets = "0.5.1, 0.6, 0.7" InvertedIndices = "1.0" @@ -36,6 +37,7 @@ julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -45,4 +47,4 @@ UniqueVectors = "2fbcfb34-fd0c-5fbb-b5d7-e826d8f5b0a9" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["BenchmarkTools", "ChainRulesTestUtils", "Dates", "FiniteDifferences", "FFTW", "NamedArrays", "Test", "UniqueVectors", "Unitful"] +test = ["BenchmarkTools", "ChainRulesTestUtils", "DataFrames", "Dates", "FiniteDifferences", "FFTW", "NamedArrays", "Test", "UniqueVectors", "Unitful"] diff --git a/src/tables.jl b/src/tables.jl index 6e60d84..526d833 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -138,22 +138,30 @@ function populate!(A, table, value::Symbol; force=false) # Use a BitArray mask to detect duplicates and error instead of overwriting. mask = force ? falses() : falses(size(A)) - for r in Tables.rows(table) - vals = Tuple(Tables.getcolumn(r, c) for c in dimnames(A)) - inds = map(findindex, vals, axiskeys(A)) + cols = Tables.columns(table) + value_column = Tables.getcolumn(cols, value) + axis_key_columns = Tuple(Tables.getcolumn(cols, c) for c in dimnames(A)) + return populate_function_barrier!(A, value_column, axis_key_columns, mask, force) +end + +# eltypes of value and axis_key_columns aren't inferable in `populate!` if the `table` +# doesn't have typed columns, as is the case for DataFrames. By passing them into +# `populate_function_barrier!` once they've been pulled out of a DataFrame ensures +# inference is possible for the loop. +function populate_function_barrier!(A, value_column, axis_key_columns, mask, force) + for (val, keys...) in zip(value_column, axis_key_columns...) + inds = map(AxisKeys.findindex, keys, axiskeys(A)) # Handle duplicate error checking if applicable if !force # Error if mask already set. - mask[inds...] && throw(ArgumentError("Key $vals is not unique")) + mask[inds...] && throw(ArgumentError("Key $keys is not unique")) # Set mask, marking that we've set this index setindex!(mask, true, inds...) end - # Insert our value into the data array - setindex!(A, Tables.getcolumn(r, value), inds...) + setindex!(A, val, inds...) end - return A end diff --git a/test/_packages.jl b/test/_packages.jl index 1f32f9e..50facda 100644 --- a/test/_packages.jl +++ b/test/_packages.jl @@ -1,5 +1,10 @@ using Test, AxisKeys +function count_allocs(f, args...) + stats = @timed f(args...) + return Base.gc_alloc_count(stats.gcstats) +end + @testset "offset" begin using OffsetArrays @@ -38,6 +43,14 @@ end @test dimnames(k) == (:aa,) end end +@testset "DataFrames" begin + using DataFrames: DataFrame + + X = KeyedArray(randn(1000, 1500), a=1:1000, b=1:1500) + df = DataFrame(X) + wrapdims(df, :value, :a, :b) # compile + @test count_allocs(wrapdims, df, :value, :a, :b) < 1_000 +end @testset "tables" begin using Tables