Skip to content

Commit

Permalink
Remove branching in UDF wrappers via. code gen (#5487)
Browse files Browse the repository at this point in the history
* Reduce the number of checks in vectorization case

* Remove branching by code-gen

* Refactor the code gen code

* More code cleanup

* Minor code tidy-up and some comments
  • Loading branch information
jmao-denver authored May 24, 2024
1 parent 4446605 commit 2929e94
Showing 1 changed file with 93 additions and 75 deletions.
168 changes: 93 additions & 75 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
p_sig.ret_annotation = _parse_return_annotation(t)
return p_sig


def _udf_parser(fn: Callable):
"""A decorator that acts as a transparent translator for Python UDFs used in Deephaven query formulas between
Python and Java. This decorator is intended for internal use by the Deephaven query engine and should not be used by
Expand All @@ -528,90 +527,31 @@ def _udf_parser(fn: Callable):
ret_dtype = dtypes.from_np_dtype(np.dtype(ret_np_char if ret_np_char != "X" else "O"))

@wraps(fn)
def _udf_decorator(encoded_arg_types: str, for_vectorization: bool):
def _udf_decorator(encoded_arg_types: str, for_vectorization: bool) -> Callable:
"""The actual decorator that wraps the Python UDF and converts the arguments and return values.
It is called by the query engine with the runtime argument types to create a wrapper that can efficiently
convert the arguments and return values based on the provided argument types and the parsed parameters of the
UDF.
"""
arg_conv_needed = p_sig.prepare_auto_arg_conv(encoded_arg_types)
p_sig.ret_annotation.setup_return_converter()
ret_converter = p_sig.ret_annotation.ret_converter
nonlocal ret_dtype # used in converting array-type return values, bring it into the local scope before exec()

if not for_vectorization:
if not arg_conv_needed and p_sig.ret_annotation.encoded_type == "O":
return fn

def _wrapper(*args, **kwargs):
if arg_conv_needed:
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, args)]

# if the number of arguments is more than the number of parameters, treat the last parameter as a
# vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg
for arg in args[len(converted_args):]])
else:
converted_args = args
# kwargs are not converted because they are not used in the UDFs
ret = fn(*converted_args, **kwargs)
if return_array:
return dtypes.array(ret_dtype, ret)
else:
return p_sig.ret_annotation.ret_converter(ret) if p_sig.ret_annotation.ret_converter else ret

return _wrapper
else: # for vectorization
def _vectorization_wrapper(*args):
if len(args) != len(p_sig.params) + 2:
raise ValueError(
f"The number of arguments doesn't match the function ({p_sig.fn.__name__}) signature. "
f"{len(args) - 2}, {p_sig.encoded}")
if args[0] <= 0:
raise ValueError(
f"The chunk size argument must be a positive integer for vectorized function ("
f"{p_sig.fn.__name__}). {args[0]}")

chunk_size = args[0]
chunk_result = args[1]
if args[2:]:
vectorized_args = zip(*args[2:])
for i in range(chunk_size):
scalar_args = next(vectorized_args)
if arg_conv_needed:
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, scalar_args)]

# if the number of arguments is more than the number of parameters, treat the last parameter
# as a vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg
for arg in scalar_args[len(converted_args):]])
else:
converted_args = scalar_args
if not for_vectorization and not arg_conv_needed and not ret_converter and not return_array:
# no wrapper needed
return fn

ret = fn(*converted_args)
if return_array:
chunk_result[i] = dtypes.array(ret_dtype, ret)
else:
chunk_result[i] = p_sig.ret_annotation.ret_converter(
ret) if p_sig.ret_annotation.ret_converter else ret
else:
for i in range(chunk_size):
ret = fn()
if return_array:
chunk_result[i] = dtypes.array(ret_dtype, ret)
else:
chunk_result[i] = p_sig.ret_annotation.ret_converter(
ret) if p_sig.ret_annotation.ret_converter else ret
return chunk_result
_wrapper_str = _gen_wrapper_code(p_sig, for_vectorization, arg_conv_needed, return_array)
scope = {**globals(), **locals()}
exec(_wrapper_str, scope)

if for_vectorization and test_vectorization:
global vectorized_count
vectorized_count += 1

return scope["_wrapper"]

if test_vectorization:
global vectorized_count
vectorized_count += 1
return _vectorization_wrapper

_udf_decorator.j_name = ret_dtype.j_name
real_ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(ret_dtype, dtypes.PyObject) if return_array else ret_dtype
Expand All @@ -625,3 +565,81 @@ def _vectorization_wrapper(*args):
_udf_decorator.signature = p_sig.encoded

return _udf_decorator

# region Wrapper Code Generation
# for non-vectorize-able UDFs
INDENT_STR= " " * 4
WRAPPER_HEADER = """def _wrapper(*args):"""
ARG_CONV = """
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, args)]
# if the number of arguments is more than the number of parameters, treat the last parameter as a
# vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg for arg in args[len(converted_args):]])
ret = fn(*converted_args)"""
NO_ARG_CONV = """ret = fn(*args)"""
ARRAY_RET = """return dtypes.array(ret_dtype, ret)"""
SCALAR_RET = """return ret_converter(ret) if ret_converter else ret"""

# for vectorize-able UDFs
V_WRAPPER_HEADER = """
def _wrapper(*args):
chunk_size = args[0]
chunk_result = args[1]"""
V_ARRAY_RET = """chunk_result[i] = dtypes.array(ret_dtype, ret)"""
V_SCALAR_RET = """chunk_result[i] = ret_converter(ret) if ret_converter else ret"""
V_ZIP_CHUNK_ARGS = """vectorized_args = zip(*args[2:])"""
V_ARG_CONV = """
scalar_args = next(vectorized_args)
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, scalar_args)]
# if the number of arguments is more than the number of parameters, treat the last
# parameter as a vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg
for arg in scalar_args[len(converted_args):]])
ret = fn(*converted_args)"""
V_NO_ARG_CONV = """ret = fn(*next(vectorized_args))"""
V_NO_ARG = """ret = fn()"""
V_LOOP_CHUNK = """for i in range(chunk_size):"""
V_RET_CHUNK = "return chunk_result"


def _gen_wrapper_code(p_sig: _ParsedSignature, for_vectorization: bool, arg_conv_needed: bool, return_array: bool) -> str:
""" Generate the wrapper code for the UDF based on the parsed signature and the context of the UDF usage."""
if not for_vectorization:
conv_str = ARG_CONV if arg_conv_needed else NO_ARG_CONV
ret_str = ARRAY_RET if return_array else SCALAR_RET
wrapper_str = (WRAPPER_HEADER + "\n"
+ INDENT_STR + conv_str.replace("\n", "\n" + INDENT_STR) + "\n"
+ INDENT_STR + ret_str.replace("\n", "\n" + INDENT_STR) + "\n")
return wrapper_str + "\n"
else:
ret_str = V_ARRAY_RET if return_array else V_SCALAR_RET

if len(p_sig.params) == 0:
wrapper_str = (V_WRAPPER_HEADER + "\n"
+ INDENT_STR + V_LOOP_CHUNK + "\n"
+ INDENT_STR * 2 + V_NO_ARG + "\n"
+ INDENT_STR * 2 + ret_str)
else:
wrapper_str = (V_WRAPPER_HEADER + "\n"
+ INDENT_STR + V_ZIP_CHUNK_ARGS + "\n"
+ INDENT_STR + V_LOOP_CHUNK + "\n")
if arg_conv_needed:
wrapper_str = (wrapper_str
+ INDENT_STR * 2 + V_ARG_CONV.replace("\n", "\n" + INDENT_STR * 2) + "\n"
+ INDENT_STR * 2 + ret_str + "\n")
else:
wrapper_str = (wrapper_str
+ INDENT_STR * 2 + V_NO_ARG_CONV.replace("\n", "\n" + INDENT_STR * 2) + "\n"
+ INDENT_STR * 2 + ret_str + "\n")

return (wrapper_str + "\n"
+ INDENT_STR + V_RET_CHUNK + "\n")
# endregion

0 comments on commit 2929e94

Please sign in to comment.