Skip to content

Commit

Permalink
Merge pull request #12 from greglucas/pred_error
Browse files Browse the repository at this point in the history
Fix calculation of SECS error variance
  • Loading branch information
greglucas authored May 12, 2021
2 parents 383685b + 58bdd92 commit 73ec40c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 33 deletions.
10 changes: 5 additions & 5 deletions examples/plot_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@
B_obs[:, :, 0] *= 2*np.sin(np.deg2rad(obs_lat_lon_r[:, 0]))
B_obs[:, :, 1] *= 2*np.sin(np.deg2rad(obs_lat_lon_r[:, 1]))

B_var = np.ones(B_obs.shape)
B_std = np.ones(B_obs.shape)
# Ignore the Z component
B_var[..., 2] = np.inf
# Can modify the variance as a function of time to
B_std[..., 2] = np.inf
# Can modify the standard error as a function of time to
# see how that changes the fits too
# B_var[:, 0, 1] = 1 + ts
# B_std[:, 0, 1] = 1 + ts

# Fit the data, requires observation locations and data
secs.fit(obs_loc=obs_lat_lon_r, obs_B=B_obs, obs_var=B_var)
secs.fit(obs_loc=obs_lat_lon_r, obs_B=B_obs, obs_std=B_std)

# Create prediction points
# Extend it a little beyond the observation points (-11, 11)
Expand Down
44 changes: 23 additions & 21 deletions pysecs/secs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def nsec(self):
nsec += len(self.sec_cf_loc)
return nsec

def fit(self, obs_loc, obs_B, obs_var=None, epsilon=0.05):
def fit(self, obs_loc, obs_B, obs_std=None, epsilon=0.05):
"""Fits the SECS to the given observations.
Given a number of observation locations and measurements,
Expand All @@ -88,10 +88,10 @@ def fit(self, obs_loc, obs_B, obs_var=None, epsilon=0.05):
obs_B: ndarray (ntimes x nobs x 3 [Bx, By, Bz])
An array containing the measured/observed B-fields.
obs_var : ndarray (ntimes x nobs x 3 [varX, varY, varZ]), optional
Variances in the components at each observation location. Can be used to
weight different observation locations more/less heavily. Infinite variance
effectively eliminates the observation from the fit.
obs_std : ndarray (ntimes x nobs x 3 [varX, varY, varZ]), optional
Standard error of vector components at each observation location.
This can be used to weight different observations more/less heavily.
An infinite value eliminates the observation from the fit.
Default: ones(nobs x 3) equal weights
epsilon : float
Expand All @@ -106,9 +106,9 @@ def fit(self, obs_loc, obs_B, obs_var=None, epsilon=0.05):
# Just a single snapshot given, so expand the dimensionality
obs_B = obs_B[np.newaxis, ...]

# Assume unit variance of all measurements
if obs_var is None:
obs_var = np.ones(obs_B.shape)
# Assume unit standard error of all measurements
if obs_std is None:
obs_std = np.ones(obs_B.shape)

ntimes = len(obs_B)

Expand All @@ -121,39 +121,41 @@ def fit(self, obs_loc, obs_B, obs_var=None, epsilon=0.05):

# Calculate the singular value decomposition (SVD)
# NOTE: T_obs has shape (nobs, 3, nsec), we reshape it
# to (nobs*3, nsec); obs_var has shape (ntimes, nobs, 3),
# to (nobs*3, nsec); obs_std has shape (ntimes, nobs, 3),
# we reshape it to (ntimes, nobs*3), then loop over ntimes
# to solve using (potentially) time-dependent observation
# error variances to weight the observations
# standard errors to weight the observations
for i in range(ntimes):

# Only (re-)calculate SVD when necessary
if i == 0 or not np.all(obs_var[i] == obs_var[i-1]):
if i == 0 or not np.all(obs_std[i] == obs_std[i-1]):

# Weight T_obs with obs_var
# Weight T_obs with obs_std
svd_in = (T_obs.reshape(-1, self.nsec) /
obs_var[i].ravel()[:, np.newaxis])
obs_std[i].ravel()[:, np.newaxis])

# Find singular value decompostion
U, S, Vh = np.linalg.svd(svd_in, full_matrices=False)

# Divide by infinity (1/S) gives zero weights
# Eliminate singular values less than epsilon by setting their
# reciprocal to zero (setting S to infinity firsts avoids
# divide-by-zero warings)
S[S < epsilon * S.max()] = np.inf
W = 1./S

# Eliminate the small singular values (less than epsilon)
# by giving them zero weight
W[S < epsilon*S.max()] = 0.

# Update VWU if obs_var changed
# Update VWU if obs_std changed
VWU = Vh.T @ (np.diag(W) @ U.T)

# Solve for SEC amplitudes and error variances
# shape: ntimes x nsec
self.sec_amps[i, :] = (VWU @ (obs_B[i]/obs_var[i]).reshape(-1).T).T
self.sec_amps[i, :] = (VWU @ (obs_B[i] / obs_std[i]).reshape(-1).T).T

# Maybe we want the variance of the predictions sometime later...?
# shape: ntimes x nsec
self.sec_amps_var[i, :] = np.sum((Vh.T * W)**2, axis=1)
valid = np.isfinite(obs_std[i].reshape(-1))
self.sec_amps_var[i, :] = np.sum(
(VWU[:,valid] * obs_std[i].reshape(-1)[valid])**2,
axis=1)

return self

Expand Down
14 changes: 7 additions & 7 deletions tests/test_secs.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,17 @@ def test_fit_multi_time():
assert_allclose(expected, secs.sec_amps)


def test_fit_obs_var():
def test_fit_obs_std():
"""Test that variance on observations changes the results."""
secs = pysecs.SECS(sec_df_loc=[[1., 0., R_EARTH + 1e6],
[-1., 0., R_EARTH + 1e6]])
obs_loc = np.array([[0, 0, R_EARTH]])
obs_B = np.ones((2, 1, 3))
obs_B[1, :, :] *= 2
obs_var = np.ones(obs_B.shape)
obs_std = np.ones(obs_B.shape)
# Remove the z component from the fit of the second timestep
obs_var[1, :, 2] = np.inf
secs.fit(obs_loc, obs_B, obs_var=obs_var)
obs_std[1, :, 2] = np.inf
secs.fit(obs_loc, obs_B, obs_std=obs_std)
expected = np.array([[6.40594202e+13, -7.41421248e+13],
[1.382015e+14, -1.382015e+14]])
assert_allclose(expected, secs.sec_amps, rtol=1e-6)
Expand All @@ -353,10 +353,10 @@ def test_fit_epsilon():
obs_loc = np.array([[0, 0, R_EARTH]])
obs_B = np.ones((2, 1, 3))
obs_B[1, :, :] *= 2
obs_var = np.ones(obs_B.shape)
obs_std = np.ones(obs_B.shape)
# Remove the z component from the fit of the second timestep
obs_var[1, :, 2] = np.inf
secs.fit(obs_loc, obs_B, obs_var=obs_var, epsilon=0.8)
obs_std[1, :, 2] = np.inf
secs.fit(obs_loc, obs_B, obs_std=obs_std, epsilon=0.8)
expected = np.array([[-5.041352e+12, -5.041352e+12],
[1.382015e+14, -1.382015e+14]])
assert_allclose(expected, secs.sec_amps, rtol=1e-6)
Expand Down

0 comments on commit 73ec40c

Please sign in to comment.