From 1f9167e1a9bb729987a6356422c3a3251b51e6e8 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 22 Mar 2024 11:59:32 -0600 Subject: [PATCH] Code clean up/add more comments/test cases --- .../table/impl/lang/QueryLanguageParser.java | 32 -------------- .../engine/util/PyCallableWrapperJpyImpl.java | 43 ++++++++++++++----- py/server/deephaven/_udf.py | 9 +--- py/server/tests/test_udf_args.py | 9 ++++ .../tests/test_udf_return_java_values.py | 12 +++++- 5 files changed, 53 insertions(+), 52 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 0329488adda..83109869bad 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -1968,38 +1968,6 @@ public static boolean isWideningPrimitiveConversion(Class original, Class return false; } - public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { - if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() - || original.equals(void.class) || target.equals(void.class)) { - throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); - } - - if (original.equals(target)) { - return true; - } - - LanguageParserPrimitiveType originalEnum = LanguageParserPrimitiveType.getPrimitiveType(original); - - switch (originalEnum) { - case BytePrimitive: - if (target == short.class) - return true; - case ShortPrimitive: - case CharPrimitive: // char is unsigned, so it's a lossless conversion to int - if (target == int.class) // this covers all the smaller integer types - return true; - case IntPrimitive: - if (target == long.class) - return true; - break; - case FloatPrimitive: - if (target == double.class) - return true; - break; - } - - return false; - } private enum LanguageParserPrimitiveType { // Including "Enum" (or really, any differentiating string) in these names is important. They're used diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index a420c0eed27..84c75748aed 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -13,7 +13,6 @@ import java.util.*; import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS; -import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isLosslessWideningPrimitiveConversion; import static io.deephaven.util.type.TypeUtils.getUnboxedType; /** @@ -51,7 +50,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaArrayClass.put('l', long[].class); numpyType2JavaArrayClass.put('f', float[].class); numpyType2JavaArrayClass.put('d', double[].class); - numpyType2JavaArrayClass.put('?', boolean[].class); + numpyType2JavaArrayClass.put('?', Boolean[].class); numpyType2JavaArrayClass.put('U', String[].class); numpyType2JavaArrayClass.put('M', Instant[].class); numpyType2JavaArrayClass.put('O', Object[].class); @@ -200,8 +199,8 @@ public void parseSignature() { List parameters = new ArrayList<>(); if (!pyEncodedParamsStr.isEmpty()) { String[] pyEncodedParams = pyEncodedParamsStr.split(","); - for (int i = 0; i < pyEncodedParams.length; i++) { - String[] paramDetail = pyEncodedParams[i].split(":"); + for (String pyEncodedParam : pyEncodedParams) { + String[] paramDetail = pyEncodedParam.split(":"); String paramName = paramDetail[0]; String paramTypeCodes = paramDetail[1]; Set> possibleTypes = new HashSet<>(); @@ -211,9 +210,6 @@ public void parseSignature() { // skip the array type code ti++; possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti))); - if (paramTypeCodes.charAt(ti) == '?') { - possibleTypes.add(Boolean[].class); - } } else if (typeCode == 'N') { possibleTypes.add(NULL_CLASS); } else { @@ -250,14 +246,40 @@ private boolean isSafelyCastable(Set> types, Class type) { return false; } + public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { + if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() + || original.equals(void.class) || target.equals(void.class)) { + throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); + } + + if (original.equals(target)) { + return true; + } + + if (original.equals(byte.class)) { + return target == short.class || target == int.class || target == long.class; + } else if (original.equals(short.class) || original.equals(char.class)) { // char is unsigned, so it's a + // lossless conversion to int + return target == int.class || target == long.class; + } else if (original.equals(int.class)) { + return target == long.class; + } else if (original.equals(float.class)) { + return target == double.class; + } + + return false; + } public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); List parameters = signature.getParameters(); for (int i = 0; i < argTypes.length; i++) { + // if there are more arguments than parameters, we'll need to consider the last parameter as a varargs + // parameter. This is not ideal. We should consider a better way to handle this, i.e. a way to convey that + // the function is variadic. Set> types = - parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes(); + parameters.get(Math.min(i, parameters.size() - 1)).getPossibleTypes(); // to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor // with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the @@ -267,15 +289,14 @@ public void verifyArguments(Class[] argTypes) { } Class t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]); - if (!types.contains(t) && !types.contains(Object.class) - && !isSafelyCastable(types, t)) { + if (!types.contains(t) && !types.contains(Object.class) && !isSafelyCastable(types, t)) { throw new IllegalArgumentException( callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got " + (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i])); } } - }; + } // In vectorized mode, we want to call the vectorized function directly. public PyObject vectorizedCallable() { diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 617d5398594..61584f32f15 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -113,7 +113,7 @@ def _component_np_dtype_char(t: type) -> Optional[str]: numpy ndarray, otherwise return None. """ component_type = None - if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + if sys.version_info > (3, 8): import types if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): component_type = t.__args__[0] @@ -175,10 +175,7 @@ def _is_union_type(t: type) -> bool: if isinstance(t, types.UnionType): return True - if isinstance(t, _GenericAlias) and t.__origin__ == Union: - return True - - return False + return isinstance(t, _GenericAlias) and t.__origin__ == Union def _parse_param(name: str, annotation: Any) -> _ParsedParam: @@ -480,7 +477,6 @@ def _py_udf(fn: Callable): # build a signature string for vectorization by removing NoneType, array char '[', and comma from the encoded types # since vectorization only supports UDFs with a single signature and enforces an exact match, any non-compliant # signature (e.g. Union with more than 1 non-NoneType) will be rejected by the vectorizer. - sig_str_vectorization = re.sub(r"[\[N,]", "", p_sig.encoded) return_array = p_sig.ret_annotation.has_array ret_dtype = dtypes.from_np_dtype(np.dtype(p_sig.ret_annotation.encoded_type[-1])) @@ -505,7 +501,6 @@ def wrapper(*args, **kwargs): j_class = real_ret_dtype.qst_type.clazz() wrapper.return_type = j_class - # wrapper.signature = sig_str_vectorization wrapper.signature = p_sig.encoded return wrapper diff --git a/py/server/tests/test_udf_args.py b/py/server/tests/test_udf_args.py index 8e58050a9d6..9e1c4a1cfc5 100644 --- a/py/server/tests/test_udf_args.py +++ b/py/server/tests/test_udf_args.py @@ -486,6 +486,15 @@ def f(x: bytearray) -> bool: t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) self.assertEqual(t.columns[2].data_type, dtypes.bool_) + def test_non_common_cases(self): + def f1(x: int) -> float: + ... + + def f2(x: float) -> int: + ... + + t = empty_table(1).update("X = f2(f1(ii))") + self.assertEqual(t.columns[0].data_type, dtypes.int_) if __name__ == "__main__": unittest.main() diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index d42b6f8465c..6d897195e21 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -54,8 +54,7 @@ def test_array_return(self): "np.str_": dtypes.string_array, "np.uint16": dtypes.char_array, } - # container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] - container_types = ["list"] + container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] for component_type, dh_dtype in component_types.items(): for container_type in container_types: with self.subTest(component_type=component_type, container_type=container_type): @@ -68,6 +67,15 @@ def test_array_return(self): t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") self.assertEqual(t.columns[2].data_type, dh_dtype) + container_types = ["bytes", "bytearray"] + for container_type in container_types: + with self.subTest(container_type=container_type): + func_decl_str = f"""def fn(col) -> {container_type}:""" + func_body_str = f""" return {container_type}(col)""" + exec("\n".join([func_decl_str, func_body_str]), globals()) + t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") + self.assertEqual(t.columns[2].data_type, dtypes.byte_array) + def test_scalar_return_class_method_not_supported(self): for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype):