From 58dbb8b6aabc197857d1c17d6aab4cbf6ad9f984 Mon Sep 17 00:00:00 2001 From: Li Xiang <54926635+Li-Xiang-Ideal@users.noreply.github.com> Date: Sat, 25 Nov 2023 19:13:48 +0800 Subject: [PATCH] Update tensors.py --- mathics/builtin/tensors.py | 72 +++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 40 deletions(-) diff --git a/mathics/builtin/tensors.py b/mathics/builtin/tensors.py index f8a292524..3509920c4 100644 --- a/mathics/builtin/tensors.py +++ b/mathics/builtin/tensors.py @@ -357,48 +357,40 @@ def rec_sparse(item, rest_lists, current): current[1] * _val, ), ) - if head.sameQ(SymbolSparseArray): - dims = [] - val = Integer1 - data = [] # data = [(rules), ...] - for _list in lists: - dims.extend(_list.elements[1]) - val *= _list.elements[2] - if _list.elements[2] == Integer0: # _val==0 - data.append(_list.elements[3].elements) # append (rules) - else: # _val!=0, append (rules, other pos->_val) - other_pos = [] - for pos in itertools.product( - *(range(1, d.value + 1) for d in _list.elements[1]) - ): - other_pos.append( - ListExpression(*(Integer(i) for i in pos)) - ) # generate all pos - rules_pos = set( - rule.elements[0] for rule in _list.elements[3].elements - ) # pos of existing rules - other_pos = ( - set(other_pos) - rules_pos - ) # remove pos of existing rules - other_rules = [] - for pos in other_pos: - other_rules.append( - Expression(SymbolRule, pos, _list.elements[2]) - ) # generate other pos->_val - data.append( - _list.elements[3].elements + tuple(other_rules) - ) # append (rules, other pos->_val) - dims = ListExpression(*dims) - return Expression( - SymbolSparseArray, - SymbolAutomatic, - dims, - val, - ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))), - ) - else: + + # head != SparseArray + if not head.sameQ(SymbolSparseArray): return rec(lists[0], lists[1:], []) + # head == SparseArray + dims = [] + val = Integer1 + data = [] # data = [(rules), ...] + for _list in lists: + _dims, _val, _rules = _list.elements[1:] + dims.extend(_dims) + val *= _val + if _val == Integer0: # _val==0, append (_rules) + data.append(_rules.elements) + else: # _val!=0, append (_rules, other pos->_val) + other_pos = [] + for pos in itertools.product(*(range(1, d.value + 1) for d in _dims)): + other_pos.append(ListExpression(*(Integer(i) for i in pos))) + rules_pos = set(rule.elements[0] for rule in _rules.elements) + other_pos = set(other_pos) - rules_pos + other_rules = [] + for pos in other_pos: + other_rules.append(Expression(SymbolRule, pos, _val)) + data.append(_list.elements[3].elements + tuple(other_rules)) + dims = ListExpression(*dims) + return Expression( + SymbolSparseArray, + SymbolAutomatic, + dims, + val, + ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))), + ) + class RotationTransform(Builtin): """