Skip to content

Commit

Permalink
Add Java char support, fix a vectorization issue
Browse files Browse the repository at this point in the history
Also add more test cases
  • Loading branch information
jmao-denver committed Nov 25, 2023
1 parent a4016cc commit e3bf58b
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2687,12 +2687,7 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
}
}

List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
// note vectorization doesn't handle Python variadic arguments
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
}

}

private void prepareVectorizationArgs(MethodCallExpr n, QueryScope queryScope, Expression[] expressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private static final Map<Character, Class<?>> numpyType2JavaClass = new HashMap<>();

static {
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('h', short.class);
numpyType2JavaClass.put('H', char.class);
numpyType2JavaClass.put('i', int.class);
numpyType2JavaClass.put('l', long.class);
numpyType2JavaClass.put('h', short.class);
numpyType2JavaClass.put('f', float.class);
numpyType2JavaClass.put('d', double.class);
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('?', boolean.class);
numpyType2JavaClass.put('U', String.class);
numpyType2JavaClass.put('M', Instant.class);
Expand Down
25 changes: 19 additions & 6 deletions py/server/deephaven/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __call__(self, *args, **kwargs):
"""Double-precision floating-point number type"""
string = DType(j_name="java.lang.String", qst_type=_JQstType.stringType(), np_type=np.str_)
"""String type"""
Character = DType(j_name="java.lang.Character")
"""Character type"""
BigDecimal = DType(j_name="java.math.BigDecimal")
"""Java BigDecimal type"""
StringSet = DType(j_name="io.deephaven.stringset.StringSet")
Expand Down Expand Up @@ -339,8 +341,19 @@ def from_np_dtype(np_dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]) -
return PyObject


_NUMPY_INT_TYPE_CODES = ["i", "l", "h", "b"]
_NUMPY_FLOATING_TYPE_CODES = ["f", "d"]
_NUMPY_INT_TYPE_CODES = {"b", "h", "H", "i", "l"}
_NUMPY_FLOATING_TYPE_CODES = {"f", "d"}


def _is_py_null(x: Any) -> bool:
"""Checks if the value is a Python null value, i.e. None or NaN, or Pandas.NA."""
if x is None:
return True

try:
return bool(pd.isna(x))
except (TypeError, ValueError):
return False


def _scalar(x: Any, dtype: DType) -> Any:
Expand All @@ -350,12 +363,14 @@ def _scalar(x: Any, dtype: DType) -> Any:

# NULL_BOOL will appear in Java as a byte value which causes a cast error. We just let JPY converts it to Java null
# and the engine has casting logic to handle it.
if x is None and dtype != bool_ and _PRIMITIVE_DTYPE_NULL_MAP.get(dtype):
if _is_py_null(x) and dtype not in (bool_, char) and _PRIMITIVE_DTYPE_NULL_MAP.get(dtype):
return _PRIMITIVE_DTYPE_NULL_MAP[dtype]

try:
if hasattr(x, "dtype"):
if x.dtype.char in _NUMPY_INT_TYPE_CODES:
if x.dtype.char == 'H': # np.uint16 maps to Java char
return Character(int(x))
elif x.dtype.char in _NUMPY_INT_TYPE_CODES:
return int(x)
elif x.dtype.char in _NUMPY_FLOATING_TYPE_CODES:
return float(x)
Expand All @@ -368,8 +383,6 @@ def _scalar(x: Any, dtype: DType) -> Any:
elif x.dtype.char == 'M':
from deephaven.time import to_j_instant
return to_j_instant(x)
elif x.dtype.char == 'H': # np.uint16
return jpy.get_type("java.lang.Character")(int(x))
elif isinstance(x, (datetime.datetime, pd.Timestamp)):
from deephaven.time import to_j_instant
return to_j_instant(x)
Expand Down
31 changes: 0 additions & 31 deletions py/server/deephaven/jcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,34 +283,3 @@ def _j_array_to_series(dtype: DType, j_array: jpy.JType, conv_null: bool) -> pd.
s = pd.Series(data=np_array, copy=False)

return s


def _convert_udf_args(args: Tuple[Any], fn_signature: str, null_value: Literal[np.nan, pd.NA, None]) -> List[Any]:
converted_args = []
for arg, np_dtype_char in zip(args, fn_signature):
if np_dtype_char == 'O':
converted_args.append(arg)
elif src_np_dtype := _J_ARRAY_NP_TYPE_MAP.get(type(arg)):
# array types
np_dtype = np.dtype(np_dtype_char)
if src_np_dtype != np_dtype and np_dtype != np.object_:
raise DHError(f"Cannot convert Java array of type {src_np_dtype} to numpy array of type {np_dtype}")
dtype = dtypes.from_np_dtype(np_dtype)
if null_value is pd.NA:
converted_args.append(_j_array_to_series(dtype, arg, conv_null=True))
else: # np.nan or None
converted_args.append(_j_array_to_numpy_array(dtype, arg, conv_null=bool(null_value)))
else: # scalar type or array types that don't need conversion
try:
np_dtype = np.dtype(np_dtype_char)
except TypeError:
converted_args.append(arg)
else:
dtype = dtypes.from_np_dtype(np_dtype)
if dtype is dtypes.bool_:
converted_args.append(null_value if arg is None else arg)
elif dh_null := _PRIMITIVE_DTYPE_NULL_MAP.get(dtype):
converted_args.append(null_value if arg == dh_null else arg)
else:
converted_args.append(arg)
return converted_args
5 changes: 2 additions & 3 deletions py/server/deephaven/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

""" This module supports the conversion between Deephaven tables and numpy arrays. """
import re
from functools import wraps
from typing import List

import jpy
Expand All @@ -13,8 +12,8 @@
from deephaven import DHError, dtypes, new_table
from deephaven.column import Column, InputColumn
from deephaven.dtypes import DType
from deephaven.jcompat import _j_array_to_numpy_array, _convert_udf_args
from deephaven.table import Table, _encode_signature
from deephaven.jcompat import _j_array_to_numpy_array
from deephaven.table import Table

_JDataAccessHelpers = jpy.get_type("io.deephaven.engine.table.impl.DataAccessHelpers")

Expand Down
16 changes: 8 additions & 8 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
from deephaven.agg import Aggregation
from deephaven.column import Column, ColumnType
from deephaven.filters import Filter, and_, or_
from deephaven.jcompat import j_unary_operator, j_binary_operator, j_map_to_dict, j_hashmap, _convert_udf_args, \
_j_array_to_numpy_array
from deephaven.jcompat import j_unary_operator, j_binary_operator, j_map_to_dict, j_hashmap, _j_array_to_numpy_array
from deephaven.jcompat import to_sequence, j_array_list
from deephaven.time import to_np_datetime64
from deephaven.update_graph import auto_locking_ctx, UpdateGraph
from deephaven.updateby import UpdateByOperation
from deephaven.dtypes import _BUILDABLE_ARRAY_DTYPE_MAP, _scalar, _np_dtype_char, _component_np_dtype_char, DType, \
_np_ndarray_component_type, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP
_np_ndarray_component_type, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP, _NUMPY_INT_TYPE_CODES, \
_NUMPY_FLOATING_TYPE_CODES

# Table
_J_Table = jpy.get_type("io.deephaven.engine.table.Table")
Expand Down Expand Up @@ -368,7 +368,7 @@ def _j_py_script_session() -> _JPythonScriptSession:
return None


_SUPPORTED_NP_TYPE_CODES = {"b", "h", "i", "l", "f", "d", "?", "U", "M", "O"}
_SUPPORTED_NP_TYPE_CODES = {"b", "h", "H", "i", "l", "f", "d", "?", "U", "M", "O"}


@dataclass
Expand Down Expand Up @@ -444,7 +444,7 @@ def _parse_type_no_nested(annotation, p_annotation, t):
p_annotation.is_array = True
if tc in {"N", "O", "?", "U", "M"}:
p_annotation.is_none_legal = True
if tc in {"b", "h", "i", "l"}:
if tc in {"b", "h", "H", "i", "l"}:
if p_annotation.int_char and p_annotation.int_char != tc:
raise DHError(message=f"ambiguity detected: multiple integer types in annotation: {annotation}")
p_annotation.int_char = tc
Expand Down Expand Up @@ -499,9 +499,9 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
for p in params:
pa = ParsedAnnotation()
pa.encoded_types.add(p)
if p in {"b", "h", "i", "l"}:
if p in _NUMPY_INT_TYPE_CODES:
pa.int_char = p
if p in {"f", "d"}:
if p in _NUMPY_FLOATING_TYPE_CODES:
pa.floating_char = p
p_annotations.append(pa)
else: # GUFunc
Expand All @@ -517,7 +517,7 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
pa.is_array = True
else:
pa.encoded_types.add(p)
if p in {"b", "h", "i", "l"}:
if p in {"b", "h", "H", "i", "l"}:
pa.int_char = p
if p in {"f", "d"}:
pa.floating_char = p
Expand Down
89 changes: 76 additions & 13 deletions py/server/tests/test_udf_numpy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,33 @@
from tests.testbase import BaseTestCase

_J_TYPE_NULL_MAP = {
"double": "NULL_DOUBLE",
"float": "NULL_FLOAT",
"byte": "NULL_BYTE",
"short": "NULL_SHORT",
"char": "NULL_CHAR",
"int": "NULL_INT",
"long": "NULL_LONG",
"short": "NULL_SHORT",
"byte": "NULL_BYTE",
"float": "NULL_FLOAT",
"double": "NULL_DOUBLE",
}

_J_TYPE_NP_DTYPE_MAP = {
"double": "np.float64",
"float": "np.float32",
"byte": "np.int8",
"short": "np.int16",
"char": "np.uint16",
"int": "np.int32",
"long": "np.int64",
"short": "np.int16",
"byte": "np.int8",
"float": "np.float32",
"double": "np.float64",
}

_J_TYPE_J_ARRAY_TYPE_MAP = {
"double": double_array,
"float": float32_array,
"byte": int8_array,
"short": int16_array,
"char": char_array,
"int": int32_array,
"long": long_array,
"short": int16_array,
"byte": int8_array,
"float": float32_array,
"double": double_array,
}


Expand Down Expand Up @@ -254,7 +257,19 @@ def f51(col1, col2: Optional[np.ndarray[np.int32]]) -> bool:
with self.assertRaises(DHError) as cm:
t = t.update(["X1 = f51(X, null)"])

def test_str_bool_datetime(self):
t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X")

def f6(*args: np.int32, col2: np.ndarray[np.int32]) -> bool:
return np.nanmean(col2) == np.mean(col2)
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f6(X, Y)"])
self.assertIn("missing 1 required keyword-only argument", str(cm.exception))

with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f6(X, Y=null)"])
self.assertIn("not compatible with annotation", str(cm.exception))

def test_str_bool_datetime_array(self):
with self.subTest("str"):
def f1(p1: np.ndarray[str], p2=None) -> bool:
return bool(len(p1))
Expand Down Expand Up @@ -304,6 +319,54 @@ def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool:
t2 = t.update(["X1 = f31(null, Y)"])
self.assertEqual(3, t2.to_string("X1").count("false"))

def test_str_bool_datetime_scalar(self):
with self.subTest("str"):
def f1(p1: str, p2=None) -> bool:
return p1 is None

t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? `deephaven`: null"])
t1 = t.update(["X1 = f1(Y)"])
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
self.assertEqual(5, t1.to_string().count("true"))

def f11(p1: Union[str, None], p2=None) -> bool:
return p1 is None
t2 = t.update(["X1 = f11(Y)"])
self.assertEqual(5, t2.to_string().count("false"))

with self.subTest("datetime"):
def f2(p1: np.datetime64, p2=None) -> bool:
return p1 is None

t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? now() : null"])
t1 = t.update(["X1 = f2(Y)"])
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
self.assertEqual(5, t1.to_string().count("true"))

def f21(p1: Union[np.datetime64, None], p2=None) -> bool:
return p1 is None
t2 = t.update(["X1 = f21(Y)"])
self.assertEqual(5, t2.to_string().count("false"))

with self.subTest("boolean"):
def f3(p1: np.bool_, p2=None) -> bool:
return p1 is None

t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : null"])
t1 = t.update(["X1 = f3(Y)"])
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
self.assertEqual(5, t1.to_string("X1").count("false"))

t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : false"])
t1 = t.update(["X1 = f3(Y)"])
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
self.assertEqual(0, t1.to_string("X1").count("true"))

def f31(p1: Optional[np.bool_], p2=None) -> bool:
return p1 is None
t2 = t.update(["X1 = f31(null, Y)"])
self.assertEqual(10, t2.to_string("X1").count("true"))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
dtypes.byte: "np.int8",
dtypes.bool_: "np.bool_",
dtypes.string: "np.str_",
# dtypes.char: "np.uint16",
dtypes.char: "np.uint16",
}


Expand Down Expand Up @@ -52,7 +52,7 @@ def test_array_return(self):
"np.float64": dtypes.double_array,
"bool": dtypes.boolean_array,
"np.str_": dtypes.string_array,
# "np.uint16": dtypes.char_array,
"np.uint16": dtypes.char_array,
}
container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"]
for component_type, dh_dtype in component_types.items():
Expand Down

0 comments on commit e3bf58b

Please sign in to comment.