diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index a980a8631b5..f3ad8ba10db 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -34,7 +34,8 @@ @dataclass -class _ParsedParamAnnotation: +class _ParsedParam: + name: Union[str, int] = field(init=True) orig_types: set[type] = field(default_factory=set) encoded_types: set[str] = field(default_factory=set) none_allowed: bool = False @@ -54,7 +55,7 @@ class _ParsedReturnAnnotation: @dataclass class _ParsedSignature: fn: Callable = None - params: List[_ParsedParamAnnotation] = field(default_factory=list) + params: List[_ParsedParam] = field(default_factory=list) ret_annotation: _ParsedReturnAnnotation = None @property @@ -93,22 +94,22 @@ def _encode_param_type(t: type) -> str: return tc -def _parse_param_annotation(annotation: Any) -> _ParsedParamAnnotation: +def _parse_param(name: str, annotation: Any) -> _ParsedParam: """ Parse a parameter annotation in a function's signature """ - p_annotation = _ParsedParamAnnotation() + p_param = _ParsedParam(name) if annotation is inspect._empty: - p_annotation.encoded_types.add("O") - p_annotation.none_allowed = True + p_param.encoded_types.add("O") + p_param.none_allowed = True elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union: for t in annotation.__args__: - _parse_type_no_nested(annotation, p_annotation, t) + _parse_type_no_nested(annotation, p_param, t) else: - _parse_type_no_nested(annotation, p_annotation, annotation) - return p_annotation + _parse_type_no_nested(annotation, p_param, annotation) + return p_param -def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation, t: Union[type, str]) -> None: +def _parse_type_no_nested(annotation: Any, p_param: _ParsedParam, t: Union[type, str]) -> None: """ Parse a specific type (top level or nested in a top-level Union annotation) without handling nested types (e.g. a nested Union). The result is stored in the given _ParsedAnnotation object. """ @@ -117,25 +118,25 @@ def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation, # annotation is already a type, and we can remove this line. t = eval(t) if isinstance(t, str) else t - p_annotation.orig_types.add(t) + p_param.orig_types.add(t) tc = _encode_param_type(t) if "[" in tc: - p_annotation.has_array = True + p_param.has_array = True if tc in {"N", "O"}: - p_annotation.none_allowed = True + p_param.none_allowed = True if tc in _NUMPY_INT_TYPE_CODES: - if p_annotation.int_char and p_annotation.int_char != tc: + if p_param.int_char and p_param.int_char != tc: raise DHError(message=f"multiple integer types in annotation: {annotation}, " - f"types: {p_annotation.int_char}, {tc}. this is not supported because it is not " + f"types: {p_param.int_char}, {tc}. this is not supported because it is not " f"clear which Deephaven null value to use when checking for nulls in the argument") - p_annotation.int_char = tc + p_param.int_char = tc if tc in _NUMPY_FLOATING_TYPE_CODES: - if p_annotation.floating_char and p_annotation.floating_char != tc: + if p_param.floating_char and p_param.floating_char != tc: raise DHError(message=f"multiple floating types in annotation: {annotation}, " - f"types: {p_annotation.floating_char}, {tc}. this is not supported because it is not " + f"types: {p_param.floating_char}, {tc}. this is not supported because it is not " f"clear which Deephaven null value to use when checking for nulls in the argument") - p_annotation.floating_char = tc - p_annotation.encoded_types.add(tc) + p_param.floating_char = tc + p_param.encoded_types.add(tc) def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: @@ -182,8 +183,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun p_sig.ret_annotation.encoded_type = rt_char if isinstance(fn, numba.np.ufunc.dufunc.DUFunc): - for p in params: - pa = _ParsedParamAnnotation() + for i, p in enumerate(params): + pa = _ParsedParam(i + 1) pa.encoded_types.add(p) if p in _NUMPY_INT_TYPE_CODES: pa.int_char = p @@ -198,8 +199,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun input_decl = re.sub("[()]", "", input_decl).split(",") output_decl = re.sub("[()]", "", output_decl) - for p, d in zip(params, input_decl): - pa = _ParsedParamAnnotation() + for i, (p, d) in enumerate(zip(params, input_decl)): + pa = _ParsedParam(i + 1) if d: pa.encoded_types.add("[" + p) pa.has_array = True @@ -225,9 +226,10 @@ def _parse_np_ufunc_signature(fn: numpy.ufunc) -> _ParsedSignature: # them in the future (https://github.com/deephaven/deephaven-core/issues/4762) p_sig = _ParsedSignature(fn) if fn.nin > 0: - pa = _ParsedParamAnnotation() - pa.encoded_types.add("O") - p_sig.params = [pa] * fn.nin + for i in range(fn.nin): + pa = _ParsedParam(i + 1) + pa.encoded_types.add("O") + p_sig.params.append(pa) p_sig.ret_annotation = _ParsedReturnAnnotation() p_sig.ret_annotation.encoded_type = "O" return p_sig @@ -249,7 +251,7 @@ def _parse_signature(fn: Callable) -> _ParsedSignature: else: sig = inspect.signature(fn) for n, p in sig.parameters.items(): - p_sig.params.append(_parse_param_annotation(p.annotation)) + p_sig.params.append(_parse_param(n, p.annotation)) p_sig.ret_annotation = _parse_return_annotation(sig.return_annotation) return p_sig @@ -263,11 +265,11 @@ def _is_from_np_type(param_types: set[type], np_type_char: str) -> bool: return False -def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any: +def _convert_arg(param: _ParsedParam, arg: Any) -> Any: """ Convert a single argument to the type specified by the annotation """ if arg is None: if not param.none_allowed: - raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.orig_types}") else: return None @@ -277,12 +279,17 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any: # if it matches one of the encoded types, convert it if encoded_type in param.encoded_types: dtype = dtypes.from_np_dtype(np_dtype) - return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False) + try: + return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False) + except Exception as e: + raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation" + f" {param.encoded_types}" + f"\n{str(e)}") from e # if the annotation is missing, or it is a generic object type, return the arg elif "O" in param.encoded_types: return arg else: - raise TypeError(f"Argument {arg} is not compatible with annotation {param.encoded_types}") + raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.encoded_types}") else: # if the arg is not a Java array specific_types = param.encoded_types - {"N", "O"} # remove NoneType and object type if specific_types: @@ -300,7 +307,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any: if param.none_allowed: return None else: - raise DHError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + raise DHError(f"Argument {param.name!r}: {arg} is not compatible with annotation" + f" {param.orig_types}") else: # return a numpy integer instance only if the annotation is a numpy type if _is_from_np_type(param.orig_types, param.int_char): @@ -332,7 +340,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any: if "O" in param.encoded_types: return arg else: - raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation" + f" {param.orig_types}") else: # if no annotation or generic object, return arg return arg diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py index 79d9f87241f..b47261fcfd8 100644 --- a/py/server/tests/test_numba_guvectorize.py +++ b/py/server/tests/test_numba_guvectorize.py @@ -7,7 +7,7 @@ import numpy as np from numba import guvectorize, int64, int32 -from deephaven import empty_table, dtypes +from deephaven import empty_table, dtypes, DHError from tests.testbase import BaseTestCase a = np.arange(5, dtype=np.int64) @@ -89,6 +89,18 @@ def g(x, res): t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y)") self.assertEqual(t.columns[2].data_type, dtypes.long_array) + def test_type_mismatch_error(self): + # vector input to scalar output function (m)->() + @guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True) + def g(x, res): + res[0] = 0 + for xi in x: + res[0] += xi + + with self.assertRaises(DHError) as cm: + t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)") + self.assertIn("Argument 1", str(cm.exception)) + if __name__ == '__main__': unittest.main() diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_numpy_args.py index 78544bf8ec2..0ad89c841e0 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_numpy_args.py @@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f3(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation") + self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -352,7 +352,7 @@ def f1(p1: str, p2=None) -> bool: t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? `deephaven`: null"]) with self.assertRaises(DHError) as cm: t1 = t.update(["X1 = f1(Y)"]) - self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation") + self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") def f11(p1: Union[str, None], p2=None) -> bool: return p1 is None @@ -366,7 +366,7 @@ def f2(p1: np.datetime64, p2=None) -> bool: t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? now() : null"]) with self.assertRaises(DHError) as cm: t1 = t.update(["X1 = f2(Y)"]) - self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation") + self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") def f21(p1: Union[np.datetime64, None], p2=None) -> bool: return p1 is None @@ -380,7 +380,7 @@ def f3(p1: np.bool_, p2=None) -> bool: t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : null"]) with self.assertRaises(DHError) as cm: t1 = t.update(["X1 = f3(Y)"]) - self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation") + self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : false"]) t1 = t.update(["X1 = f3(Y)"])