diff --git a/pymdp/maths.py b/pymdp/maths.py index 59904be5..68000fb4 100644 --- a/pymdp/maths.py +++ b/pymdp/maths.py @@ -12,6 +12,7 @@ from scipy import special from pymdp import utils from itertools import chain +from opt_einsum import contract EPS_VAL = 1e-16 # global constant for use in spm_log() function @@ -105,6 +106,28 @@ def spm_dot_classic(X, x, dims_to_omit=None): return Y +def factor_dot_flex(M, xs, dims, keep_dims=None): + """ Dot product of a multidimensional array with `x`. + + Parameters + ---------- + - `M` [numpy.ndarray] - tensor + - 'xs' [list of numpyr.ndarray] - list of tensors + - 'dims' [list of tuples] - list of dimensions of xs tensors in tensor M + - 'keep_dims' [tuple] - tuple of integers denoting dimesions to keep + Returns + ------- + - `Y` [1D numpy.ndarray] - the result of the dot product + """ + all_dims = tuple(range(M.ndim)) + matrix = [[xs[f], dims[f]] for f in range(len(xs))] + args = [M, all_dims] + for row in matrix: + args.extend(row) + + args += [keep_dims] + return contract(*args, backend='numpy') + def spm_dot_old(X, x, dims_to_omit=None, obs_mode=False): """ Dot product of a multidimensional array with `x`. The dimensions in `dims_to_omit` will not be summed across during the dot product