Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Li-Xiang-Ideal committed Dec 16, 2023
1 parent 120b189 commit d3ebad4
Showing 1 changed file with 142 additions and 4 deletions.
146 changes: 142 additions & 4 deletions test/eval/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from mathics.core.atoms import Integer
from mathics.core.definitions import Definitions
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.expression import BaseElement, Expression
from mathics.core.list import ListExpression
from mathics.core.symbols import Atom, Symbol, SymbolList, SymbolPlus, SymbolTimes
from mathics.core.symbols import Atom, Symbol, SymbolList
from mathics.eval.scoping import dynamic_scoping
from mathics.eval.tensors import construct_outer

definitions = Definitions(add_builtin=True)
Expand Down Expand Up @@ -110,13 +111,27 @@ def testCartesianProduct(self):
list6 = ListExpression(Integer(6), Integer(7), Integer(8))

expected_result_3 = Expression(
Symbol("System`Outer"), SymbolList, list4, list5, list6
).evaluate(evaluation)

expected_result_4 = Expression(
Symbol("System`Tuples"), ListExpression(list4, list5, list6)
).evaluate(evaluation)

def cond_next_list(item, level) -> bool:
return isinstance(item, Atom) or not item.head.sameQ(SymbolList)

etc_4 = (
cond_next_list, # equals to (lambda item, level: level > 1)
(lambda item: item.elements),
(lambda elements: ListExpression(*elements)), # apply_head
(lambda current: ListExpression(*current)), # apply_f
(lambda current, item: current + (item,)),
False,
evaluation,
)

etc_5 = (
cond_next_list,
(lambda item: item.elements),
(lambda elements: elements), # apply_head
Expand All @@ -126,11 +141,134 @@ def cond_next_list(item, level) -> bool:
evaluation,
)

assert construct_outer([list4, list5, list6], (), etc_4) == expected_result_3
assert (
ListExpression(*construct_outer([list4, list5, list6], (), etc_4))
== expected_result_3
ListExpression(*construct_outer([list4, list5, list6], (), etc_5))
== expected_result_4
)

def testTable(self):
"""
Table can be implemented by construct_outer.
"""
iter1 = [2] # {i, 2}
iter2 = [3, 4] # {j, 3, 4}
iter3 = [5, 1, -2] # {k, 5, 1, -2}

list1 = [1, 2] # {i, {1, 2}}
list2 = [3, 4] # {j, {3, 4}}
list3 = [5, 3, 1] # {k, {5, 3, 1}}

def get_range_1(_iter: list) -> range:
if len(_iter) == 1:
return range(1, _iter[0] + 1)
elif len(_iter) == 2:
return range(_iter[0], _iter[1] + 1)
elif len(_iter) == 3:
pm = 1 if _iter[2] >= 0 else -1
return range(_iter[0], _iter[1] + pm, _iter[2])
else:
raise ValueError("Invalid iterator")

expected_result_1 = [
[[18, 2, -6], [11, -5, -13]],
[[20, 4, -4], [13, -3, -11]],
] # Table[2*i - j^2 + k^2, {i, 2}, {j, 3, 4}, {k, 5, 1, -2}]
# Table[2*i - j^2 + k^2, {{i, {1, 2}}, {j, {3, 4}}, {k, {5, 3, 1}}]

etc_1 = (
(lambda item, level: level > 1), # range always has depth 1
get_range_1,
(lambda elements: elements),
(lambda current: 2 * current[0] - current[1] ** 2 + current[2] ** 2),
(lambda current, item: current + (item,)),
False,
evaluation,
)

etc_2 = (
(lambda item, level: level > 1),
(lambda item: item),
(lambda elements: elements),
(lambda current: 2 * current[0] - current[1] ** 2 + current[2] ** 2),
(lambda current, item: current + (item,)),
False,
evaluation,
)

assert construct_outer([iter1, iter2, iter3], (), etc_1) == expected_result_1
assert construct_outer([list1, list2, list3], (), etc_2) == expected_result_1

# Flattened result

etc_3 = (
(lambda item, level: level > 1),
(lambda item: item),
(lambda elements: elements),
(lambda current: 2 * current[0] - current[1] ** 2 + current[2] ** 2),
(lambda current, item: current + (item,)),
True,
evaluation,
)

expected_result_2 = [18, 2, -6, 11, -5, -13, 20, 4, -4, 13, -3, -11]

assert construct_outer([list1, list2, list3], (), etc_3) == expected_result_2

# M-Expression

iter4 = ListExpression(Symbol("i"), Integer(2))
iter5 = ListExpression(Symbol("j"), Integer(3), Integer(4))
iter6 = ListExpression(Symbol("k"), Integer(5), Integer(1), Integer(-2))

list4 = ListExpression(Symbol("i"), ListExpression(Integer(1), Integer(2)))
list5 = ListExpression(Symbol("j"), ListExpression(Integer(3), Integer(4)))
list6 = ListExpression(
Symbol("k"), ListExpression(Integer(5), Integer(3), Integer(1))
)

expr_to_evaluate = (
Integer(2) * Symbol("i")
- Symbol("j") ** Integer(2)
+ Symbol("k") ** Integer(2)
) # 2*i - j^2 + k^2

expected_result_3 = Expression(
Symbol("System`Table"),
expr_to_evaluate,
iter4,
iter5,
iter6,
).evaluate(evaluation)
# Table[2*i - j^2 + k^2, {i, 2}, {j, 3, 4}, {k, 5, 1, -2}]

def get_range_2(_iter: BaseElement) -> BaseElement:
if isinstance(_iter.elements[1], Atom): # {i, 2}, etc.
_list = (
Expression(Symbol("System`Range"), *_iter.elements[1:])
.evaluate(evaluation)
.elements
)
else: # {i, {1, 2}}, etc.
_list = _iter.elements[1].elements
return ({_iter.elements[0].name: item} for item in _list)

def evaluate_current(current: dict) -> BaseElement:
return dynamic_scoping(expr_to_evaluate.evaluate, current, evaluation)

etc_4 = (
(lambda item, level: level > 1),
get_range_2,
(lambda elements: ListExpression(*elements)), # apply_head
evaluate_current,
(lambda current, item: {**current, **item}),
False,
evaluation,
)

assert construct_outer([iter4, iter5, iter6], {}, etc_4) == expected_result_3
assert construct_outer([list4, list5, list6], {}, etc_4) == expected_result_3


if __name__ == "__main__":
unittest.main()

0 comments on commit d3ebad4

Please sign in to comment.