Skip to content

Commit

Permalink
Fix input table jpy restriction (#5260)
Browse files Browse the repository at this point in the history
* Fix input table jpy restriction

* Remove unnecessary jpy type

* Wrap java object with the most specific wrapper

* Fix docstring format issue

* Apply suggestions from code review

Co-authored-by: Chip Kent <[email protected]>

* Respond to review comments

* Tiny coding formatting issue

* Update py/server/deephaven/_wrapper.py

Co-authored-by: Chip Kent <[email protected]>

* Add one more test case

* Empty Commit to force rerun of CI

* Respond to latest review comments

---------

Co-authored-by: jianfengmao <[email protected]>
Co-authored-by: Jianfeng Mao <[email protected]>
Co-authored-by: Chip Kent <[email protected]>
  • Loading branch information
4 people authored Mar 21, 2024
1 parent dd6ace7 commit c301ea6
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 18 deletions.
51 changes: 37 additions & 14 deletions py/server/deephaven/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pkgutil
import sys
from abc import ABC, abstractmethod
from typing import Set, Union, Optional, Any
from typing import Set, Union, Optional, Any, List

import jpy

Expand Down Expand Up @@ -102,21 +102,16 @@ def _is_direct_initialisable(cls) -> bool:
return False


def _lookup_wrapped_class(j_obj: jpy.JType) -> Optional[type]:
""" Returns the wrapper class for the specified Java object. """
def _lookup_wrapped_class(j_obj: jpy.JType) -> List[JObjectWrapper]:
""" Returns the wrapper classes for the specified Java object. """
# load every module in the deephaven package so that all the wrapper classes are loaded and available to wrap
# the Java objects returned by calling resolve()
global _has_all_wrappers_imported
if not _has_all_wrappers_imported:
_recursive_import(__package__.partition(".")[0])
_has_all_wrappers_imported = True

for wc in _di_wrapper_classes:
j_clz = wc.j_object_type
if j_clz.jclass.isInstance(j_obj):
return wc

return None
return [wc for wc in _di_wrapper_classes if wc.j_object_type.jclass.isInstance(j_obj)]


def javaify(obj: Any) -> Optional[jpy.JType]:
Expand Down Expand Up @@ -162,15 +157,43 @@ def pythonify(j_obj: Any) -> Optional[Any]:
return wrap_j_object(j_obj)


def wrap_j_object(j_obj: jpy.JType) -> Union[JObjectWrapper, jpy.JType]:
""" Wraps the specified Java object as an instance of a custom wrapper class if one is available, otherwise returns
the raw Java object. """
def _wrap_with_subclass(j_obj: jpy.JType, cls: type) -> Optional[JObjectWrapper]:
""" Returns a wrapper instance for the specified Java object by trying the entire subclasses' hierarchy. The
function employs a Depth First Search strategy to try the most specific subclass first. If no matching wrapper class is found,
returns None.
The premises for this function are as follows:
- The subclasses all share the same class attribute `j_object_type` (guaranteed by subclassing JObjectWrapper)
- The subclasses are all direct initialisable (guaranteed by subclassing JObjectWrapper)
- The subclasses are all distinct from each other and check for their uniqueness in the initializer (__init__), e.g.
InputTable checks for the presence of the INPUT_TABLE_ATTRIBUTE attribute on the Java object.
"""
for subclass in cls.__subclasses__():
try:
if (wrapper := _wrap_with_subclass(j_obj, subclass)) is not None:
return wrapper
return subclass(j_obj)
except:
continue
return None


def wrap_j_object(j_obj: jpy.JType) -> Optional[Union[JObjectWrapper, jpy.JType]]:
""" Wraps the specified Java object as an instance of the most specific custom wrapper class if one is available,
otherwise returns the raw Java object. """
if j_obj is None:
return None

wc = _lookup_wrapped_class(j_obj)
wcs = _lookup_wrapped_class(j_obj)
for wc in wcs:
try:
if (wrapper:= _wrap_with_subclass(j_obj, wc)) is not None:
return wrapper
return wc(j_obj)
except:
continue

return wc(j_obj) if wc else j_obj
return j_obj


def unwrap(obj: Any) -> Union[jpy.JType, Any]:
Expand Down
5 changes: 3 additions & 2 deletions py/server/deephaven/table_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
_JTableFactory = jpy.get_type("io.deephaven.engine.table.TableFactory")
_JTableTools = jpy.get_type("io.deephaven.engine.util.TableTools")
_JDynamicTableWriter = jpy.get_type("io.deephaven.engine.table.impl.util.DynamicTableWriter")
_JBaseArrayBackedInputTable = jpy.get_type("io.deephaven.engine.table.impl.util.BaseArrayBackedInputTable")
_JAppendOnlyArrayBackedInputTable = jpy.get_type(
"io.deephaven.engine.table.impl.util.AppendOnlyArrayBackedInputTable")
_JKeyedArrayBackedInputTable = jpy.get_type("io.deephaven.engine.table.impl.util.KeyedArrayBackedInputTable")
_JTableDefinition = jpy.get_type("io.deephaven.engine.table.TableDefinition")
_JTable = jpy.get_type("io.deephaven.engine.table.Table")
_J_INPUT_TABLE_ATTRIBUTE = _JTable.INPUT_TABLE_ATTRIBUTE
_J_InputTableUpdater = jpy.get_type("io.deephaven.engine.util.input.InputTableUpdater")
_JRingTableTools = jpy.get_type("io.deephaven.engine.table.impl.sources.ring.RingTableTools")
_JSupplier = jpy.get_type('java.util.function.Supplier')
_JFunctionGeneratedTableFactory = jpy.get_type("io.deephaven.engine.table.impl.util.FunctionGeneratedTableFactory")
Expand Down Expand Up @@ -235,13 +235,14 @@ class InputTable(Table):
Users should always create InputTables through factory methods rather than directly from the constructor.
"""
j_object_type = _JBaseArrayBackedInputTable

def __init__(self, j_table: jpy.JType):
super().__init__(j_table)
self.j_input_table = self.j_table.getAttribute(_J_INPUT_TABLE_ATTRIBUTE)
if not self.j_input_table:
raise DHError("the provided table input is not suitable for input tables.")
if not _J_InputTableUpdater.jclass.isInstance(self.j_input_table):
raise DHError("the provided table's InputTable attribute type is not of InputTableUpdater type.")

def add(self, table: Table) -> None:
"""Synchronously writes rows from the provided table to this input table. If this is a keyed input table, added rows with keys
Expand Down
32 changes: 30 additions & 2 deletions py/server/tests/test_table_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
import numpy as np

from deephaven import DHError, read_csv, time_table, empty_table, merge, merge_sorted, dtypes, new_table, \
input_table, time
input_table, time, _wrapper
from deephaven.column import byte_col, char_col, short_col, bool_col, int_col, long_col, float_col, double_col, \
string_col, datetime_col, pyobj_col, jobj_col
from deephaven.constants import NULL_DOUBLE, NULL_FLOAT, NULL_LONG, NULL_INT, NULL_SHORT, NULL_BYTE
from deephaven.table_factory import DynamicTableWriter, ring_table
from deephaven.table_factory import DynamicTableWriter, InputTable, ring_table
from tests.testbase import BaseTestCase
from deephaven.table import Table
from deephaven.stream import blink_to_append_only, stream_to_append_only

JArrayList = jpy.get_type("java.util.ArrayList")
_JBlinkTableTools = jpy.get_type("io.deephaven.engine.table.impl.BlinkTableTools")
_JDateTimeUtils = jpy.get_type("io.deephaven.time.DateTimeUtils")
_JTable = jpy.get_type("io.deephaven.engine.table.Table")


@dataclass
Expand Down Expand Up @@ -372,6 +373,16 @@ def test_input_table(self):
keyed_input_table.delete(t.select(["String", "Double"]))
self.assertEqual(keyed_input_table.size, 0)

with self.subTest("custom input table creation"):
place_holder_input_table = empty_table(1).update_view(["Key=`A`", "Value=10"]).with_attributes({_JTable.INPUT_TABLE_ATTRIBUTE: "Placeholder IT"}).j_table

with self.assertRaises(DHError) as cm:
InputTable(place_holder_input_table)
self.assertIn("not of InputTableUpdater type", str(cm.exception))

self.assertTrue(isinstance(_wrapper.wrap_j_object(place_holder_input_table), Table))


def test_ring_table(self):
cols = [
bool_col(name="Boolean", data=[True, False]),
Expand Down Expand Up @@ -450,5 +461,22 @@ def test_input_table_empty_data(self):
it.delete(t)
self.assertEqual(it.size, 0)

def test_j_input_wrapping(self):
cols = [
bool_col(name="Boolean", data=[True, False]),
string_col(name="String", data=["foo", "bar"]),
]
t = new_table(cols=cols)
col_defs = {c.name: c.data_type for c in t.columns}
append_only_input_table = input_table(col_defs=col_defs)

it = _wrapper.wrap_j_object(append_only_input_table.j_table)
self.assertTrue(isinstance(it, InputTable))

t = _wrapper.wrap_j_object(t.j_object)
self.assertFalse(isinstance(t, InputTable))
self.assertTrue(isinstance(t, Table))


if __name__ == '__main__':
unittest.main()

0 comments on commit c301ea6

Please sign in to comment.