diff --git a/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs b/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs index d5400317..bdd54973 100644 --- a/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs +++ b/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs @@ -149,7 +149,6 @@ public override void Show(object value) { if (value is Observations.GaussianObservationsStatistics gaussianObservationsStatistics) { - var statesCount = gaussianObservationsStatistics.Means.GetLength(0); var observationDimensions = gaussianObservationsStatistics.Means.GetLength(1); @@ -227,12 +226,14 @@ public override void Show(object value) var batchObservationsCount = gaussianObservationsStatistics.BatchObservations.GetLength(0); var offset = BufferData && batchObservationsCount > BufferCount ? batchObservationsCount - BufferCount : 0; + var predictedStatesCount = gaussianObservationsStatistics.PredictedStates.Length; + for (int i = offset; i < batchObservationsCount; i++) { var dim1 = gaussianObservationsStatistics.BatchObservations[i, dimension1SelectedIndex]; var dim2 = gaussianObservationsStatistics.BatchObservations[i, dimension2SelectedIndex]; - var state = gaussianObservationsStatistics.InferredMostProbableStates[i]; - allScatterSeries[(int)state].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state)); + var state = gaussianObservationsStatistics.PredictedStates[i]; + allScatterSeries[Convert.ToInt32(state)].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state)); } for (int i = 0; i < statesCount; i++) diff --git a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai index 67cad75f..3cad8001 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai +++ b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai @@ -52,7 +52,7 @@ - hmm.most_likely_states([59.7382107943162,3.99285183724331]) + hmm.infer_state([59.7382107943162,3.99285183724331]) diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs index ffeb6c00..991b7cca 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs @@ -45,11 +45,11 @@ public class GaussianObservationsStatistics public double[,] BatchObservations { get; set; } /// - /// The sequence of inferred most probable states. + /// The predicted state for each observation in the batch of observations. /// - [Description("The sequence of inferred most probable states.")] + [Description("The predicted state for each observation in the batch of observations.")] [XmlIgnore] - public int[] InferredMostProbableStates { get; set; } + public long[] PredictedStates { get; set; } /// /// Transforms an observable sequence of into an observable sequence @@ -64,7 +64,7 @@ public IObservable Process(IObservable var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas"); var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj); var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations"); - var inferredMostProbableStatesPyObj = (int[])pyObject.GetArrayAttr("inferred_most_probable_states"); + var predictedStatesPyObj = (long[])pyObject.GetArrayAttr("predicted_states"); return new GaussianObservationsStatistics { @@ -72,7 +72,7 @@ public IObservable Process(IObservable StdDevs = stdDevsPyObj, CovarianceMatrices = covarianceMatricesPyObj, BatchObservations = batchObservationsPyObj, - InferredMostProbableStates = inferredMostProbableStatesPyObj + PredictedStates = predictedStatesPyObj }; }); } diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py index aad1386e..acafea6a 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/main.py +++ b/src/Bonsai.ML.HiddenMarkovModels/main.py @@ -78,14 +78,15 @@ def get_nonlinearity_type(func): self.state_probabilities = None self.batch = None - self.batch_observations = np.array([[]], dtype=float) + self.batch_observations = np.array([[]], dtype=float).reshape((0, dimensions)) self.is_running = False self._fit_finished = False self.loop = None self.thread = None self.curr_batch_size = 0 self.flush_data_between_batches = True - self.inferred_most_probable_states = np.array([], dtype=int) + self.predicted_states = np.array([], dtype=int) + self.buffer_count = 250 def update_params(self, initial_state_distribution, transitions_params, observations_params): hmm_params = self.params @@ -124,10 +125,17 @@ def update_params(self, initial_state_distribution, transitions_params, observat def infer_state(self, observation: list[float]): - self.log_alpha = self.compute_log_alpha( - np.expand_dims(np.array(observation), 0), self.log_alpha) + observation = np.expand_dims(np.array(observation), 0) + self.log_alpha = self.compute_log_alpha(observation, self.log_alpha) self.state_probabilities = np.exp(self.log_alpha).astype(np.double) - return self.state_probabilities.argmax() + prediction = self.state_probabilities.argmax() + self.predicted_states = np.append(self.predicted_states, prediction) + if self.predicted_states.shape[0] > self.buffer_count: + self.predicted_states = self.predicted_states[1:] + self.batch_observations = np.vstack([self.batch_observations, observation]) + if self.batch_observations.shape[0] == self.buffer_count: + self.batch_observations = self.batch_observations[1:] + return prediction def compute_log_alpha(self, obs, log_alpha=None): @@ -171,8 +179,6 @@ def fit_async(self, self.batch = np.vstack( [self.batch[1:], np.expand_dims(np.array(observation), 0)]) - self.batch_observations = self.batch - if not self.is_running and self.loop is None and self.thread is None: if self.curr_batch_size >= batch_size: @@ -221,8 +227,6 @@ def on_completion(future): if self.flush_data_between_batches: self.batch = None - self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int) - self.is_running = True if self.loop is None or self.loop.is_closed():