Skip to content

Commit

Permalink
added the numpy version of @dimarkov's factor_dot_flex into numpy `…
Browse files Browse the repository at this point in the history
…maths.py` library
  • Loading branch information
conorheins committed Jun 5, 2024
1 parent 55fcc5f commit e97d6c7
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions pymdp/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scipy import special
from pymdp import utils
from itertools import chain
from opt_einsum import contract

EPS_VAL = 1e-16 # global constant for use in spm_log() function

Expand Down Expand Up @@ -105,6 +106,28 @@ def spm_dot_classic(X, x, dims_to_omit=None):

return Y

def factor_dot_flex(M, xs, dims, keep_dims=None):
""" Dot product of a multidimensional array with `x`.
Parameters
----------
- `M` [numpy.ndarray] - tensor
- 'xs' [list of numpyr.ndarray] - list of tensors
- 'dims' [list of tuples] - list of dimensions of xs tensors in tensor M
- 'keep_dims' [tuple] - tuple of integers denoting dimesions to keep
Returns
-------
- `Y` [1D numpy.ndarray] - the result of the dot product
"""
all_dims = tuple(range(M.ndim))
matrix = [[xs[f], dims[f]] for f in range(len(xs))]
args = [M, all_dims]
for row in matrix:
args.extend(row)

args += [keep_dims]
return contract(*args, backend='numpy')

def spm_dot_old(X, x, dims_to_omit=None, obs_mode=False):
""" Dot product of a multidimensional array with `x`. The dimensions in `dims_to_omit`
will not be summed across during the dot product
Expand Down

0 comments on commit e97d6c7

Please sign in to comment.