Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix input table jpy restriction #5260

Merged
merged 12 commits into from
Mar 21, 2024
51 changes: 37 additions & 14 deletions py/server/deephaven/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pkgutil
import sys
from abc import ABC, abstractmethod
from typing import Set, Union, Optional, Any
from typing import Set, Union, Optional, Any, List

import jpy

Expand Down Expand Up @@ -102,21 +102,16 @@ def _is_direct_initialisable(cls) -> bool:
return False


def _lookup_wrapped_class(j_obj: jpy.JType) -> Optional[type]:
""" Returns the wrapper class for the specified Java object. """
def _lookup_wrapped_class(j_obj: jpy.JType) -> List[JObjectWrapper]:
""" Returns the wrapper classes for the specified Java object. """
# load every module in the deephaven package so that all the wrapper classes are loaded and available to wrap
# the Java objects returned by calling resolve()
global _has_all_wrappers_imported
if not _has_all_wrappers_imported:
_recursive_import(__package__.partition(".")[0])
_has_all_wrappers_imported = True

for wc in _di_wrapper_classes:
j_clz = wc.j_object_type
if j_clz.jclass.isInstance(j_obj):
return wc

return None
return [wc for wc in _di_wrapper_classes if wc.j_object_type.jclass.isInstance(j_obj)]


def javaify(obj: Any) -> Optional[jpy.JType]:
Expand Down Expand Up @@ -162,15 +157,43 @@ def pythonify(j_obj: Any) -> Optional[Any]:
return wrap_j_object(j_obj)


def wrap_j_object(j_obj: jpy.JType) -> Union[JObjectWrapper, jpy.JType]:
""" Wraps the specified Java object as an instance of a custom wrapper class if one is available, otherwise returns
the raw Java object. """
def _wrap_with_subclass(j_obj: jpy.JType, cls: type) -> Optional[JObjectWrapper]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably, this is slower than the existing behavior. The only place I can imagine where that might matter is something like a PartitionedTable.transform with a Python function. Does the Java constituent Table get wrapped automatically as a Python Table wrapper, and if so , is it slow enough to matter?

@jmao-denver you might need to test this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are not. When requested (e.g. via. constituent_tables()), the constituent tables are explicitly wrapped in Table, not through this auto wrapping facility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that addresses my concern. I am talking about when you pass a Python function (adapted to Java) to transform.
deephaven.table.PartitionedTable.transform

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some quick/minimal testing, ~2 - 4 % slower, the more constituent tables, the slower, which is kinda expected (more wrapping):

def transform_func(t: Table) -> Table:
    return t.update("f = a + b")

t = empty_table(100_000_000).update(["a = i", "b = i + 1", "c = i % 10_000", "d = `text`", "e = i % 1000000"])
pt = t.partition_by(by=["c"])
print("num of constituent tables: ", len(pt.constituent_tables))
print("constituent table size: ", pt.constituent_tables[0].size)
with make_user_exec_ctx():
    st = time.process_time_ns()
    transformed_pt = pt.transform(transform_func)
    print("transform time: ", (time.process_time_ns() - st)/10**9)
    self.assertIsNotNone(transformed_pt)

Before:

num of constituent tables:  10000
constituent table size:  10000
transform time:  42.321798386

num of constituent tables:  1000
constituent table size:  100000
transform time:  20.706249172

After:

num of constituent tables:  10000
constituent table size:  10000
transform time:  44.08636249

num of constituent tables:  1000
constituent table size:  100000
transform time:  21.192759225

""" Returns a wrapper instance for the specified Java object by trying the entire subclasses' hierarchy. The
function employs a Depth First Search strategy to try the most specific subclass first. If no matching wrapper class is found,
returns None.

The premises for this function are as follows:
- The subclasses all share the same class attribute `j_object_type` (guaranteed by subclassing JObjectWrapper)
- The subclasses are all direct initialisable (guaranteed by subclassing JObjectWrapper)
- The subclasses are all distinct from each other and check for their uniqueness in the initializer (__init__), e.g.
InputTable checks for the presence of the INPUT_TABLE_ATTRIBUTE attribute on the Java object.
"""
for subclass in cls.__subclasses__():
try:
if (wrapper := _wrap_with_subclass(j_obj, subclass)) is not None:
return wrapper
return subclass(j_obj)
except:
continue
return None


def wrap_j_object(j_obj: jpy.JType) -> Optional[Union[JObjectWrapper, jpy.JType]]:
""" Wraps the specified Java object as an instance of the most specific custom wrapper class if one is available,
otherwise returns the raw Java object. """
chipkent marked this conversation as resolved.
Show resolved Hide resolved
if j_obj is None:
return None

wc = _lookup_wrapped_class(j_obj)
wcs = _lookup_wrapped_class(j_obj)
for wc in wcs:
chipkent marked this conversation as resolved.
Show resolved Hide resolved
try:
if (wrapper:= _wrap_with_subclass(j_obj, wc)) is not None:
return wrapper
return wc(j_obj)
chipkent marked this conversation as resolved.
Show resolved Hide resolved
except:
continue

return wc(j_obj) if wc else j_obj
return j_obj


def unwrap(obj: Any) -> Union[jpy.JType, Any]:
Expand Down
2 changes: 0 additions & 2 deletions py/server/deephaven/table_factory.py
chipkent marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
_JTableFactory = jpy.get_type("io.deephaven.engine.table.TableFactory")
_JTableTools = jpy.get_type("io.deephaven.engine.util.TableTools")
_JDynamicTableWriter = jpy.get_type("io.deephaven.engine.table.impl.util.DynamicTableWriter")
_JBaseArrayBackedInputTable = jpy.get_type("io.deephaven.engine.table.impl.util.BaseArrayBackedInputTable")
_JAppendOnlyArrayBackedInputTable = jpy.get_type(
"io.deephaven.engine.table.impl.util.AppendOnlyArrayBackedInputTable")
_JKeyedArrayBackedInputTable = jpy.get_type("io.deephaven.engine.table.impl.util.KeyedArrayBackedInputTable")
Expand Down Expand Up @@ -235,7 +234,6 @@ class InputTable(Table):

Users should always create InputTables through factory methods rather than directly from the constructor.
"""
j_object_type = _JBaseArrayBackedInputTable

def __init__(self, j_table: jpy.JType):
super().__init__(j_table)
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
26 changes: 24 additions & 2 deletions py/server/tests/test_table_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
import numpy as np

from deephaven import DHError, read_csv, time_table, empty_table, merge, merge_sorted, dtypes, new_table, \
input_table, time
input_table, time, _wrapper
from deephaven.column import byte_col, char_col, short_col, bool_col, int_col, long_col, float_col, double_col, \
string_col, datetime_col, pyobj_col, jobj_col
from deephaven.constants import NULL_DOUBLE, NULL_FLOAT, NULL_LONG, NULL_INT, NULL_SHORT, NULL_BYTE
from deephaven.table_factory import DynamicTableWriter, ring_table
from deephaven.table_factory import DynamicTableWriter, InputTable, ring_table
from tests.testbase import BaseTestCase
from deephaven.table import Table
from deephaven.stream import blink_to_append_only, stream_to_append_only

JArrayList = jpy.get_type("java.util.ArrayList")
_JBlinkTableTools = jpy.get_type("io.deephaven.engine.table.impl.BlinkTableTools")
_JDateTimeUtils = jpy.get_type("io.deephaven.time.DateTimeUtils")
_JTable = jpy.get_type("io.deephaven.engine.table.Table")


@dataclass
Expand Down Expand Up @@ -372,6 +373,14 @@ def test_input_table(self):
keyed_input_table.delete(t.select(["String", "Double"]))
self.assertEqual(keyed_input_table.size, 0)

with self.subTest("custom input table creation"):
place_holder_input_table = empty_table(1).update_view(["Key=`A`", "Value=10"]).with_attributes({_JTable.INPUT_TABLE_ATTRIBUTE: "Placeholder IT"}).j_table
# Confirming no error.
it = InputTable(place_holder_input_table)

self.assertTrue(isinstance(_wrapper.wrap_j_object(place_holder_input_table), InputTable))


def test_ring_table(self):
cols = [
bool_col(name="Boolean", data=[True, False]),
Expand Down Expand Up @@ -450,5 +459,18 @@ def test_input_table_empty_data(self):
it.delete(t)
self.assertEqual(it.size, 0)

def test_j_input_wrapping(self):
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
cols = [
bool_col(name="Boolean", data=[True, False]),
string_col(name="String", data=["foo", "bar"]),
]
t = new_table(cols=cols)
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
col_defs = {c.name: c.data_type for c in t.columns}
append_only_input_table = input_table(col_defs=col_defs)
chipkent marked this conversation as resolved.
Show resolved Hide resolved

t = _wrapper.wrap_j_object(append_only_input_table.j_table)
self.assertTrue(isinstance(t, InputTable))


if __name__ == '__main__':
unittest.main()
Loading