Skip to content

Commit

Permalink
Save optimal projected frequencies for fixed model
Browse files Browse the repository at this point in the history
Saves optimal projected frequencies per timepoint and strain from the
optimal earth mover's distance calculation between each timepoint into
the same scores data structure that stores the "projected_frequency"
values.
  • Loading branch information
huddlej committed Sep 4, 2024
1 parent 1406e14 commit c58dbdc
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/popcast/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan
# Estimate target values.
y_hat = self.predict(X, coefficients)

# Save optimal frequencies by timepoint and strain, if calculating
# optimal distance to the future.
if calculate_optimal_distance:
optimal_frequency_records = []

# Calculate EMD for each timepoint in the estimated values and sum that
# distance across all timepoints.
error = 0.0
Expand Down Expand Up @@ -471,6 +476,12 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan
cost=distance_matrix
)

optimal_frequency_records.append(pd.DataFrame({
"strain": samples_a,
"timepoint": timepoint,
"optimal_projected_frequency": estimated_frequencies,
}))

# Estimate the distance between the model's estimated future and the
# observed future populations.
model_emd, _, self.model_flow = cv2.EMD(
Expand All @@ -490,6 +501,9 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan
else:
l1_penalty = 0.0

if calculate_optimal_distance:
self.optimal_frequencies = pd.concat(optimal_frequency_records)

return error + l1_penalty

def _fit_distance(self, coefficients, X, y, use_l1_penalty=True):
Expand Down Expand Up @@ -861,6 +875,11 @@ def test(model_class, model_kwargs, data, targets, timepoints, coefficients=None

# Get the estimated frequencies for test sets to export.
test_y_hat = model.predict(test_X)
test_y_hat = test_y_hat.merge(
model.optimal_frequencies,
on=["timepoint", "strain"],
validate="1:1",
)

# Convert timestamps to a serializable format.
for df in [test_X, test_y, test_y_hat]:
Expand Down

0 comments on commit c58dbdc

Please sign in to comment.