diff --git a/filterpy/common/helpers.py b/filterpy/common/helpers.py index 0d54eed..6fb9291 100644 --- a/filterpy/common/helpers.py +++ b/filterpy/common/helpers.py @@ -88,14 +88,16 @@ def setter_scalar(value, dim_x): type which converts to numpy.array (list, np.array, np.matrix, etc), or a scalar, in which case we create a diagonal matrix with each diagonal element == value. + + dim_x is used iff value is scalar, otherwise it is determined from the + shape of value """ if isscalar(value): v = eye(dim_x) * value else: - v = asarray(value, dtype=float) + v = array(value, dtype=float) + dim_x = v.shape[0] - if v is value: - v = value.copy() if v.shape != (dim_x, dim_x): raise Exception('must have shape ({},{})'.format(dim_x, dim_x)) return v