Skip to content

Commit

Permalink
SymPy 1.13 fixes (#1620)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenWeber42 authored Aug 14, 2024
1 parent 54b2fa1 commit 5a773ea
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 27 deletions.
2 changes: 1 addition & 1 deletion dace/codegen/tools/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def infer_expr_type(code, symbols=None):
if isinstance(code, (str, float, int, complex)):
parsed_ast = ast.parse(str(code))
elif isinstance(code, sympy.Basic):
parsed_ast = ast.parse(sympy.printing.pycode(code))
parsed_ast = ast.parse(sympy.printing.pycode(code, allow_unknown_functions=True))
elif isinstance(code, SymExpr):
parsed_ast = ast.parse(sympy.printing.pycode(code.expr))
else:
Expand Down
4 changes: 3 additions & 1 deletion dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def _create_einsum_internal(sdfg: SDFG,
if not is_conflicted and init_output is None:
to_init = False

if einsum.is_reduce() and alpha == 1 and (beta == 0 or beta == 1):
if einsum.is_reduce() and symbolic.equal_valued(1, alpha) and (
symbolic.equal_valued(0, beta) or symbolic.equal_valued(1, beta)
):
from dace.libraries.standard.nodes.reduce import Reduce
# Get reduce axes
axes = tuple(i for i, s in enumerate(einsum.inputs[0]) if s not in einsum.output)
Expand Down
37 changes: 27 additions & 10 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs):
if any(not isinstance(s, Number) for s in [start, stop, step]):
shape = (symbolic.int_ceil(stop - start, step), )
else:
shape = (np.ceil((stop - start) / step), )
shape = (np.int64(np.ceil((stop - start) / step)), )

if not isinstance(shape[0], Number) and ('dtype' not in kwargs or kwargs['dtype'] == None):
raise NotImplementedError("The current implementation of numpy.arange requires that the output dtype is given "
Expand All @@ -583,7 +583,12 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs):
dtype = dtypes.dtype_to_typeclass(dtype)
outname, outarr = sdfg.add_temp_transient(shape, dtype)
else:
dtype = dtypes.dtype_to_typeclass(type(shape[0]))
# infer dtype based on args's dtype
# (since the `dtype` keyword argument isn't given, none of the arguments can be symbolic)
if any(isinstance(arg, (float, np.float32, np.float64)) for arg in args):
dtype = dtypes.float64
else:
dtype = dtypes.int64
outname, outarr = sdfg.add_temp_transient(shape, dtype)

state.add_mapped_tasklet(name="_numpy_arange_",
Expand Down Expand Up @@ -4143,30 +4148,42 @@ def view(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, dtype, type

desc = sdfg.arrays[arr]

# Change size of array based on the differences in bytes
bytemult = desc.dtype.bytes / dtype.bytes
bytediv = dtype.bytes / desc.dtype.bytes
orig_bytes = desc.dtype.bytes
view_bytes = dtype.bytes

if view_bytes < orig_bytes and orig_bytes % view_bytes != 0:
raise ValueError("When changing to a smaller dtype, its size must be a divisor of "
"the size of original dtype")

contigdim = next(i for i, s in enumerate(desc.strides) if s == 1)

# For cases that can be recognized, if contiguous dimension is too small
# raise an exception similar to numpy
if (not issymbolic(desc.shape[contigdim], sdfg.constants) and bytemult < 1
and desc.shape[contigdim] % bytediv != 0):
if (not issymbolic(desc.shape[contigdim], sdfg.constants) and orig_bytes < view_bytes
and desc.shape[contigdim] * orig_bytes % view_bytes != 0):
raise ValueError('When changing to a larger dtype, its size must be a divisor of '
'the total size in bytes of the last axis of the array.')

# Create new shape and strides for view
# NOTE: we change sizes by using `(old_size * orig_bytes) // view_bytes`
# Thus, the changed size will be an integer due to integer division.
# If the division created a fraction, the view wouldn't be valid in the first place.
# So, we assume the division will always yield an integer, and, hence,
# the integer division is correct.
# Also, keep in mind that `old_size * (orig_bytes // view_bytes)` is different.
# E.g., if `orig_bytes == 1 and view_bytes == 2`: `old_size * (1 // 2) == old_size * 0`.
newshape = list(desc.shape)
newstrides = [s * bytemult if i != contigdim else s for i, s in enumerate(desc.strides)]
newshape[contigdim] *= bytemult
newstrides = [(s * orig_bytes) // view_bytes if i != contigdim else s for i, s in enumerate(desc.strides)]
# don't use `*=`, because it will break the bracket
newshape[contigdim] = (newshape[contigdim] * orig_bytes) // view_bytes

newarr, _ = sdfg.add_view(arr,
newshape,
dtype,
storage=desc.storage,
strides=newstrides,
allow_conflicts=desc.allow_conflicts,
total_size=desc.total_size * bytemult,
total_size=(desc.total_size * orig_bytes) // view_bytes,
may_alias=desc.may_alias,
alignment=desc.alignment,
find_new_name=True)
Expand Down
10 changes: 5 additions & 5 deletions dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from copy import deepcopy as dc
from dace import dtypes, memlet as mm, properties, data as dt
from dace.symbolic import symstr, equal
from dace.symbolic import symstr, equal, equal_valued
import dace.library
from dace import SDFG, SDFGState
from dace.frontend.common import op_repository as oprepo
Expand Down Expand Up @@ -81,12 +81,12 @@ def make_sdfg(node, parent_state, parent_sdfg):
_, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=outer_array_b.storage)
_, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage)

if node.alpha == 1.0:
if equal_valued(1, node.alpha):
mul_program = "__out = __a * __b"
else:
mul_program = "__out = {} * __a * __b".format(_cast_to_dtype_str(node.alpha, dtype_a))

if node.beta == 1:
if equal_valued(1, node.beta):
state = sdfg.add_state(node.label + "_state")
else:
init_state = sdfg.add_state(node.label + "_initstate")
Expand All @@ -99,13 +99,13 @@ def make_sdfg(node, parent_state, parent_sdfg):
output_nodes = None

# Initialization / beta map
if node.beta == 0:
if equal_valued(0, node.beta):
init_state.add_mapped_tasklet(
'gemm_init', {'_o%d' % i: '0:%s' % symstr(d)
for i, d in enumerate(shape_c)}, {},
'out = 0', {'out': dace.Memlet.simple(mul_out, ','.join(['_o%d' % i for i in range(len(shape_c))]))},
external_edges=True)
elif node.beta == 1:
elif equal_valued(1, node.beta):
# Do nothing for initialization, only update the values
pass
else:
Expand Down
16 changes: 15 additions & 1 deletion dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import pickle
import re
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union
import warnings
import numpy

import sympy.abc
import sympy.printing.str

import packaging.version as packaging_version

from dace import dtypes

DEFAULT_SYMBOL_TYPE = dtypes.int32
Expand All @@ -23,6 +24,19 @@
_sympy_clash = {k: v if v else getattr(sympy.abc, k) for k, v in sympy.abc._clash.items()}


# SymPy 1.13 changes the behavior of `==` such that floats with different precisions
# are always different.
# For DaCe, mostly the comparison of value (ignoring precision) is relevant which
# can be done with `equal_valued`. However, `equal_valued` was only introduced in
# SymPy 1.12, so we fall back to `==` in that case (which ignores precision in those versions).
# For convenience, we provide this functionality in our own SymPy layer.
if packaging_version.Version(sympy.__version__) < packaging_version.Version("1.12"):
def equal_valued(x, y):
return x == y
else:
equal_valued = sympy.core.numbers.equal_valued


class symbol(sympy.Symbol):
""" Defines a symbolic expression. Extends SymPy symbols with DaCe-related
information. """
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
install_requires=[
'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'websockets', 'jinja2',
'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill',
'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"'
'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"', 'packaging'
] + cmake_requires,
extras_require={
'testing': ['coverage', 'pytest-cov', 'scipy', 'absl-py', 'opt_einsum', 'pymlir', 'click'],
Expand Down
5 changes: 3 additions & 2 deletions tests/numpy/einsum_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import pytest
import dace
from dace import symbolic
import numpy as np

M = dace.symbol('M')
Expand Down Expand Up @@ -261,8 +262,8 @@ def tester(A, B):
for node, _ in sdfg.all_nodes_recursive():
if isinstance(node, Einsum):
assert node.einsum_str == 'ij,jk->ik'
assert node.alpha == 1.0
assert node.beta == 1.0
assert symbolic.equal_valued(1, node.alpha)
assert symbolic.equal_valued(1, node.beta)

assert np.allclose(sdfg(A, B), C)

Expand Down
29 changes: 23 additions & 6 deletions tests/numpy/reshape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_reshape_subset_explicit():
assert np.allclose(expected, B)


def test_reinterpret():
def test_reinterpret_smaller():
@dace.program
def reint(A: dace.int32[N]):
C = A.view(dace.int16)
Expand All @@ -161,18 +161,34 @@ def reint(A: dace.int32[N]):
assert np.allclose(expected, A)


def test_reinterpret_larger():
@dace.program
def reint(A: dace.int16[N]):
C = A.view(dace.int32)
C[:] += 1

A = np.random.randint(0, 32767, size=[10], dtype=np.int16)
expected = np.copy(A)
B = expected.view(np.int32)
B[:] += 1

reint(A)
assert np.allclose(expected, A)


def test_reinterpret_invalid():
@dace.program
def reint_invalid(A: dace.float32[5]):
C = A.view(dace.float64)
C[:] += 1

A = np.random.rand(5).astype(np.float32)
try:
with pytest.raises(
ValueError,
match="When changing to a larger dtype, its size must be a divisor of the total size "
"in bytes of the last axis of the array."
):
reint_invalid(A)
raise AssertionError('Program should not be compilable')
except ValueError:
pass


if __name__ == "__main__":
Expand All @@ -184,5 +200,6 @@ def reint_invalid(A: dace.float32[5]):
test_reshape_copy_scoped()
test_reshape_subset()
test_reshape_subset_explicit()
test_reinterpret()
test_reinterpret_smaller()
test_reinterpret_larger()
test_reinterpret_invalid()

0 comments on commit 5a773ea

Please sign in to comment.