-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
371 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
module CovarianceEstimationExt | ||
|
||
using AxisKeys: KeyedArray, KeyedMatrix, NamedDims, NamedDimsArray, axiskeys, dimnames, keyless_unname, hasnames | ||
using CovarianceEstimation | ||
using CovarianceEstimation: AbstractWeights | ||
using CovarianceEstimation.Statistics | ||
|
||
# Since we get ambiguity errors with specific implementations we need to wrap each supported method | ||
# A better approach might be to add `NamedDims` support to CovarianceEstimators.jl in the future. | ||
|
||
estimators = [ | ||
:SimpleCovariance, | ||
:LinearShrinkage, | ||
:DiagonalUnitVariance, | ||
:DiagonalCommonVariance, | ||
:DiagonalUnequalVariance, | ||
:CommonCovariance, | ||
:PerfectPositiveCorrelation, | ||
:ConstantCorrelation, | ||
:AnalyticalNonlinearShrinkage, | ||
] | ||
for estimator in estimators | ||
@eval function Statistics.cov(ce::$estimator, A::KeyedMatrix, wv::Vararg{AbstractWeights}; dims=1, kwargs...) | ||
d = NamedDims.dim(A, dims) | ||
data = cov(ce, keyless_unname(A), wv...; dims=d, kwargs...) | ||
L1 = dimnames(A, 3 - d) | ||
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data | ||
K1 = axiskeys(A, 3 - d) | ||
KeyedArray(data2, (copy(K1), copy(K1))) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
module InvertedIndicesExt | ||
|
||
using AxisKeys | ||
using InvertedIndices | ||
|
||
# needs only Base.to_indices in struct.jl to work, | ||
# plus this to work when used in round brackets: | ||
AxisKeys.findindex(not::InvertedIndex, r::AbstractVector) = Base.unalias(r, not) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module OffsetArraysExt | ||
|
||
using AxisKeys | ||
using OffsetArrays | ||
|
||
AxisKeys.no_offset(x::OffsetArray) = parent(x) | ||
AxisKeys.shorttype(r::OffsetArray) = "OffsetArray(::" * shorttype(parent(r)) * ",...)" | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
module StatisticsExt | ||
|
||
using AxisKeys: KeyedArray, KeyedMatrix, NamedDims, axiskeys | ||
using Statistics | ||
|
||
for fun in [:mean, :std, :var] # These don't use mapreduce, but could perhaps be handled better? | ||
@eval function Statistics.$fun(A::KeyedArray; dims=:, kwargs...) | ||
dims === Colon() && return $fun(parent(A); kwargs...) | ||
numerical_dims = NamedDims.dim(A, dims) | ||
data = $fun(parent(A); dims=numerical_dims, kwargs...) | ||
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A)) | ||
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta)) | ||
end | ||
end | ||
|
||
# Handle function interface for `mean` only | ||
if VERSION >= v"1.3" | ||
@eval function Statistics.mean(f, A::KeyedArray; dims=:, kwargs...) | ||
dims === Colon() && return mean(f, parent(A); kwargs...) | ||
numerical_dims = NamedDims.dim(A, dims) | ||
data = mean(f, parent(A); dims=numerical_dims, kwargs...) | ||
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A)) | ||
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta)) | ||
end | ||
end | ||
|
||
for fun in [:cov, :cor] # Returned the axes work are different for cov and cor | ||
@eval function Statistics.$fun(A::KeyedMatrix; dims=1, kwargs...) | ||
numerical_dim = NamedDims.dim(A, dims) | ||
data = $fun(parent(A); dims=numerical_dim, kwargs...) | ||
# Use same remaining axis for both dimensions of data | ||
rem_key = axiskeys(A, 3-numerical_dim) | ||
KeyedArray(data, (copy(rem_key), copy(rem_key))) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.