Skip to content

Commit

Permalink
Revert "auto convert Python return val to Java val (#4579)" (#4638)
Browse files Browse the repository at this point in the history
This reverts commit 77587f4.
  • Loading branch information
devinrsmith authored Oct 13, 2023
1 parent 9275306 commit b14b349
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 623 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.github.javaparser.ast.comments.BlockComment;
import com.github.javaparser.ast.comments.JavadocComment;
import com.github.javaparser.ast.comments.LineComment;
import com.github.javaparser.ast.comments.Comment;
import com.github.javaparser.ast.expr.ArrayAccessExpr;
import com.github.javaparser.ast.expr.ArrayCreationExpr;
import com.github.javaparser.ast.expr.ArrayInitializerExpr;
Expand Down Expand Up @@ -122,7 +123,6 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.time.Instant;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -1214,7 +1214,7 @@ private Expression[] convertParameters(final Executable executable,
// as a single argument
if (ObjectVector.class.isAssignableFrom(expressionTypes[ei])) {
expressions[ei] = new CastExpr(
StaticJavaParser.parseClassOrInterfaceType("java.lang.Object"),
new ClassOrInterfaceType("java.lang.Object"),
expressions[ei]);
expressionTypes[ei] = Object.class;
} else {
Expand Down Expand Up @@ -2023,7 +2023,7 @@ public Class<?> visit(ConditionalExpr n, VisitArgs printer) {
if (classA == boolean.class && classB == Boolean.class) {
// a little hacky, but this handles the null case where it unboxes. very weird stuff
final Expression uncastExpr = n.getThenExpr();
final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr);
final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr);
n.setThenExpr(castExpr);
// fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr)
uncastExpr.setParentNode(castExpr);
Expand All @@ -2032,7 +2032,7 @@ public Class<?> visit(ConditionalExpr n, VisitArgs printer) {
if (classA == Boolean.class && classB == boolean.class) {
// a little hacky, but this handles the null case where it unboxes. very weird stuff
final Expression uncastExpr = n.getElseExpr();
final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr);
final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr);
n.setElseExpr(castExpr);
// fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr)
uncastExpr.setParentNode(castExpr);
Expand Down Expand Up @@ -2159,8 +2159,7 @@ public Class<?> visit(FieldAccessExpr n, VisitArgs printer) {
printer.append(", " + clsName + ".class");

final ClassExpr targetType =
new ClassExpr(
StaticJavaParser.parseClassOrInterfaceType(printer.pythonCastContext.getSimpleName()));
new ClassExpr(new ClassOrInterfaceType(null, printer.pythonCastContext.getSimpleName()));
getAttributeArgs.add(targetType);

// Let's advertise to the caller the cast context type
Expand Down Expand Up @@ -2338,35 +2337,6 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
callMethodCall.setData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS,
new PyCallableDetails(null, methodName));

if (PyCallableWrapper.class.isAssignableFrom(method.getDeclaringClass())) {
final Optional<Class<?>> optionalRetType = pyCallableReturnType(callMethodCall);
if (optionalRetType.isPresent()) {
Class<?> retType = optionalRetType.get();
final Optional<CastExpr> optionalCastExpr =
makeCastExpressionForPyCallable(retType, callMethodCall);
if (optionalCastExpr.isPresent()) {
final CastExpr castExpr = optionalCastExpr.get();
replaceChildExpression(
n.getParentNode().orElseThrow(),
n,
castExpr);

callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).setCasted(true);
try {
return castExpr.accept(this, printer);
} catch (Exception e) {
// exceptions could be thrown by {@link #tryVectorizePythonCallable}
replaceChildExpression(
castExpr.getParentNode().orElseThrow(),
castExpr,
callMethodCall);
callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)
.setCasted(false);
return callMethodCall.accept(this, printer);
}
}
}
}
replaceChildExpression(
n.getParentNode().orElseThrow(),
n,
Expand Down Expand Up @@ -2406,7 +2376,7 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {

final ObjectCreationExpr newPyCallableExpr = new ObjectCreationExpr(
null,
StaticJavaParser.parseClassOrInterfaceType(pyCallableWrapperImplName),
new ClassOrInterfaceType(null, pyCallableWrapperImplName),
NodeList.nodeList(getAttributeCall));

final MethodCallExpr callMethodCall = new MethodCallExpr(
Expand Down Expand Up @@ -2460,60 +2430,6 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
typeArguments);
}

private Optional<CastExpr> makeCastExpressionForPyCallable(Class<?> retType, MethodCallExpr callMethodCall) {
if (retType.isPrimitive()) {
return Optional.of(new CastExpr(
new PrimitiveType(PrimitiveType.Primitive
.valueOf(retType.getSimpleName().toUpperCase())),
callMethodCall));
} else if (retType.getComponentType() != null) {
final Class<?> componentType = retType.getComponentType();
if (componentType.isPrimitive()) {
ArrayType arrayType;
if (componentType == boolean.class) {
arrayType = new ArrayType(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"));
} else {
arrayType = new ArrayType(new PrimitiveType(PrimitiveType.Primitive
.valueOf(retType.getComponentType().getSimpleName().toUpperCase())));
}
return Optional.of(new CastExpr(arrayType, callMethodCall));
} else if (retType.getComponentType() == String.class || retType.getComponentType() == Boolean.class
|| retType.getComponentType() == Instant.class) {
ArrayType arrayType =
new ArrayType(
StaticJavaParser.parseClassOrInterfaceType(retType.getComponentType().getSimpleName()));
return Optional.of(new CastExpr(arrayType, callMethodCall));
}
} else if (retType == Boolean.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"), callMethodCall));
} else if (retType == String.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.String"), callMethodCall));
} else if (retType == Instant.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.time.Instant"), callMethodCall));
}

return Optional.empty();
}

private Optional<Class<?>> pyCallableReturnType(@NotNull MethodCallExpr n) {
final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS);
final String pyMethodName = pyCallableDetails.pythonMethodName;
final QueryScope queryScope = ExecutionContext.getContext().getQueryScope();
final Object paramValueRaw = queryScope.readParamValue(pyMethodName, null);
if (paramValueRaw == null) {
return Optional.empty();
}
if (!(paramValueRaw instanceof PyCallableWrapper)) {
return Optional.empty();
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
pyCallableWrapper.parseSignature();
return Optional.ofNullable(pyCallableWrapper.getReturnType());
}

@NotNull
private static Expression[] getExpressionsArray(final NodeList<Expression> exprNodeList) {
return exprNodeList == null ? new Expression[0]
Expand Down Expand Up @@ -2647,27 +2563,13 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
// expression evaluation code that expects singular values. This check makes sure that numba/dh vectorized
// functions must be used alone as the entire expression after removing the enclosing parentheses.
Node n1 = n;
boolean autoCastChecked = false;
while (n1.hasParentNode()) {
n1 = n1.getParentNode().orElseThrow();
Class<?> cls = n1.getClass();

if (cls == CastExpr.class) {
if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()) {
autoCastChecked = true;
} else {
throw new PythonCallVectorizationFailure(
"The return values of Python vectorized functions can't be cast: " + n1);
}
} else if (cls == MethodCallExpr.class) {
String methodName = ((MethodCallExpr) n1).getNameAsString();
if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()
&& methodName.endsWith("Cast")) {
autoCastChecked = true;
} else {
throw new PythonCallVectorizationFailure(
"Python vectorized function can't be used in another expression: " + n1);
}
throw new PythonCallVectorizationFailure(
"The return values of Python vectorized function can't be cast: " + n1);
} else if (cls != EnclosedExpr.class && cls != WrapperNode.class) {
throw new PythonCallVectorizationFailure(
"Python vectorized function can't be used in another expression: " + n1);
Expand Down Expand Up @@ -3331,17 +3233,6 @@ private static class PyCallableDetails {
@NotNull
private final String pythonMethodName;

@NotNull
private boolean isCasted = false;

public boolean isCasted() {
return isCasted;
}

public void setCasted(boolean casted) {
isCasted = casted;
}

private PyCallableDetails(@Nullable String pythonScopeExpr, @NotNull String pythonMethodName) {
this.pythonScopeExpr = pythonScopeExpr;
this.pythonMethodName = pythonMethodName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import org.jpy.PyModule;
import org.jpy.PyObject;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -21,7 +20,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class);

private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType();
private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType();

private static final PyModule dh_table_module = PyModule.importModule("deephaven.table");

Expand All @@ -36,7 +34,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('?', boolean.class);
numpyType2JavaClass.put('U', String.class);
numpyType2JavaClass.put('M', Instant.class);
numpyType2JavaClass.put('O', Object.class);
}

Expand All @@ -50,7 +47,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private Collection<ChunkArgument> chunkArguments;
private boolean numbaVectorized;
private PyObject unwrapped;
private PyObject pyUdfDecoratedCallable;

public PyCallableWrapperJpyImpl(PyObject pyCallable) {
this.pyCallable = pyCallable;
Expand Down Expand Up @@ -95,37 +91,23 @@ private static PyObject getNumbaVectorizedFuncType() {
}
}

private static PyObject getNumbaGUVectorizedFuncType() {
try {
return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc");
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Numba isn't installed in the Python environment.");
}
return null;
}
}

private void prepareSignature() {
boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE);
boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE);
if (isNumbaGUVectorized || isNumbaVectorized) {
if (pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE)) {
List<PyObject> params = pyCallable.getAttribute("types").asList();
if (params.isEmpty()) {
throw new IllegalArgumentException(
"numba vectorized/guvectorized function must have an explicit signature: " + pyCallable);
"numba vectorized function must have an explicit signature: " + pyCallable);
}
// numba allows a vectorized function to have multiple signatures
if (params.size() > 1) {
throw new UnsupportedOperationException(
pyCallable
+ " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions");
+ " has multiple signatures; this is not currently supported for numba vectorized functions");
}
signature = params.get(0).getStringValue();
unwrapped = pyCallable;
// since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized
numbaVectorized = isNumbaVectorized;
vectorized = isNumbaVectorized;
unwrapped = null;
numbaVectorized = true;
vectorized = true;
} else if (pyCallable.hasAttribute("dh_vectorized")) {
signature = pyCallable.getAttribute("signature").toString();
unwrapped = pyCallable.getAttribute("callable");
Expand All @@ -137,7 +119,6 @@ private void prepareSignature() {
numbaVectorized = false;
vectorized = false;
}
pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped);
}

@Override
Expand All @@ -154,6 +135,14 @@ public void parseSignature() {
throw new IllegalStateException("Signature should always be available.");
}

char numpyTypeCode = signature.charAt(signature.length() - 1);
Class<?> returnType = numpyType2JavaClass.get(numpyTypeCode);
if (returnType == null) {
throw new IllegalStateException(
"Vectorized functions should always have an integral, floating point, boolean, String, or Object return type: "
+ numpyTypeCode);
}

List<Class<?>> paramTypes = new ArrayList<>();
for (char numpyTypeChar : signature.toCharArray()) {
if (numpyTypeChar != '-') {
Expand All @@ -170,31 +159,25 @@ public void parseSignature() {
}

this.paramTypes = paramTypes;

returnType = pyUdfDecoratedCallable.getAttribute("return_type", null);
if (returnType == null) {
throw new IllegalStateException(
"Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type");
}

if (returnType == boolean.class) {
if (returnType == Object.class) {
this.returnType = PyObject.class;
} else if (returnType == boolean.class) {
this.returnType = Boolean.class;
} else {
this.returnType = returnType;
}
}

// In vectorized mode, we want to call the vectorized function directly.
public PyObject vectorizedCallable() {
if (numbaVectorized || vectorized) {
if (numbaVectorized) {
return pyCallable;
} else {
return dh_table_module.call("dh_vectorize", unwrapped);
}
}

// In non-vectorized mode, we want to call the udf decorated function or the original function.
@Override
public Object call(Object... args) {
PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable;
return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ public void addChunkArgument(ChunkArgument ignored) {}

@Override
public Class<?> getReturnType() {
return Object.class;
throw new UnsupportedOperationException();
}
}
28 changes: 25 additions & 3 deletions py/server/deephaven/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from typing import Sequence, Any

import jpy
import numpy as np
import pandas as pd

import deephaven.dtypes as dtypes
from deephaven import DHError
from deephaven import DHError, time
from deephaven.dtypes import DType
from deephaven.dtypes import _instant_array
from deephaven.time import to_j_instant

_JColumnHeader = jpy.get_type("io.deephaven.qst.column.header.ColumnHeader")
_JColumn = jpy.get_type("io.deephaven.qst.column.Column")
Expand Down Expand Up @@ -204,7 +206,27 @@ def datetime_col(name: str, data: Sequence) -> InputColumn:
Returns:
a new input column
"""
data = _instant_array(data)

# try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on
# it to reduce the number of round trips to the JVM
if not isinstance(data, np.ndarray):
try:
data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64)
except Exception as e:
...

if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'):
if data.dtype.kind == 'M':
longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64'))
elif data.dtype.kind == 'i':
longs = jpy.array('long', data.astype('int64'))
else: # data.dtype.kind == 'U'
longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data])
data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs)

if not isinstance(data, dtypes.instant_array.j_type):
data = [to_j_instant(d) for d in data]

return InputColumn(name=name, data_type=dtypes.Instant, input_data=data)


Expand Down
Loading

0 comments on commit b14b349

Please sign in to comment.