Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin authored May 10, 2024
2 parents d8d864c + c16d0fc commit a5bdf44
Show file tree
Hide file tree
Showing 23 changed files with 371 additions and 123 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
matrix:
version:
- '1.6'
- '1.9'
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
- 'nightly'
os:
Expand Down
39 changes: 34 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AxisKeys"
uuid = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
license = "MIT"
version = "0.2.3"
version = "0.2.14"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -17,32 +17,61 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LazyStack = "1fad7336-0346-5a1a-a56f-a06ba010965b"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extensions]
AbstractFFTsExt = "AbstractFFTs"
ChainRulesCoreExt = "ChainRulesCore"
CovarianceEstimationExt = "CovarianceEstimation"
InvertedIndicesExt = "InvertedIndices"
LazyStackExt = "LazyStack"
OffsetArraysExt = "OffsetArrays"
StatisticsExt = "Statistics"
StatsBaseExt = "StatsBase"

[compat]
AbstractFFTs = "0.5, 1.0"
BenchmarkTools = "0.5, 1.0"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
CovarianceEstimation = "0.2"
DataFrames = "1"
FiniteDifferences = "0.12"
IntervalSets = "0.5.1, 0.6"
IntervalSets = "0.5.1, 0.6, 0.7"
InvertedIndices = "1.0"
LazyStack = "0.0.7, 0.0.8"
NamedDims = "0.2.46"
NamedDims = "0.2.46, 0.3, 1"
OffsetArrays = "0.10, 0.11, 1.0"
StatsBase = "0.32, 0.33"
StatsBase = "0.32, 0.33, 0.34"
Tables = "0.2, 1"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LazyStack = "1fad7336-0346-5a1a-a56f-a06ba010965b"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UniqueVectors = "2fbcfb34-fd0c-5fbb-b5d7-e826d8f5b0a9"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["BenchmarkTools", "ChainRulesTestUtils", "Dates", "FiniteDifferences", "FFTW", "NamedArrays", "Test", "UniqueVectors", "Unitful"]
test = ["BenchmarkTools", "CovarianceEstimation", "ChainRulesCore", "ChainRulesTestUtils", "DataFrames", "Dates", "FiniteDifferences", "FFTW", "InvertedIndices", "LazyStack", "NamedArrays", "OffsetArrays", "Test", "Statistics", "StatsBase", "UniqueVectors", "Unitful"]
8 changes: 6 additions & 2 deletions src/fft.jl → ext/AbstractFFTsExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
module AbstractFFTsExt

using AxisKeys: KeyedArray, NdaKa, axiskeys, keyless, NamedDims
using AbstractFFTs

#=
Simple support for FFTs using:
Expand All @@ -7,8 +11,6 @@ Does not (yet) cover plan_fft & friends,
because extracting the dimensions from those is tricky
=#

using AbstractFFTs

for fun in [:fft, :ifft, :bfft, :rfft]
@eval function AbstractFFTs.$fun(A::Union{KeyedArray,NdaKa}, dims = ntuple(+,ndims(A)))
numerical_dims = NamedDims.dim(A, dims)
Expand Down Expand Up @@ -80,3 +82,5 @@ function irfft_un_freq(x, len)
s = inv(step(x) * len)
range(zero(s), step = s, length = len)
end

end
5 changes: 5 additions & 0 deletions src/chainrules.jl → ext/ChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
module ChainRulesCoreExt

using AxisKeys: KeyedArray, KaNda, NdaKa, keyless, keyless_unname, axiskeys, named_axiskeys, wrapdims
using ChainRulesCore

function ChainRulesCore.ProjectTo(x::Union{KaNda, NdaKa})
Expand All @@ -19,3 +22,5 @@ function ChainRulesCore.rrule(::typeof(keyless_unname), x)
pb(y) = _KeyedArray_pullback(y, ProjectTo(x))
return keyless_unname(x), pb
end

end
33 changes: 33 additions & 0 deletions ext/CovarianceEstimationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module CovarianceEstimationExt

using AxisKeys: KeyedArray, KeyedMatrix, NamedDims, NamedDimsArray, axiskeys, dimnames, keyless_unname, hasnames
using CovarianceEstimation
using CovarianceEstimation: AbstractWeights
using CovarianceEstimation.Statistics

# Since we get ambiguity errors with specific implementations we need to wrap each supported method
# A better approach might be to add `NamedDims` support to CovarianceEstimators.jl in the future.

estimators = [
:SimpleCovariance,
:LinearShrinkage,
:DiagonalUnitVariance,
:DiagonalCommonVariance,
:DiagonalUnequalVariance,
:CommonCovariance,
:PerfectPositiveCorrelation,
:ConstantCorrelation,
:AnalyticalNonlinearShrinkage,
]
for estimator in estimators
@eval function Statistics.cov(ce::$estimator, A::KeyedMatrix, wv::Vararg{AbstractWeights}; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = cov(ce, keyless_unname(A), wv...; dims=d, kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
KeyedArray(data2, (copy(K1), copy(K1)))
end
end

end
10 changes: 10 additions & 0 deletions ext/InvertedIndicesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module InvertedIndicesExt

using AxisKeys
using InvertedIndices

# needs only Base.to_indices in struct.jl to work,
# plus this to work when used in round brackets:
AxisKeys.findindex(not::InvertedIndex, r::AbstractVector) = Base.unalias(r, not)

end
16 changes: 10 additions & 6 deletions src/stack.jl → ext/LazyStackExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module LazyStackExt

using LazyStack
using AxisKeys: KeyedArray, NamedDims, NamedDimsArray, axiskeys, hasnames, dimnames, keys_or_axes
import LazyStack

# for stack_iter
LazyStack.no_wraps(a::KeyedArray) = LazyStack.no_wraps(NamedDims.unname(parent(a)))
Expand All @@ -14,23 +16,23 @@ function LazyStack.rewrap_like(A, a::NamedTuple)
end

# tuple of arrays
function LazyStack.stack(x::Tuple{Vararg{<:KeyedArray}})
function LazyStack.stack(x::Tuple{Vararg{KeyedArray}})
KeyedArray(LazyStack.stack(map(parent, x)), stack_keys(x))
end

stack_keys(xs::Tuple{Vararg{<:KeyedArray}}) =
stack_keys(xs::Tuple{Vararg{KeyedArray}}) =
(keys_or_axes(first(xs))..., Base.OneTo(length(xs)))

# array of arrays: first strip off outer containers...
function LazyStack.stack(xs::KeyedArray{<:AbstractArray})
KeyedArray(stack(parent(xs)), stack_keys(xs))
KeyedArray(LazyStack.stack(parent(xs)), stack_keys(xs))
end
function LazyStack.stack(xs::KeyedArray{<:AbstractArray,N,<:NamedDimsArray{L}}) where {L,N}
data = stack(parent(parent(xs)))
data = LazyStack.stack(parent(parent(xs)))
KeyedArray(LazyStack.ensure_named(data, LazyStack.getnames(xs)), stack_keys(xs))
end
function LazyStack.stack(xs::NamedDimsArray{L,<:AbstractArray,N,<:KeyedArray}) where {L,N}
data = stack(parent(parent(xs)))
data = LazyStack.stack(parent(parent(xs)))
LazyStack.ensure_named(KeyedArray(data, stack_keys(xs)), LazyStack.getnames(xs))
end

Expand All @@ -57,3 +59,5 @@ function LazyStack.getnames(xs::AbstractArray{<:KeyedArray{T,N,IT}}) where {T,N,
out_names = hasnames(xs) ? dimnames(xs) : NamedDims.dimnames(xs)
(NamedDims.dimnames(IT)..., out_names...)
end

end
9 changes: 9 additions & 0 deletions ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module OffsetArraysExt

using AxisKeys
using OffsetArrays

AxisKeys.no_offset(x::OffsetArray) = parent(x)
AxisKeys.shorttype(r::OffsetArray) = "OffsetArray(::" * shorttype(parent(r)) * ",...)"

end
37 changes: 37 additions & 0 deletions ext/StatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module StatisticsExt

using AxisKeys: KeyedArray, KeyedMatrix, NamedDims, axiskeys
using Statistics

for fun in [:mean, :std, :var] # These don't use mapreduce, but could perhaps be handled better?
@eval function Statistics.$fun(A::KeyedArray; dims=:, kwargs...)
dims === Colon() && return $fun(parent(A); kwargs...)
numerical_dims = NamedDims.dim(A, dims)
data = $fun(parent(A); dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
end

# Handle function interface for `mean` only
if VERSION >= v"1.3"
@eval function Statistics.mean(f, A::KeyedArray; dims=:, kwargs...)
dims === Colon() && return mean(f, parent(A); kwargs...)
numerical_dims = NamedDims.dim(A, dims)
data = mean(f, parent(A); dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
end

for fun in [:cov, :cor] # Returned the axes work are different for cov and cor
@eval function Statistics.$fun(A::KeyedMatrix; dims=1, kwargs...)
numerical_dim = NamedDims.dim(A, dims)
data = $fun(parent(A); dims=numerical_dim, kwargs...)
# Use same remaining axis for both dimensions of data
rem_key = axiskeys(A, 3-numerical_dim)
KeyedArray(data, (copy(rem_key), copy(rem_key)))
end
end

end
29 changes: 4 additions & 25 deletions src/statsbase.jl → ext/StatsBaseExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module StatsBaseExt

using AxisKeys: KeyedArray, KeyedMatrix, NamedDims, NamedDimsArray, axiskeys, dimnames, keyless_unname, hasnames
using StatsBase
using StatsBase.Statistics

# Support some of the weighted statistics function in StatsBase
# NOTES:
Expand Down Expand Up @@ -48,35 +51,11 @@ end

for fun in (:std, :var, :cov)
full_name = Symbol("mean_and_$fun")
@eval StatsBase.$full_name(A::KeyedMatrix, wv::Vararg{<:AbstractWeights}; dims=:, corrected::Bool=true, kwargs...) =
@eval StatsBase.$full_name(A::KeyedMatrix, wv::Vararg{AbstractWeights}; dims=:, corrected::Bool=true, kwargs...) =
(
mean(A, wv...; dims=dims, kwargs...),
$fun(A, wv...; dims=dims, corrected=corrected, kwargs...)
)
end

# Since we get ambiguity errors with specific implementations we need to wrap each supported method
# A better approach might be to add `NamedDims` support to CovarianceEstimators.jl in the future.
using CovarianceEstimation

estimators = [
:SimpleCovariance,
:LinearShrinkage,
:DiagonalUnitVariance,
:DiagonalCommonVariance,
:DiagonalUnequalVariance,
:CommonCovariance,
:PerfectPositiveCorrelation,
:ConstantCorrelation,
:AnalyticalNonlinearShrinkage,
]
for estimator in estimators
@eval function Statistics.cov(ce::$estimator, A::KeyedMatrix, wv::Vararg{<:AbstractWeights}; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = cov(ce, keyless_unname(A), wv...; dims=d, kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
KeyedArray(data2, (copy(K1), copy(K1)))
end
end
18 changes: 11 additions & 7 deletions src/AxisKeys.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export KeyedArray, axiskeys
include("lookup.jl")

include("names.jl")
export named_axiskeys
export named_axiskeys, rekey
export NamedDimsArray, dimnames, rename # Reexport key NamedDimsArrays things

include("wrap.jl")
Expand All @@ -26,11 +26,15 @@ include("show.jl")

include("tables.jl") # Tables.jl

include("stack.jl") # LazyStack.jl

include("fft.jl") # AbstractFFTs.jl

include("statsbase.jl") # StatsBase.jl
if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsExt.jl")
include("../ext/ChainRulesCoreExt.jl")
include("../ext/CovarianceEstimationExt.jl")
include("../ext/InvertedIndicesExt.jl")
include("../ext/LazyStackExt.jl")
include("../ext/OffsetArraysExt.jl")
include("../ext/StatisticsExt.jl")
include("../ext/StatsBaseExt.jl")
end

include("chainrules.jl")
end
17 changes: 14 additions & 3 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,28 @@ Base.BroadcastStyle(::Type{<:KeyedArray{T,N,AT}}) where {T,N,AT} =
Base.BroadcastStyle(::KeyedStyle{A}, ::KeyedStyle{B}) where {A, B} = KeyedStyle(A(), B())
Base.BroadcastStyle(::KeyedStyle{A}, b::B) where {A, B} = KeyedStyle(A(), b)
Base.BroadcastStyle(a::A, ::KeyedStyle{B}) where {A, B} = KeyedStyle(a, B())
Base.BroadcastStyle(::KeyedStyle{A}, b::DefaultArrayStyle) where {A} = KeyedStyle(A(), b)
Base.BroadcastStyle(a::AbstractArrayStyle{M}, ::KeyedStyle{B}) where {B,M} = KeyedStyle(a, B())

using NamedDims: NamedDimsStyle
# this resolves in favour of KeyedArray(NamedDimsArray())
Base.BroadcastStyle(a::NamedDimsStyle, ::KeyedStyle{B}) where {B} = KeyedStyle(a, B())
Base.BroadcastStyle(::KeyedStyle{A}, b::NamedDimsStyle) where {A} = KeyedStyle(A(), b)

# Resolve ambiguities
# for all these cases, we define that we win to be the outer style regardless of order
for B in (
:BroadcastStyle, :DefaultArrayStyle, :AbstractArrayStyle, :(Broadcast.Style{Tuple}),
)
@eval function Base.BroadcastStyle(::KeyedStyle{A}, b::$B) where A
return KeyedStyle(A(), b)
end
@eval function Base.BroadcastStyle(b::$B, ::KeyedStyle{A}) where A
return KeyedStyle(b, A())
end
end

function unwrap_broadcasted(bc::Broadcasted{KeyedStyle{S}}) where {S}
inner_args = map(unwrap_broadcasted, bc.args)
Broadcasted{S}(bc.f, inner_args)
Broadcasted{S}(bc.f, inner_args, axes(bc))
end
unwrap_broadcasted(x) = x
unwrap_broadcasted(x::KeyedArray) = parent(x)
Expand Down
Loading

0 comments on commit a5bdf44

Please sign in to comment.