Skip to content

Commit

Permalink
More helpful error when UDF arg type check failed (deephaven#5175)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver authored Feb 26, 2024
1 parent e29e7c6 commit 759c0f0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 39 deletions.
77 changes: 43 additions & 34 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@


@dataclass
class _ParsedParamAnnotation:
class _ParsedParam:
name: Union[str, int] = field(init=True)
orig_types: set[type] = field(default_factory=set)
encoded_types: set[str] = field(default_factory=set)
none_allowed: bool = False
Expand All @@ -54,7 +55,7 @@ class _ParsedReturnAnnotation:
@dataclass
class _ParsedSignature:
fn: Callable = None
params: List[_ParsedParamAnnotation] = field(default_factory=list)
params: List[_ParsedParam] = field(default_factory=list)
ret_annotation: _ParsedReturnAnnotation = None

@property
Expand Down Expand Up @@ -93,22 +94,22 @@ def _encode_param_type(t: type) -> str:
return tc


def _parse_param_annotation(annotation: Any) -> _ParsedParamAnnotation:
def _parse_param(name: str, annotation: Any) -> _ParsedParam:
""" Parse a parameter annotation in a function's signature """
p_annotation = _ParsedParamAnnotation()
p_param = _ParsedParam(name)

if annotation is inspect._empty:
p_annotation.encoded_types.add("O")
p_annotation.none_allowed = True
p_param.encoded_types.add("O")
p_param.none_allowed = True
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union:
for t in annotation.__args__:
_parse_type_no_nested(annotation, p_annotation, t)
_parse_type_no_nested(annotation, p_param, t)
else:
_parse_type_no_nested(annotation, p_annotation, annotation)
return p_annotation
_parse_type_no_nested(annotation, p_param, annotation)
return p_param


def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation, t: Union[type, str]) -> None:
def _parse_type_no_nested(annotation: Any, p_param: _ParsedParam, t: Union[type, str]) -> None:
""" Parse a specific type (top level or nested in a top-level Union annotation) without handling nested types
(e.g. a nested Union). The result is stored in the given _ParsedAnnotation object.
"""
Expand All @@ -117,25 +118,25 @@ def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation,
# annotation is already a type, and we can remove this line.
t = eval(t) if isinstance(t, str) else t

p_annotation.orig_types.add(t)
p_param.orig_types.add(t)
tc = _encode_param_type(t)
if "[" in tc:
p_annotation.has_array = True
p_param.has_array = True
if tc in {"N", "O"}:
p_annotation.none_allowed = True
p_param.none_allowed = True
if tc in _NUMPY_INT_TYPE_CODES:
if p_annotation.int_char and p_annotation.int_char != tc:
if p_param.int_char and p_param.int_char != tc:
raise DHError(message=f"multiple integer types in annotation: {annotation}, "
f"types: {p_annotation.int_char}, {tc}. this is not supported because it is not "
f"types: {p_param.int_char}, {tc}. this is not supported because it is not "
f"clear which Deephaven null value to use when checking for nulls in the argument")
p_annotation.int_char = tc
p_param.int_char = tc
if tc in _NUMPY_FLOATING_TYPE_CODES:
if p_annotation.floating_char and p_annotation.floating_char != tc:
if p_param.floating_char and p_param.floating_char != tc:
raise DHError(message=f"multiple floating types in annotation: {annotation}, "
f"types: {p_annotation.floating_char}, {tc}. this is not supported because it is not "
f"types: {p_param.floating_char}, {tc}. this is not supported because it is not "
f"clear which Deephaven null value to use when checking for nulls in the argument")
p_annotation.floating_char = tc
p_annotation.encoded_types.add(tc)
p_param.floating_char = tc
p_param.encoded_types.add(tc)


def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation:
Expand Down Expand Up @@ -182,8 +183,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
p_sig.ret_annotation.encoded_type = rt_char

if isinstance(fn, numba.np.ufunc.dufunc.DUFunc):
for p in params:
pa = _ParsedParamAnnotation()
for i, p in enumerate(params):
pa = _ParsedParam(i + 1)
pa.encoded_types.add(p)
if p in _NUMPY_INT_TYPE_CODES:
pa.int_char = p
Expand All @@ -198,8 +199,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
input_decl = re.sub("[()]", "", input_decl).split(",")
output_decl = re.sub("[()]", "", output_decl)

for p, d in zip(params, input_decl):
pa = _ParsedParamAnnotation()
for i, (p, d) in enumerate(zip(params, input_decl)):
pa = _ParsedParam(i + 1)
if d:
pa.encoded_types.add("[" + p)
pa.has_array = True
Expand All @@ -225,9 +226,10 @@ def _parse_np_ufunc_signature(fn: numpy.ufunc) -> _ParsedSignature:
# them in the future (https://github.com/deephaven/deephaven-core/issues/4762)
p_sig = _ParsedSignature(fn)
if fn.nin > 0:
pa = _ParsedParamAnnotation()
pa.encoded_types.add("O")
p_sig.params = [pa] * fn.nin
for i in range(fn.nin):
pa = _ParsedParam(i + 1)
pa.encoded_types.add("O")
p_sig.params.append(pa)
p_sig.ret_annotation = _ParsedReturnAnnotation()
p_sig.ret_annotation.encoded_type = "O"
return p_sig
Expand All @@ -249,7 +251,7 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
else:
sig = inspect.signature(fn)
for n, p in sig.parameters.items():
p_sig.params.append(_parse_param_annotation(p.annotation))
p_sig.params.append(_parse_param(n, p.annotation))

p_sig.ret_annotation = _parse_return_annotation(sig.return_annotation)
return p_sig
Expand All @@ -263,11 +265,11 @@ def _is_from_np_type(param_types: set[type], np_type_char: str) -> bool:
return False


def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
def _convert_arg(param: _ParsedParam, arg: Any) -> Any:
""" Convert a single argument to the type specified by the annotation """
if arg is None:
if not param.none_allowed:
raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.orig_types}")
else:
return None

Expand All @@ -277,12 +279,17 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
# if it matches one of the encoded types, convert it
if encoded_type in param.encoded_types:
dtype = dtypes.from_np_dtype(np_dtype)
return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False)
try:
return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False)
except Exception as e:
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
f" {param.encoded_types}"
f"\n{str(e)}") from e
# if the annotation is missing, or it is a generic object type, return the arg
elif "O" in param.encoded_types:
return arg
else:
raise TypeError(f"Argument {arg} is not compatible with annotation {param.encoded_types}")
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.encoded_types}")
else: # if the arg is not a Java array
specific_types = param.encoded_types - {"N", "O"} # remove NoneType and object type
if specific_types:
Expand All @@ -300,7 +307,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
if param.none_allowed:
return None
else:
raise DHError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
raise DHError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
f" {param.orig_types}")
else:
# return a numpy integer instance only if the annotation is a numpy type
if _is_from_np_type(param.orig_types, param.int_char):
Expand Down Expand Up @@ -332,7 +340,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
if "O" in param.encoded_types:
return arg
else:
raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
f" {param.orig_types}")
else: # if no annotation or generic object, return arg
return arg

Expand Down
14 changes: 13 additions & 1 deletion py/server/tests/test_numba_guvectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numba import guvectorize, int64, int32

from deephaven import empty_table, dtypes
from deephaven import empty_table, dtypes, DHError
from tests.testbase import BaseTestCase

a = np.arange(5, dtype=np.int64)
Expand Down Expand Up @@ -89,6 +89,18 @@ def g(x, res):
t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y)")
self.assertEqual(t.columns[2].data_type, dtypes.long_array)

def test_type_mismatch_error(self):
# vector input to scalar output function (m)->()
@guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True)
def g(x, res):
res[0] = 0
for xi in x:
res[0] += xi

with self.assertRaises(DHError) as cm:
t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)")
self.assertIn("Argument 1", str(cm.exception))


if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions py/server/tests/test_udf_numpy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool:
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
with self.assertRaises(DHError) as cm:
t2 = t.update(["X1 = f3(null, Y )"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool:
return bool(len(p1)) if p1 is not None else False
Expand All @@ -352,7 +352,7 @@ def f1(p1: str, p2=None) -> bool:
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? `deephaven`: null"])
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f1(Y)"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

def f11(p1: Union[str, None], p2=None) -> bool:
return p1 is None
Expand All @@ -366,7 +366,7 @@ def f2(p1: np.datetime64, p2=None) -> bool:
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? now() : null"])
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f2(Y)"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

def f21(p1: Union[np.datetime64, None], p2=None) -> bool:
return p1 is None
Expand All @@ -380,7 +380,7 @@ def f3(p1: np.bool_, p2=None) -> bool:
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : null"])
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f3(Y)"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : false"])
t1 = t.update(["X1 = f3(Y)"])
Expand Down

0 comments on commit 759c0f0

Please sign in to comment.