Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
0xfdf committed Aug 13, 2024
1 parent e842e4c commit 3ecb480
Showing 1 changed file with 0 additions and 53 deletions.
53 changes: 0 additions & 53 deletions toraniko/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,56 +238,3 @@ def ledoit_wolf_shrinkage(X: np.ndarray) -> tuple[float | int, np.ndarray]:
shrunk_cov = (1 - shrinkage) * sample_cov + shrinkage * mu * np.eye(m)

return shrinkage, shrunk_cov


# TODO: test
def stfu_shrinkage(X: np.ndarray) -> tuple[tuple[float | int, ...], np.ndarray]:
"""Estimate the covariance matrix of `X` via Specializing the Target to Features and Unlabeled shrinkage (SFTU).
Parameters
----------
X : array-like input data matrix for which to estimate covariance, having shape (n_samples, m_features)
Returns
-------
shrinkage: float estimated shrinkage parameter.
shrunk_cov: numpy ndarray estimated shrunk covariance matrix having shape (n_features, n_features)
"""
n, m = X.shape

# Center the data
X = X - X.mean(axis=0)

# Sample covariance
sample_cov = np.dot(X.T, X) / n

# Diagonal of sample covariance
sample_var = np.diag(sample_cov)

# Mean of sample variances
mean_var = np.mean(sample_var)

# Estimate of tr(sigma^2)
tr_sigma2 = np.sum(sample_cov**2)

# Estimate of tr(sigma^4)
tr_sigma4 = np.sum((X.T @ X / n) ** 2)

# Estimate lambda_star (shrinkage towards diagonal matrix)
lambda_star = (tr_sigma2 - np.sum(sample_var**2)) / (
(n - 1) * (tr_sigma2 - 2 * np.sum(sample_var**2) + m * mean_var**2)
)
lambda_star = max(0, min(1, lambda_star))

# Estimate mu_star (shrinkage towards scalar matrix)
mu_star = (tr_sigma2 - m * mean_var**2) / ((n - 1) * (tr_sigma4 - tr_sigma2**2 / m))
mu_star = max(0, min(1, mu_star))

# Compute shrunk covariance matrix
shrunk_cov = (
(1 - lambda_star) * sample_cov
+ lambda_star * np.diag(sample_var)
+ lambda_star * mu_star * (mean_var - np.diag(sample_var))
)

return (lambda_star, mu_star), shrunk_cov

0 comments on commit 3ecb480

Please sign in to comment.