Skip to content

Commit

Permalink
Update tensors.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Li-Xiang-Ideal authored Nov 25, 2023
1 parent 58dbb8b commit 9583f38
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions mathics/builtin/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,20 @@
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.symbols import Atom, Symbol, SymbolFalse, SymbolTrue
from mathics.core.systemsymbols import SymbolAutomatic, SymbolRule, SymbolSparseArray
from mathics.core.symbols import (
Atom,
Symbol,
SymbolFalse,
SymbolList,
SymbolTimes,
SymbolTrue,
)
from mathics.core.systemsymbols import (
SymbolAutomatic,
SymbolNormal,
SymbolRule,
SymbolSparseArray,
)
from mathics.eval.parts import get_part


Expand Down Expand Up @@ -300,26 +312,42 @@ class Outer(Builtin):
= {{0, 1, 0}, {1, 0, 1}, {0, ComplexInfinity, 0}}
"""

rules = {
"Outer[f_, a___, b_SparseArray, c___] /; UnsameQ[f, Times]": "Outer[f, a, b // Normal, c]",
}

summary_text = "generalized outer product"

def eval(self, f, lists, evaluation: Evaluation):
"Outer[f_, lists__] /; Or[SameQ[f, Times], Not[MemberQ[{lists}, _SparseArray]]]"
"Outer[f_, lists__]"

# If f=!=Times, or lists contain both SparseArray and List, then convert all SparseArrays to Lists
lists = lists.get_sequence()
head = None
sparse_to_list = f != SymbolTimes
contain_sparse = False
comtain_list = False
for _list in lists:
if _list.head.sameQ(SymbolSparseArray):
contain_sparse = True
if _list.head.sameQ(SymbolList):
comtain_list = True
sparse_to_list = sparse_to_list or (contain_sparse and comtain_list)
if sparse_to_list:
break
if sparse_to_list:
new_lists = []
for _list in lists:
if isinstance(_list, Atom):
evaluation.message("Outer", "normal")
return
if sparse_to_list:
if _list.head.sameQ(SymbolSparseArray):
_list = Expression(SymbolNormal, _list).evaluate(evaluation)
new_lists.append(_list)
if head is None:
head = _list.head
elif not _list.head.sameQ(head):
evaluation.message("Outer", "heads", head, _list.head)
return
if sparse_to_list:
lists = new_lists

def rec(item, rest_lists, current):
evaluation.check_stopped()
Expand Down Expand Up @@ -391,7 +419,6 @@ def rec_sparse(item, rest_lists, current):
ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))),
)


class RotationTransform(Builtin):
"""
<url>:WMA link: https://reference.wolfram.com/language/ref/RotationTransform.html</url>
Expand Down

0 comments on commit 9583f38

Please sign in to comment.