diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 57fb3828adf..42d44163063 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -504,7 +504,6 @@ def _parse_signature(fn: Callable) -> _ParsedSignature: p_sig.ret_annotation = _parse_return_annotation(t) return p_sig - def _udf_parser(fn: Callable): """A decorator that acts as a transparent translator for Python UDFs used in Deephaven query formulas between Python and Java. This decorator is intended for internal use by the Deephaven query engine and should not be used by @@ -528,7 +527,7 @@ def _udf_parser(fn: Callable): ret_dtype = dtypes.from_np_dtype(np.dtype(ret_np_char if ret_np_char != "X" else "O")) @wraps(fn) - def _udf_decorator(encoded_arg_types: str, for_vectorization: bool): + def _udf_decorator(encoded_arg_types: str, for_vectorization: bool) -> Callable: """The actual decorator that wraps the Python UDF and converts the arguments and return values. It is called by the query engine with the runtime argument types to create a wrapper that can efficiently convert the arguments and return values based on the provided argument types and the parsed parameters of the @@ -536,82 +535,23 @@ def _udf_decorator(encoded_arg_types: str, for_vectorization: bool): """ arg_conv_needed = p_sig.prepare_auto_arg_conv(encoded_arg_types) p_sig.ret_annotation.setup_return_converter() + ret_converter = p_sig.ret_annotation.ret_converter + nonlocal ret_dtype # used in converting array-type return values, bring it into the local scope before exec() - if not for_vectorization: - if not arg_conv_needed and p_sig.ret_annotation.encoded_type == "O": - return fn - - def _wrapper(*args, **kwargs): - if arg_conv_needed: - converted_args = [param.arg_converter(arg) if param.arg_converter else arg - for param, arg in zip(p_sig.params, args)] - - # if the number of arguments is more than the number of parameters, treat the last parameter as a - # vararg and use its arg_converter to convert the rest of the arguments - if len(args) > len(p_sig.params): - arg_converter = p_sig.params[-1].arg_converter - converted_args.extend([arg_converter(arg) if arg_converter else arg - for arg in args[len(converted_args):]]) - else: - converted_args = args - # kwargs are not converted because they are not used in the UDFs - ret = fn(*converted_args, **kwargs) - if return_array: - return dtypes.array(ret_dtype, ret) - else: - return p_sig.ret_annotation.ret_converter(ret) if p_sig.ret_annotation.ret_converter else ret - - return _wrapper - else: # for vectorization - def _vectorization_wrapper(*args): - if len(args) != len(p_sig.params) + 2: - raise ValueError( - f"The number of arguments doesn't match the function ({p_sig.fn.__name__}) signature. " - f"{len(args) - 2}, {p_sig.encoded}") - if args[0] <= 0: - raise ValueError( - f"The chunk size argument must be a positive integer for vectorized function (" - f"{p_sig.fn.__name__}). {args[0]}") - - chunk_size = args[0] - chunk_result = args[1] - if args[2:]: - vectorized_args = zip(*args[2:]) - for i in range(chunk_size): - scalar_args = next(vectorized_args) - if arg_conv_needed: - converted_args = [param.arg_converter(arg) if param.arg_converter else arg - for param, arg in zip(p_sig.params, scalar_args)] - - # if the number of arguments is more than the number of parameters, treat the last parameter - # as a vararg and use its arg_converter to convert the rest of the arguments - if len(args) > len(p_sig.params): - arg_converter = p_sig.params[-1].arg_converter - converted_args.extend([arg_converter(arg) if arg_converter else arg - for arg in scalar_args[len(converted_args):]]) - else: - converted_args = scalar_args + if not for_vectorization and not arg_conv_needed and not ret_converter and not return_array: + # no wrapper needed + return fn - ret = fn(*converted_args) - if return_array: - chunk_result[i] = dtypes.array(ret_dtype, ret) - else: - chunk_result[i] = p_sig.ret_annotation.ret_converter( - ret) if p_sig.ret_annotation.ret_converter else ret - else: - for i in range(chunk_size): - ret = fn() - if return_array: - chunk_result[i] = dtypes.array(ret_dtype, ret) - else: - chunk_result[i] = p_sig.ret_annotation.ret_converter( - ret) if p_sig.ret_annotation.ret_converter else ret - return chunk_result + _wrapper_str = _gen_wrapper_code(p_sig, for_vectorization, arg_conv_needed, return_array) + scope = {**globals(), **locals()} + exec(_wrapper_str, scope) + + if for_vectorization and test_vectorization: + global vectorized_count + vectorized_count += 1 + + return scope["_wrapper"] - if test_vectorization: - global vectorized_count - vectorized_count += 1 - return _vectorization_wrapper _udf_decorator.j_name = ret_dtype.j_name real_ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(ret_dtype, dtypes.PyObject) if return_array else ret_dtype @@ -625,3 +565,81 @@ def _vectorization_wrapper(*args): _udf_decorator.signature = p_sig.encoded return _udf_decorator + +# region Wrapper Code Generation +# for non-vectorize-able UDFs +INDENT_STR= " " * 4 +WRAPPER_HEADER = """def _wrapper(*args):""" +ARG_CONV = """ +converted_args = [param.arg_converter(arg) if param.arg_converter else arg + for param, arg in zip(p_sig.params, args)] + +# if the number of arguments is more than the number of parameters, treat the last parameter as a +# vararg and use its arg_converter to convert the rest of the arguments +if len(args) > len(p_sig.params): + arg_converter = p_sig.params[-1].arg_converter + converted_args.extend([arg_converter(arg) if arg_converter else arg for arg in args[len(converted_args):]]) +ret = fn(*converted_args)""" +NO_ARG_CONV = """ret = fn(*args)""" +ARRAY_RET = """return dtypes.array(ret_dtype, ret)""" +SCALAR_RET = """return ret_converter(ret) if ret_converter else ret""" + +# for vectorize-able UDFs +V_WRAPPER_HEADER = """ +def _wrapper(*args): + chunk_size = args[0] + chunk_result = args[1]""" +V_ARRAY_RET = """chunk_result[i] = dtypes.array(ret_dtype, ret)""" +V_SCALAR_RET = """chunk_result[i] = ret_converter(ret) if ret_converter else ret""" +V_ZIP_CHUNK_ARGS = """vectorized_args = zip(*args[2:])""" +V_ARG_CONV = """ +scalar_args = next(vectorized_args) +converted_args = [param.arg_converter(arg) if param.arg_converter else arg + for param, arg in zip(p_sig.params, scalar_args)] +# if the number of arguments is more than the number of parameters, treat the last +# parameter as a vararg and use its arg_converter to convert the rest of the arguments +if len(args) > len(p_sig.params): + arg_converter = p_sig.params[-1].arg_converter + converted_args.extend([arg_converter(arg) if arg_converter else arg + for arg in scalar_args[len(converted_args):]]) +ret = fn(*converted_args)""" +V_NO_ARG_CONV = """ret = fn(*next(vectorized_args))""" +V_NO_ARG = """ret = fn()""" +V_LOOP_CHUNK = """for i in range(chunk_size):""" +V_RET_CHUNK = "return chunk_result" + + +def _gen_wrapper_code(p_sig: _ParsedSignature, for_vectorization: bool, arg_conv_needed: bool, return_array: bool) -> str: + """ Generate the wrapper code for the UDF based on the parsed signature and the context of the UDF usage.""" + if not for_vectorization: + conv_str = ARG_CONV if arg_conv_needed else NO_ARG_CONV + ret_str = ARRAY_RET if return_array else SCALAR_RET + wrapper_str = (WRAPPER_HEADER + "\n" + + INDENT_STR + conv_str.replace("\n", "\n" + INDENT_STR) + "\n" + + INDENT_STR + ret_str.replace("\n", "\n" + INDENT_STR) + "\n") + return wrapper_str + "\n" + else: + ret_str = V_ARRAY_RET if return_array else V_SCALAR_RET + + if len(p_sig.params) == 0: + wrapper_str = (V_WRAPPER_HEADER + "\n" + + INDENT_STR + V_LOOP_CHUNK + "\n" + + INDENT_STR * 2 + V_NO_ARG + "\n" + + INDENT_STR * 2 + ret_str) + else: + wrapper_str = (V_WRAPPER_HEADER + "\n" + + INDENT_STR + V_ZIP_CHUNK_ARGS + "\n" + + INDENT_STR + V_LOOP_CHUNK + "\n") + if arg_conv_needed: + wrapper_str = (wrapper_str + + INDENT_STR * 2 + V_ARG_CONV.replace("\n", "\n" + INDENT_STR * 2) + "\n" + + INDENT_STR * 2 + ret_str + "\n") + else: + wrapper_str = (wrapper_str + + INDENT_STR * 2 + V_NO_ARG_CONV.replace("\n", "\n" + INDENT_STR * 2) + "\n" + + INDENT_STR * 2 + ret_str + "\n") + + return (wrapper_str + "\n" + + INDENT_STR + V_RET_CHUNK + "\n") +# endregion +