diff --git a/pymdp/agent.py b/pymdp/agent.py index 7def788c..f09b559a 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -648,7 +648,7 @@ def _validate(self): ), f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of A[{m}]..." if self.pA != None: assert ( - self.pA[m].shape[2:] == factor_dims + self.pA[m].shape[2:] == factor_dims if self.pA[m] is not None else True, ), f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of pA[{m}]..." assert max(self.A_dependencies[m]) <= ( self.num_factors - 1