Skip to content

Commit

Permalink
Update tensors.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Li-Xiang-Ideal authored Nov 25, 2023
1 parent 4159b69 commit 58dbb8b
Showing 1 changed file with 32 additions and 40 deletions.
72 changes: 32 additions & 40 deletions mathics/builtin/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,48 +357,40 @@ def rec_sparse(item, rest_lists, current):
current[1] * _val,
),
)
if head.sameQ(SymbolSparseArray):
dims = []
val = Integer1
data = [] # data = [(rules), ...]
for _list in lists:
dims.extend(_list.elements[1])
val *= _list.elements[2]
if _list.elements[2] == Integer0: # _val==0
data.append(_list.elements[3].elements) # append (rules)
else: # _val!=0, append (rules, other pos->_val)
other_pos = []
for pos in itertools.product(
*(range(1, d.value + 1) for d in _list.elements[1])
):
other_pos.append(
ListExpression(*(Integer(i) for i in pos))
) # generate all pos
rules_pos = set(
rule.elements[0] for rule in _list.elements[3].elements
) # pos of existing rules
other_pos = (
set(other_pos) - rules_pos
) # remove pos of existing rules
other_rules = []
for pos in other_pos:
other_rules.append(
Expression(SymbolRule, pos, _list.elements[2])
) # generate other pos->_val
data.append(
_list.elements[3].elements + tuple(other_rules)
) # append (rules, other pos->_val)
dims = ListExpression(*dims)
return Expression(
SymbolSparseArray,
SymbolAutomatic,
dims,
val,
ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))),
)
else:

# head != SparseArray
if not head.sameQ(SymbolSparseArray):
return rec(lists[0], lists[1:], [])

# head == SparseArray
dims = []
val = Integer1
data = [] # data = [(rules), ...]
for _list in lists:
_dims, _val, _rules = _list.elements[1:]
dims.extend(_dims)
val *= _val
if _val == Integer0: # _val==0, append (_rules)
data.append(_rules.elements)
else: # _val!=0, append (_rules, other pos->_val)
other_pos = []
for pos in itertools.product(*(range(1, d.value + 1) for d in _dims)):
other_pos.append(ListExpression(*(Integer(i) for i in pos)))
rules_pos = set(rule.elements[0] for rule in _rules.elements)
other_pos = set(other_pos) - rules_pos
other_rules = []
for pos in other_pos:
other_rules.append(Expression(SymbolRule, pos, _val))
data.append(_list.elements[3].elements + tuple(other_rules))
dims = ListExpression(*dims)
return Expression(
SymbolSparseArray,
SymbolAutomatic,
dims,
val,
ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))),
)


class RotationTransform(Builtin):
"""
Expand Down

0 comments on commit 58dbb8b

Please sign in to comment.