diff --git a/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs b/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs
index d5400317..bdd54973 100644
--- a/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs
+++ b/src/Bonsai.ML.HiddenMarkovModels.Design/GaussianObservationsStatisticsClustersVisualizer.cs
@@ -149,7 +149,6 @@ public override void Show(object value)
{
if (value is Observations.GaussianObservationsStatistics gaussianObservationsStatistics)
{
-
var statesCount = gaussianObservationsStatistics.Means.GetLength(0);
var observationDimensions = gaussianObservationsStatistics.Means.GetLength(1);
@@ -227,12 +226,14 @@ public override void Show(object value)
var batchObservationsCount = gaussianObservationsStatistics.BatchObservations.GetLength(0);
var offset = BufferData && batchObservationsCount > BufferCount ? batchObservationsCount - BufferCount : 0;
+ var predictedStatesCount = gaussianObservationsStatistics.PredictedStates.Length;
+
for (int i = offset; i < batchObservationsCount; i++)
{
var dim1 = gaussianObservationsStatistics.BatchObservations[i, dimension1SelectedIndex];
var dim2 = gaussianObservationsStatistics.BatchObservations[i, dimension2SelectedIndex];
- var state = gaussianObservationsStatistics.InferredMostProbableStates[i];
- allScatterSeries[(int)state].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state));
+ var state = gaussianObservationsStatistics.PredictedStates[i];
+ allScatterSeries[Convert.ToInt32(state)].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state));
}
for (int i = 0; i < statesCount; i++)
diff --git a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
index 67cad75f..3cad8001 100644
--- a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
+++ b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
@@ -52,7 +52,7 @@
- hmm.most_likely_states([59.7382107943162,3.99285183724331])
+ hmm.infer_state([59.7382107943162,3.99285183724331])
diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs
index ffeb6c00..991b7cca 100644
--- a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs
+++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs
@@ -45,11 +45,11 @@ public class GaussianObservationsStatistics
public double[,] BatchObservations { get; set; }
///
- /// The sequence of inferred most probable states.
+ /// The predicted state for each observation in the batch of observations.
///
- [Description("The sequence of inferred most probable states.")]
+ [Description("The predicted state for each observation in the batch of observations.")]
[XmlIgnore]
- public int[] InferredMostProbableStates { get; set; }
+ public long[] PredictedStates { get; set; }
///
/// Transforms an observable sequence of into an observable sequence
@@ -64,7 +64,7 @@ public IObservable Process(IObservable
var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas");
var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj);
var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations");
- var inferredMostProbableStatesPyObj = (int[])pyObject.GetArrayAttr("inferred_most_probable_states");
+ var predictedStatesPyObj = (long[])pyObject.GetArrayAttr("predicted_states");
return new GaussianObservationsStatistics
{
@@ -72,7 +72,7 @@ public IObservable Process(IObservable
StdDevs = stdDevsPyObj,
CovarianceMatrices = covarianceMatricesPyObj,
BatchObservations = batchObservationsPyObj,
- InferredMostProbableStates = inferredMostProbableStatesPyObj
+ PredictedStates = predictedStatesPyObj
};
});
}
diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py
index aad1386e..acafea6a 100644
--- a/src/Bonsai.ML.HiddenMarkovModels/main.py
+++ b/src/Bonsai.ML.HiddenMarkovModels/main.py
@@ -78,14 +78,15 @@ def get_nonlinearity_type(func):
self.state_probabilities = None
self.batch = None
- self.batch_observations = np.array([[]], dtype=float)
+ self.batch_observations = np.array([[]], dtype=float).reshape((0, dimensions))
self.is_running = False
self._fit_finished = False
self.loop = None
self.thread = None
self.curr_batch_size = 0
self.flush_data_between_batches = True
- self.inferred_most_probable_states = np.array([], dtype=int)
+ self.predicted_states = np.array([], dtype=int)
+ self.buffer_count = 250
def update_params(self, initial_state_distribution, transitions_params, observations_params):
hmm_params = self.params
@@ -124,10 +125,17 @@ def update_params(self, initial_state_distribution, transitions_params, observat
def infer_state(self, observation: list[float]):
- self.log_alpha = self.compute_log_alpha(
- np.expand_dims(np.array(observation), 0), self.log_alpha)
+ observation = np.expand_dims(np.array(observation), 0)
+ self.log_alpha = self.compute_log_alpha(observation, self.log_alpha)
self.state_probabilities = np.exp(self.log_alpha).astype(np.double)
- return self.state_probabilities.argmax()
+ prediction = self.state_probabilities.argmax()
+ self.predicted_states = np.append(self.predicted_states, prediction)
+ if self.predicted_states.shape[0] > self.buffer_count:
+ self.predicted_states = self.predicted_states[1:]
+ self.batch_observations = np.vstack([self.batch_observations, observation])
+ if self.batch_observations.shape[0] == self.buffer_count:
+ self.batch_observations = self.batch_observations[1:]
+ return prediction
def compute_log_alpha(self, obs, log_alpha=None):
@@ -171,8 +179,6 @@ def fit_async(self,
self.batch = np.vstack(
[self.batch[1:], np.expand_dims(np.array(observation), 0)])
- self.batch_observations = self.batch
-
if not self.is_running and self.loop is None and self.thread is None:
if self.curr_batch_size >= batch_size:
@@ -221,8 +227,6 @@ def on_completion(future):
if self.flush_data_between_batches:
self.batch = None
- self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)
-
self.is_running = True
if self.loop is None or self.loop.is_closed():