Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Li-Xiang-Ideal committed Dec 17, 2023
1 parent a3d3107 commit 27dfd69
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions test/eval/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,140 @@ def evaluate_current(current: dict) -> BaseElement:
assert construct_outer([iter4, iter5, iter6], {}, etc_4) == expected_result_3
assert construct_outer([list4, list5, list6], {}, etc_4) == expected_result_3

def testTensorProduct(self):
"""
Tensor Product can be implemented by construct_outer.
"""
list1 = [[4, 5], [8, 10], [12, 15]]
list2 = [6, 7, 8]

expected_result_1 = [
[[24, 28, 32], [30, 35, 40]],
[[48, 56, 64], [60, 70, 80]],
[[72, 84, 96], [90, 105, 120]],
]

def product_of_list(_list):
result = 1
for item in _list:
result *= item
return result

etc_1 = (
(lambda item, level: not isinstance(item, list)),
(lambda item: item),
(lambda elements: elements),
product_of_list,
(lambda current, item: current + (item,)),
False,
evaluation,
)

etc_2 = (
(lambda item, level: not isinstance(item, list)),
(lambda item: item),
(lambda elements: elements),
(lambda current: current),
(lambda current, item: current * item),
False,
evaluation,
)

assert construct_outer([list1, list2], (), etc_1) == expected_result_1
assert construct_outer([list1, list2], 1, etc_2) == expected_result_1

# M-Expression

list3 = ListExpression(
ListExpression(Integer(4), Integer(5)),
ListExpression(Integer(8), Integer(10)),
ListExpression(Integer(12), Integer(15)),
)
list4 = ListExpression(Integer(6), Integer(7), Integer(8))

expected_result_2 = Expression(
Symbol("System`Outer"), Symbol("System`Times"), list3, list4
).evaluate(evaluation)

def cond_next_list(item, level) -> bool:
return isinstance(item, Atom) or not item.head.sameQ(SymbolList)

etc_3 = (
cond_next_list,
(lambda item: item.elements),
(lambda elements: ListExpression(*elements)),
(lambda current: Expression(Symbol("System`Times"), *current)),
(lambda current, item: current + (item,)),
False,
evaluation,
)

etc_4 = (
cond_next_list,
(lambda item: item.elements),
(lambda elements: ListExpression(*elements)),
(lambda current: current),
(lambda current, item: current * item),
False,
evaluation,
)

assert (
construct_outer([list3, list4], (), etc_3).evaluate(evaluation)
== expected_result_2
)
assert (
construct_outer([list3, list4], Integer(1), etc_4).evaluate(evaluation)
== expected_result_2
)

def testOthers(self):
"""
construct_outer can be used in other cases.
"""
list1 = [[4, 5], [8, [10, 12]], 15] # ragged
list2 = [6, 7, 8]
list3 = [] # empty

expected_result_1 = [
[[24, 28, 32], [30, 35, 40]],
[[48, 56, 64], [[60, 70, 80], [72, 84, 96]]],
[90, 105, 120],
]

expected_result_2 = [
[[(4, 6), (4, 7), (4, 8)], [(5, 6), (5, 7), (5, 8)]],
[[(8, 6), (8, 7), (8, 8)], [([10, 12], 6), ([10, 12], 7), ([10, 12], 8)]],
[(15, 6), (15, 7), (15, 8)],
]

expected_result_3 = [[[[], [], []], [[], [], []]], [[[], [], []], [[], [], []]], [[], [], []]]

etc_1 = (
(lambda item, level: not isinstance(item, list)),
(lambda item: item),
(lambda elements: elements),
(lambda current: current),
(lambda current, item: current * item),
False,
evaluation,
)

etc_2 = (
(lambda item, level: not isinstance(item, list) or level > 2),
(lambda item: item),
(lambda elements: elements),
(lambda current: current),
(lambda current, item: current + (item,)),
False,
evaluation,
)

assert construct_outer([list1, list2], 1, etc_1) == expected_result_1
assert construct_outer([list1, list2], (), etc_2) == expected_result_2
assert construct_outer([list1, list2, list3], (), etc_2) == expected_result_3
assert construct_outer([list3, list1, list2], (), etc_2) == []


if __name__ == "__main__":
unittest.main()

0 comments on commit 27dfd69

Please sign in to comment.