Skip to content

Commit

Permalink
Merge pull request #945 from Mathics3/rewrite_eval_tensors
Browse files Browse the repository at this point in the history
Combine all ``rec()`` related to ``Outer``
  • Loading branch information
rocky authored Dec 18, 2023
2 parents 896c8f2 + 0264d73 commit 4df1273
Show file tree
Hide file tree
Showing 3 changed files with 542 additions and 55 deletions.
184 changes: 129 additions & 55 deletions mathics/eval/tensors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import itertools
from typing import Union

from sympy.combinatorics import Permutation
from sympy.utilities.iterables import permutations

from mathics.core.atoms import Integer, Integer0, Integer1, String
from mathics.core.convert.python import from_python
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.expression import BaseElement, Expression
from mathics.core.list import ListExpression
from mathics.core.symbols import (
Atom,
Expand Down Expand Up @@ -76,6 +76,94 @@ def get_dimensions(expr, head=None):
return [len(expr.elements)] + sub


def to_std_sparse_array(sparse_array, evaluation: Evaluation):
"Get a SparseArray equivalent to input with default value 0."

if sparse_array.elements[2] == Integer0:
return sparse_array
else:
return Expression(
SymbolSparseArray, Expression(SymbolNormal, sparse_array)
).evaluate(evaluation)


def construct_outer(lists, current, const_etc: tuple) -> Union[list, BaseElement]:
"""
Recursively unpacks lists to construct outer product.
------------------------------------
Unlike direct products, outer (tensor) products require traversing the
lowest level of each list, hence we recursively unpacking lists until
the lowest level is reached.
Parameters:
``item``: the current item to be unpacked (if not at lowest level),
or joined to current (if at lowest level)
``rest_lists``: the rest of lists to be unpacked
``current``: the current lowest level elements
``level``: the current level (unused yet, will be used in
``Outer[f_, lists__, n_]`` in the future)
``const_etc``: a tuple of functions used in unpacking, remains constant
throughout the recursion.
Format of ``const_etc``:
```
(
cond_next_list, # return True/False to unpack the next list/this list at next level
get_elements, # get elements of list, tuple, ListExpression, etc.
apply_head, # e.g. lambda elements: Expression(head, *elements)
apply_f, # e.g. lambda current: Expression(f, *current)
join_elem, # join current lowest level elements (i.e. current) with a new one
if_flattened, # True for result as flattened list, False for result as nested list
evaluation, # evaluation: Evaluation
)
```
For those unfamiliar with ``construct_outer``, ``ConstructOuterTest``
in ``test/eval/test_tensors.py`` provides a detailed introduction and
several good examples.
"""
(
cond_next_list, # return True when the next list should be unpacked
get_elements, # get elements of list, tuple, ListExpression, etc.
apply_head, # e.g. lambda elements: Expression(head, *elements)
apply_f, # e.g. lambda current: Expression(f, *current)
join_elem, # join current lowest level elements (i.e. current) with a new one
if_flatten, # True for result as flattened list ({a,b,c,d}), False for result as nested list ({{a,b},{c,d}})
evaluation, # evaluation: Evaluation
) = const_etc

_apply_f = (lambda current: (apply_f(current),)) if if_flatten else apply_f

# Recursive step of unpacking
def _unpack_outer(
item, rest_lists, current, level: int
) -> Union[list, BaseElement]:
evaluation.check_stopped()
if cond_next_list(item, level): # unpack next list
if rest_lists:
return _unpack_outer(
rest_lists[0], rest_lists[1:], join_elem(current, item), 1
) # unpacking of a list always start from level 1
else:
return _apply_f(join_elem(current, item))
else: # unpack this list at next level
elements = []
action = elements.extend if if_flatten else elements.append
# elements.extend flattens the result as list instead of as ListExpression
for element in get_elements(item):
action(_unpack_outer(element, rest_lists, current, level + 1))
return apply_head(elements)

return _unpack_outer(lists[0], lists[1:], current, 1)


def eval_Inner(f, list1, list2, g, evaluation: Evaluation):
"Evaluates recursively the inner product of list1 and list2"

Expand Down Expand Up @@ -124,6 +212,10 @@ def summand(i):
def eval_Outer(f, lists, evaluation: Evaluation):
"Evaluates recursively the outer product of lists"

if isinstance(lists, Atom):
evaluation.message("Outer", "normal")
return

# If f=!=Times, or lists contain both SparseArray and List, then convert all SparseArrays to Lists
lists = lists.get_sequence()
head = None
Expand Down Expand Up @@ -156,74 +248,56 @@ def eval_Outer(f, lists, evaluation: Evaluation):
if sparse_to_list:
lists = new_lists

def rec(item, rest_lists, current):
evaluation.check_stopped()
if isinstance(item, Atom) or not item.head.sameQ(head):
if rest_lists:
return rec(rest_lists[0], rest_lists[1:], current + [item])
else:
return Expression(f, *(current + [item]))
else:
elements = []
for element in item.elements:
elements.append(rec(element, rest_lists, current))
return Expression(head, *elements)

def rec_sparse(item, rest_lists, current):
evaluation.check_stopped()
if isinstance(item, tuple): # (rules)
elements = []
for element in item:
elements.extend(rec_sparse(element, rest_lists, current))
return tuple(elements)
else: # rule
_pos, _val = item.elements
if rest_lists:
return rec_sparse(
rest_lists[0],
rest_lists[1:],
(current[0] + _pos.elements, current[1] * _val),
)
else:
return (
Expression(
SymbolRule,
ListExpression(*(current[0] + _pos.elements)),
current[1] * _val,
),
)

# head != SparseArray
if not head.sameQ(SymbolSparseArray):
return rec(lists[0], lists[1:], [])

def cond_next_list(item, level) -> bool:
return isinstance(item, Atom) or not item.head.sameQ(head)

etc = (
cond_next_list,
(lambda item: item.elements), # get_elements
(lambda elements: Expression(head, *elements)), # apply_head
(lambda current: Expression(f, *current)), # apply_f
(lambda current, item: current + (item,)), # join_elem
False, # if_flatten
evaluation,
)
return construct_outer(lists, (), etc)

# head == SparseArray
dims = []
val = Integer1
data = [] # data = [(rules), ...]
for _list in lists:
_dims, _val, _rules = _list.elements[1:]
_dims, _val = _list.elements[1:3]
dims.extend(_dims)
val *= _val
if _val == Integer0: # _val==0, append (_rules)
data.append(_rules.elements)
else: # _val!=0, append (_rules, other pos->_val)
other_pos = []
for pos in itertools.product(*(range(1, d.value + 1) for d in _dims)):
other_pos.append(ListExpression(*(Integer(i) for i in pos)))
rules_pos = set(rule.elements[0] for rule in _rules.elements)
other_pos = set(other_pos) - rules_pos
other_rules = []
for pos in other_pos:
other_rules.append(Expression(SymbolRule, pos, _val))
data.append(_rules.elements + tuple(other_rules))
dims = ListExpression(*dims)

def sparse_cond_next_list(item, level) -> bool:
return isinstance(item, Atom) or not item.head.sameQ(head)

def sparse_apply_Rule(current) -> tuple:
return Expression(SymbolRule, ListExpression(*current[0]), current[1])

def sparse_join_elem(current, item) -> tuple:
return (current[0] + item.elements[0].elements, current[1] * item.elements[1])

etc = (
sparse_cond_next_list,
(lambda item: to_std_sparse_array(item, evaluation).elements[3].elements),
(lambda elements: elements), # apply_head
sparse_apply_Rule, # apply_f
sparse_join_elem, # join_elem
True, # if_flatten
evaluation,
)
return Expression(
SymbolSparseArray,
SymbolAutomatic,
dims,
val,
ListExpression(*rec_sparse(data[0], data[1:], ((), Integer1))),
ListExpression(*construct_outer(lists, ((), Integer1), etc)),
)


Expand Down
1 change: 1 addition & 0 deletions test/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
Loading

0 comments on commit 4df1273

Please sign in to comment.