diff --git a/mathics/builtin/tensors.py b/mathics/builtin/tensors.py index 3509920c4..2b9b9d414 100644 --- a/mathics/builtin/tensors.py +++ b/mathics/builtin/tensors.py @@ -30,8 +30,20 @@ from mathics.core.evaluation import Evaluation from mathics.core.expression import Expression from mathics.core.list import ListExpression -from mathics.core.symbols import Atom, Symbol, SymbolFalse, SymbolTrue -from mathics.core.systemsymbols import SymbolAutomatic, SymbolRule, SymbolSparseArray +from mathics.core.symbols import ( + Atom, + Symbol, + SymbolFalse, + SymbolList, + SymbolTimes, + SymbolTrue, +) +from mathics.core.systemsymbols import ( + SymbolAutomatic, + SymbolNormal, + SymbolRule, + SymbolSparseArray, +) from mathics.eval.parts import get_part @@ -300,26 +312,42 @@ class Outer(Builtin): = {{0, 1, 0}, {1, 0, 1}, {0, ComplexInfinity, 0}} """ - rules = { - "Outer[f_, a___, b_SparseArray, c___] /; UnsameQ[f, Times]": "Outer[f, a, b // Normal, c]", - } - summary_text = "generalized outer product" def eval(self, f, lists, evaluation: Evaluation): - "Outer[f_, lists__] /; Or[SameQ[f, Times], Not[MemberQ[{lists}, _SparseArray]]]" + "Outer[f_, lists__]" + # If f=!=Times, or lists contain both SparseArray and List, then convert all SparseArrays to Lists lists = lists.get_sequence() head = None + sparse_to_list = f != SymbolTimes + contain_sparse = False + comtain_list = False + for _list in lists: + if _list.head.sameQ(SymbolSparseArray): + contain_sparse = True + if _list.head.sameQ(SymbolList): + comtain_list = True + sparse_to_list = sparse_to_list or (contain_sparse and comtain_list) + if sparse_to_list: + break + if sparse_to_list: + new_lists = [] for _list in lists: if isinstance(_list, Atom): evaluation.message("Outer", "normal") return + if sparse_to_list: + if _list.head.sameQ(SymbolSparseArray): + _list = Expression(SymbolNormal, _list).evaluate(evaluation) + new_lists.append(_list) if head is None: head = _list.head elif not _list.head.sameQ(head): evaluation.message("Outer", "heads", head, _list.head) return + if sparse_to_list: + lists = new_lists def rec(item, rest_lists, current): evaluation.check_stopped() @@ -391,7 +419,6 @@ def rec_sparse(item, rest_lists, current): ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))), ) - class RotationTransform(Builtin): """ :WMA link: https://reference.wolfram.com/language/ref/RotationTransform.html