diff --git a/src/popcast/fit.py b/src/popcast/fit.py index fee9c0e..b1b05a5 100644 --- a/src/popcast/fit.py +++ b/src/popcast/fit.py @@ -421,6 +421,11 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan # Estimate target values. y_hat = self.predict(X, coefficients) + # Save optimal frequencies by timepoint and strain, if calculating + # optimal distance to the future. + if calculate_optimal_distance: + optimal_frequency_records = [] + # Calculate EMD for each timepoint in the estimated values and sum that # distance across all timepoints. error = 0.0 @@ -471,6 +476,12 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan cost=distance_matrix ) + optimal_frequency_records.append(pd.DataFrame({ + "strain": samples_a, + "timepoint": timepoint, + "optimal_projected_frequency": estimated_frequencies, + })) + # Estimate the distance between the model's estimated future and the # observed future populations. model_emd, _, self.model_flow = cv2.EMD( @@ -490,6 +501,9 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan else: l1_penalty = 0.0 + if calculate_optimal_distance: + self.optimal_frequencies = pd.concat(optimal_frequency_records) + return error + l1_penalty def _fit_distance(self, coefficients, X, y, use_l1_penalty=True): @@ -861,6 +875,11 @@ def test(model_class, model_kwargs, data, targets, timepoints, coefficients=None # Get the estimated frequencies for test sets to export. test_y_hat = model.predict(test_X) + test_y_hat = test_y_hat.merge( + model.optimal_frequencies, + on=["timepoint", "strain"], + validate="1:1", + ) # Convert timestamps to a serializable format. for df in [test_X, test_y, test_y_hat]: