From 65ca11f38426b82377bbeda5075780adcf59401c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 23 Dec 2024 01:02:24 -0800 Subject: [PATCH] Ignore type hints in undecorated parsed function calls (#1833) Fixes cases in which the user has no control over undecorated nested function calls, which previously led to unclear and unrelated `pyobject` type errors. --- dace/frontend/python/parser.py | 8 ++- dace/frontend/python/preprocessing.py | 2 +- .../python_frontend/callee_autodetect_test.py | 59 +++++++++++++++++++ 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 0faa2e36ce..57422e372a 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -156,7 +156,8 @@ def __init__(self, recompile: bool = True, distributed_compilation: bool = False, method: bool = False, - use_explicit_cf: bool = True): + use_explicit_cf: bool = True, + ignore_type_hints: bool = False): from dace.codegen import compiled_sdfg # Avoid import loops self.f = f @@ -178,6 +179,7 @@ def __init__(self, self.recompile = recompile self.use_explicit_cf = use_explicit_cf self.distributed_compilation = distributed_compilation + self.ignore_type_hints = ignore_type_hints self.global_vars = _get_locals_and_globals(f) self.signature = inspect.signature(f) @@ -557,6 +559,8 @@ def _get_type_annotations( continue ann = sig_arg.annotation + if self.ignore_type_hints: + ann = inspect._empty # Variable-length arguments: obtain from the remainder of given_* if sig_arg.kind is sig_arg.VAR_POSITIONAL: @@ -699,7 +703,7 @@ def _get_type_annotations( # Set __return* arrays from return type annotations rettype = self.signature.return_annotation - if not _is_empty(rettype): + if not self.ignore_type_hints and not _is_empty(rettype): if isinstance(rettype, tuple): for i, subrettype in enumerate(rettype): types[f'__return_{i}'] = create_datadescriptor(subrettype) diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index f51b67ddb2..5e6d962605 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -625,7 +625,7 @@ def global_value_to_node(self, if find_disallowed_statements(sast): return newnode - parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU) + parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU, ignore_type_hints=True) # If method, add the first argument (which disappears due to # being a bound method) and the method's object if parent_object is not None: diff --git a/tests/python_frontend/callee_autodetect_test.py b/tests/python_frontend/callee_autodetect_test.py index e64e1f5a69..6fb786982a 100644 --- a/tests/python_frontend/callee_autodetect_test.py +++ b/tests/python_frontend/callee_autodetect_test.py @@ -8,6 +8,7 @@ from dataclasses import dataclass import numpy as np import pytest +from typing import List, Tuple @dataclass @@ -276,6 +277,61 @@ def program(A): assert np.allclose(A, expected) +def test_type_hints_in_nested_call(): + """ + Tests that type hints are correctly propagated to nested functions, ignoring + existing type hints if the nested function is not decorated. + """ + + def nested(a: int, b: List[float], c) -> Tuple[float, float]: + return np.sum(b) + a, c + + @dace + def outer(a: dace.float64[20], result: dace.float64[2]): + ret1, ret2 = nested(5, a, 3.0) + result[0] = ret1 + result[1] = ret2 + + A = np.random.rand(20) + res = np.zeros(2) + ref = np.copy(res) + ref[0] = np.sum(A) + 5 + ref[1] = 3.0 + outer(A, res) + assert np.allclose(res, ref) + + +@pytest.mark.parametrize('decorated', (False, True)) +def test_explicit_type_hints_in_nested_call(decorated): + """ + Tests that type hints are not ignored if the nested function is decorated. + """ + + if decorated: + + @dace + def nested(a: dace.float64[20], b: dace.float64[16]): + b += a + else: + # This function is not decorated, so the type hints should be ignored + def nested(a: dace.float64[20], b: dace.float64[16]): + b += a + + @dace + def outer(a: dace.float64[20]): + nested(a, a) + + A = np.random.rand(20) + a_ref = A * 2 + + if decorated: + with pytest.raises(SyntaxError): + outer(A) + else: + outer(A) + assert np.allclose(A, a_ref) + + if __name__ == '__main__': test_autodetect_function() test_autodetect_method() @@ -291,3 +347,6 @@ def program(A): test_error_handling() test_nested_class_error_handling() test_loop_unrolling() + test_type_hints_in_nested_call() + test_explicit_type_hints_in_nested_call(False) + test_explicit_type_hints_in_nested_call(True)