Skip to content

Commit

Permalink
Ignore type hints in undecorated parsed function calls (#1833)
Browse files Browse the repository at this point in the history
Fixes cases in which the user has no control over undecorated nested
function calls, which previously led to unclear and unrelated `pyobject`
type errors.
  • Loading branch information
tbennun authored Dec 23, 2024
1 parent aaba591 commit 65ca11f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
8 changes: 6 additions & 2 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def __init__(self,
recompile: bool = True,
distributed_compilation: bool = False,
method: bool = False,
use_explicit_cf: bool = True):
use_explicit_cf: bool = True,
ignore_type_hints: bool = False):
from dace.codegen import compiled_sdfg # Avoid import loops

self.f = f
Expand All @@ -178,6 +179,7 @@ def __init__(self,
self.recompile = recompile
self.use_explicit_cf = use_explicit_cf
self.distributed_compilation = distributed_compilation
self.ignore_type_hints = ignore_type_hints

self.global_vars = _get_locals_and_globals(f)
self.signature = inspect.signature(f)
Expand Down Expand Up @@ -557,6 +559,8 @@ def _get_type_annotations(
continue

ann = sig_arg.annotation
if self.ignore_type_hints:
ann = inspect._empty

# Variable-length arguments: obtain from the remainder of given_*
if sig_arg.kind is sig_arg.VAR_POSITIONAL:
Expand Down Expand Up @@ -699,7 +703,7 @@ def _get_type_annotations(

# Set __return* arrays from return type annotations
rettype = self.signature.return_annotation
if not _is_empty(rettype):
if not self.ignore_type_hints and not _is_empty(rettype):
if isinstance(rettype, tuple):
for i, subrettype in enumerate(rettype):
types[f'__return_{i}'] = create_datadescriptor(subrettype)
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def global_value_to_node(self,
if find_disallowed_statements(sast):
return newnode

parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU)
parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU, ignore_type_hints=True)
# If method, add the first argument (which disappears due to
# being a bound method) and the method's object
if parent_object is not None:
Expand Down
59 changes: 59 additions & 0 deletions tests/python_frontend/callee_autodetect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass
import numpy as np
import pytest
from typing import List, Tuple


@dataclass
Expand Down Expand Up @@ -276,6 +277,61 @@ def program(A):
assert np.allclose(A, expected)


def test_type_hints_in_nested_call():
"""
Tests that type hints are correctly propagated to nested functions, ignoring
existing type hints if the nested function is not decorated.
"""

def nested(a: int, b: List[float], c) -> Tuple[float, float]:
return np.sum(b) + a, c

@dace
def outer(a: dace.float64[20], result: dace.float64[2]):
ret1, ret2 = nested(5, a, 3.0)
result[0] = ret1
result[1] = ret2

A = np.random.rand(20)
res = np.zeros(2)
ref = np.copy(res)
ref[0] = np.sum(A) + 5
ref[1] = 3.0
outer(A, res)
assert np.allclose(res, ref)


@pytest.mark.parametrize('decorated', (False, True))
def test_explicit_type_hints_in_nested_call(decorated):
"""
Tests that type hints are not ignored if the nested function is decorated.
"""

if decorated:

@dace
def nested(a: dace.float64[20], b: dace.float64[16]):
b += a
else:
# This function is not decorated, so the type hints should be ignored
def nested(a: dace.float64[20], b: dace.float64[16]):
b += a

@dace
def outer(a: dace.float64[20]):
nested(a, a)

A = np.random.rand(20)
a_ref = A * 2

if decorated:
with pytest.raises(SyntaxError):
outer(A)
else:
outer(A)
assert np.allclose(A, a_ref)


if __name__ == '__main__':
test_autodetect_function()
test_autodetect_method()
Expand All @@ -291,3 +347,6 @@ def program(A):
test_error_handling()
test_nested_class_error_handling()
test_loop_unrolling()
test_type_hints_in_nested_call()
test_explicit_type_hints_in_nested_call(False)
test_explicit_type_hints_in_nested_call(True)

0 comments on commit 65ca11f

Please sign in to comment.