diff --git a/dace/codegen/tools/type_inference.py b/dace/codegen/tools/type_inference.py index f159088461..893866522f 100644 --- a/dace/codegen/tools/type_inference.py +++ b/dace/codegen/tools/type_inference.py @@ -60,7 +60,7 @@ def infer_expr_type(code, symbols=None): if isinstance(code, (str, float, int, complex)): parsed_ast = ast.parse(str(code)) elif isinstance(code, sympy.Basic): - parsed_ast = ast.parse(sympy.printing.pycode(code)) + parsed_ast = ast.parse(sympy.printing.pycode(code, allow_unknown_functions=True)) elif isinstance(code, SymExpr): parsed_ast = ast.parse(sympy.printing.pycode(code.expr)) else: diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index f678cdea58..18e40d57f0 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -275,7 +275,9 @@ def _create_einsum_internal(sdfg: SDFG, if not is_conflicted and init_output is None: to_init = False - if einsum.is_reduce() and alpha == 1 and (beta == 0 or beta == 1): + if einsum.is_reduce() and symbolic.equal_valued(1, alpha) and ( + symbolic.equal_valued(0, beta) or symbolic.equal_valued(1, beta) + ): from dace.libraries.standard.nodes.reduce import Reduce # Get reduce axes axes = tuple(i for i, s in enumerate(einsum.inputs[0]) if s not in einsum.output) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 8c123f6bfe..ce35d7c9a1 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -568,7 +568,7 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): if any(not isinstance(s, Number) for s in [start, stop, step]): shape = (symbolic.int_ceil(stop - start, step), ) else: - shape = (np.ceil((stop - start) / step), ) + shape = (np.int64(np.ceil((stop - start) / step)), ) if not isinstance(shape[0], Number) and ('dtype' not in kwargs or kwargs['dtype'] == None): raise NotImplementedError("The current implementation of numpy.arange requires that the output dtype is given " @@ -583,7 +583,12 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): dtype = dtypes.dtype_to_typeclass(dtype) outname, outarr = sdfg.add_temp_transient(shape, dtype) else: - dtype = dtypes.dtype_to_typeclass(type(shape[0])) + # infer dtype based on args's dtype + # (since the `dtype` keyword argument isn't given, none of the arguments can be symbolic) + if any(isinstance(arg, (float, np.float32, np.float64)) for arg in args): + dtype = dtypes.float64 + else: + dtype = dtypes.int64 outname, outarr = sdfg.add_temp_transient(shape, dtype) state.add_mapped_tasklet(name="_numpy_arange_", @@ -4143,22 +4148,34 @@ def view(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, dtype, type desc = sdfg.arrays[arr] - # Change size of array based on the differences in bytes - bytemult = desc.dtype.bytes / dtype.bytes - bytediv = dtype.bytes / desc.dtype.bytes + orig_bytes = desc.dtype.bytes + view_bytes = dtype.bytes + + if view_bytes < orig_bytes and orig_bytes % view_bytes != 0: + raise ValueError("When changing to a smaller dtype, its size must be a divisor of " + "the size of original dtype") + contigdim = next(i for i, s in enumerate(desc.strides) if s == 1) # For cases that can be recognized, if contiguous dimension is too small # raise an exception similar to numpy - if (not issymbolic(desc.shape[contigdim], sdfg.constants) and bytemult < 1 - and desc.shape[contigdim] % bytediv != 0): + if (not issymbolic(desc.shape[contigdim], sdfg.constants) and orig_bytes < view_bytes + and desc.shape[contigdim] * orig_bytes % view_bytes != 0): raise ValueError('When changing to a larger dtype, its size must be a divisor of ' 'the total size in bytes of the last axis of the array.') # Create new shape and strides for view + # NOTE: we change sizes by using `(old_size * orig_bytes) // view_bytes` + # Thus, the changed size will be an integer due to integer division. + # If the division created a fraction, the view wouldn't be valid in the first place. + # So, we assume the division will always yield an integer, and, hence, + # the integer division is correct. + # Also, keep in mind that `old_size * (orig_bytes // view_bytes)` is different. + # E.g., if `orig_bytes == 1 and view_bytes == 2`: `old_size * (1 // 2) == old_size * 0`. newshape = list(desc.shape) - newstrides = [s * bytemult if i != contigdim else s for i, s in enumerate(desc.strides)] - newshape[contigdim] *= bytemult + newstrides = [(s * orig_bytes) // view_bytes if i != contigdim else s for i, s in enumerate(desc.strides)] + # don't use `*=`, because it will break the bracket + newshape[contigdim] = (newshape[contigdim] * orig_bytes) // view_bytes newarr, _ = sdfg.add_view(arr, newshape, @@ -4166,7 +4183,7 @@ def view(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, dtype, type storage=desc.storage, strides=newstrides, allow_conflicts=desc.allow_conflicts, - total_size=desc.total_size * bytemult, + total_size=(desc.total_size * orig_bytes) // view_bytes, may_alias=desc.may_alias, alignment=desc.alignment, find_new_name=True) diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index d78e54eb6e..1f11c5dc17 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -1,7 +1,7 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. from copy import deepcopy as dc from dace import dtypes, memlet as mm, properties, data as dt -from dace.symbolic import symstr, equal +from dace.symbolic import symstr, equal, equal_valued import dace.library from dace import SDFG, SDFGState from dace.frontend.common import op_repository as oprepo @@ -81,12 +81,12 @@ def make_sdfg(node, parent_state, parent_sdfg): _, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=outer_array_b.storage) _, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage) - if node.alpha == 1.0: + if equal_valued(1, node.alpha): mul_program = "__out = __a * __b" else: mul_program = "__out = {} * __a * __b".format(_cast_to_dtype_str(node.alpha, dtype_a)) - if node.beta == 1: + if equal_valued(1, node.beta): state = sdfg.add_state(node.label + "_state") else: init_state = sdfg.add_state(node.label + "_initstate") @@ -99,13 +99,13 @@ def make_sdfg(node, parent_state, parent_sdfg): output_nodes = None # Initialization / beta map - if node.beta == 0: + if equal_valued(0, node.beta): init_state.add_mapped_tasklet( 'gemm_init', {'_o%d' % i: '0:%s' % symstr(d) for i, d in enumerate(shape_c)}, {}, 'out = 0', {'out': dace.Memlet.simple(mul_out, ','.join(['_o%d' % i for i in range(len(shape_c))]))}, external_edges=True) - elif node.beta == 1: + elif equal_valued(1, node.beta): # Do nothing for initialization, only update the values pass else: diff --git a/dace/symbolic.py b/dace/symbolic.py index 7fefade69b..6218bbe715 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -5,12 +5,13 @@ import pickle import re from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union -import warnings import numpy import sympy.abc import sympy.printing.str +import packaging.version as packaging_version + from dace import dtypes DEFAULT_SYMBOL_TYPE = dtypes.int32 @@ -23,6 +24,19 @@ _sympy_clash = {k: v if v else getattr(sympy.abc, k) for k, v in sympy.abc._clash.items()} +# SymPy 1.13 changes the behavior of `==` such that floats with different precisions +# are always different. +# For DaCe, mostly the comparison of value (ignoring precision) is relevant which +# can be done with `equal_valued`. However, `equal_valued` was only introduced in +# SymPy 1.12, so we fall back to `==` in that case (which ignores precision in those versions). +# For convenience, we provide this functionality in our own SymPy layer. +if packaging_version.Version(sympy.__version__) < packaging_version.Version("1.12"): + def equal_valued(x, y): + return x == y +else: + equal_valued = sympy.core.numbers.equal_valued + + class symbol(sympy.Symbol): """ Defines a symbolic expression. Extends SymPy symbols with DaCe-related information. """ diff --git a/setup.py b/setup.py index d385abb9e1..614d168c41 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ install_requires=[ 'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'websockets', 'jinja2', 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', - 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' + 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"', 'packaging' ] + cmake_requires, extras_require={ 'testing': ['coverage', 'pytest-cov', 'scipy', 'absl-py', 'opt_einsum', 'pymlir', 'click'], diff --git a/tests/numpy/einsum_test.py b/tests/numpy/einsum_test.py index 2128d26565..89ab253fd2 100644 --- a/tests/numpy/einsum_test.py +++ b/tests/numpy/einsum_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace +from dace import symbolic import numpy as np M = dace.symbol('M') @@ -261,8 +262,8 @@ def tester(A, B): for node, _ in sdfg.all_nodes_recursive(): if isinstance(node, Einsum): assert node.einsum_str == 'ij,jk->ik' - assert node.alpha == 1.0 - assert node.beta == 1.0 + assert symbolic.equal_valued(1, node.alpha) + assert symbolic.equal_valued(1, node.beta) assert np.allclose(sdfg(A, B), C) diff --git a/tests/numpy/reshape_test.py b/tests/numpy/reshape_test.py index f93e38c0fd..5e880f7cf7 100644 --- a/tests/numpy/reshape_test.py +++ b/tests/numpy/reshape_test.py @@ -146,7 +146,7 @@ def test_reshape_subset_explicit(): assert np.allclose(expected, B) -def test_reinterpret(): +def test_reinterpret_smaller(): @dace.program def reint(A: dace.int32[N]): C = A.view(dace.int16) @@ -161,6 +161,21 @@ def reint(A: dace.int32[N]): assert np.allclose(expected, A) +def test_reinterpret_larger(): + @dace.program + def reint(A: dace.int16[N]): + C = A.view(dace.int32) + C[:] += 1 + + A = np.random.randint(0, 32767, size=[10], dtype=np.int16) + expected = np.copy(A) + B = expected.view(np.int32) + B[:] += 1 + + reint(A) + assert np.allclose(expected, A) + + def test_reinterpret_invalid(): @dace.program def reint_invalid(A: dace.float32[5]): @@ -168,11 +183,12 @@ def reint_invalid(A: dace.float32[5]): C[:] += 1 A = np.random.rand(5).astype(np.float32) - try: + with pytest.raises( + ValueError, + match="When changing to a larger dtype, its size must be a divisor of the total size " + "in bytes of the last axis of the array." + ): reint_invalid(A) - raise AssertionError('Program should not be compilable') - except ValueError: - pass if __name__ == "__main__": @@ -184,5 +200,6 @@ def reint_invalid(A: dace.float32[5]): test_reshape_copy_scoped() test_reshape_subset() test_reshape_subset_explicit() - test_reinterpret() + test_reinterpret_smaller() + test_reinterpret_larger() test_reinterpret_invalid()