Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check arg type against Py UDF signature at query compile time #5254

Merged
merged 10 commits into from
Mar 25, 2024
chipkent marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public final class QueryLanguageParser extends GenericVisitorAdapter<Class<?>, Q
private final Map<String, Class<?>> staticImportLookupCache = new HashMap<>();

// We need some class to represent null. We know for certain that this one won't be used...
private static final Class<?> NULL_CLASS = QueryLanguageParser.class;
public static final Class<?> NULL_CLASS = QueryLanguageParser.class;

/**
* The result of the QueryLanguageParser for the expression passed given to the constructor.
Expand Down Expand Up @@ -1939,7 +1939,7 @@ private static boolean isAssociativitySafeExpression(Expression expr) {
* @return {@code true} if a conversion from {@code original} to {@code target} is a widening conversion; otherwise,
* {@code false}.
*/
static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target) {
public static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target) {
if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive()
|| original.equals(void.class) || target.equals(void.class)) {
throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!");
Expand Down Expand Up @@ -1968,6 +1968,7 @@ static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target)
return false;
}


private enum LanguageParserPrimitiveType {
// Including "Enum" (or really, any differentiating string) in these names is important. They're used
// in a switch() statement, which apparently does not support qualified names. And we can't use
Expand Down Expand Up @@ -2498,13 +2499,36 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {

// Attempt python function call vectorization.
if (scopeType != null && PyCallableWrapper.class.isAssignableFrom(scopeType)) {
verifyPyCallableArguments(n, argTypes);
tryVectorizePythonCallable(n, scopeType, convertedArgExpressions, argTypes);
}

return calculateMethodReturnTypeUsingGenerics(scopeType, n.getScope().orElse(null), method, expressionTypes,
typeArguments);
}

private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class<?>[] argTypes) {
final String invokedMethodName = n.getNameAsString();

if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) {
// Only PyCallableWrapper.getAttribute()/PyCallableWrapper.call() may be invoked from the query language.
// UDF type checks are not currently supported for getAttribute() calls.
return;
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
}
if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) {
return;
}
final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS);
final String pyMethodName = pyCallableDetails.pythonMethodName;
final Object methodVar = queryScopeVariables.get(pyMethodName);
if (!(methodVar instanceof PyCallableWrapper)) {
return;
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) methodVar;
pyCallableWrapper.parseSignature();
pyCallableWrapper.verifyArguments(argTypes);
}

private Optional<CastExpr> makeCastExpressionForPyCallable(Class<?> retType, MethodCallExpr callMethodCall) {
if (retType.isPrimitive()) {
return Optional.of(new CastExpr(
Expand Down Expand Up @@ -2552,7 +2576,7 @@ private Optional<Class<?>> pyCallableReturnType(@NotNull MethodCallExpr n) {
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
pyCallableWrapper.parseSignature();
return Optional.ofNullable(pyCallableWrapper.getReturnType());
return Optional.ofNullable(pyCallableWrapper.getSignature().getReturnType());
}

@NotNull
Expand Down Expand Up @@ -2683,7 +2707,8 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
pyCallableWrapper.parseSignature();
if (!pyCallableWrapper.isVectorizableReturnType()) {
throw new PythonCallVectorizationFailure(
"Python function return type is not supported: " + pyCallableWrapper.getReturnType());
"Python function return type is not supported: "
+ pyCallableWrapper.getSignature().getReturnType());
}

// Python vectorized functions(numba, DH) return arrays of primitive/Object types. This will break the generated
Expand Down Expand Up @@ -2726,11 +2751,10 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
}
}

List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) {
// note vectorization doesn't handle Python variadic arguments
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
+ pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length);
}
}

Expand All @@ -2739,10 +2763,9 @@ private void prepareVectorizationArgs(
Expression[] expressions,
Class<?>[] argTypes,
PyCallableWrapper pyCallableWrapper) {
List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) {
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
+ pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length);
}

pyCallableWrapper.initializeChunkArguments();
Expand All @@ -2763,11 +2786,6 @@ private void prepareVectorizationArgs(
} else {
throw new IllegalStateException("Vectorizability check failed: " + n);
}

if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) {
throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " "
+ argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result,
final PyCallableWrapperJpyImpl pyCallableWrapper = cws[0];

if (pyCallableWrapper.isVectorizable()) {
checkReturnType(result, pyCallableWrapper.getReturnType());
checkReturnType(result, pyCallableWrapper.getSignature().getReturnType());

for (String variable : result.getVariablesUsed()) {
if (variable.equals("i")) {
Expand All @@ -284,7 +284,8 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result,
ArgumentsChunked argumentsChunked = pyCallableWrapper.buildArgumentsChunked(usedColumns);
PyObject vectorized = pyCallableWrapper.vectorizedCallable();
DeephavenCompatibleFunction dcf = DeephavenCompatibleFunction.create(vectorized,
pyCallableWrapper.getReturnType(), usedColumns.toArray(new String[0]), argumentsChunked, true);
pyCallableWrapper.getSignature().getReturnType(), usedColumns.toArray(new String[0]),
argumentsChunked, true);
setFilter(new ConditionFilter.ChunkFilter(
dcf.toFilterKernel(),
dcf.getColumnNames().toArray(new String[0]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ private void checkAndInitializeVectorization(Map<String, ColumnDefinition<?>> co
PyObject vectorized = pyCallableWrapper.vectorizedCallable();
formulaColumnPython = FormulaColumnPython.create(this.columnName,
DeephavenCompatibleFunction.create(vectorized,
pyCallableWrapper.getReturnType(), this.analyzedFormula.sourceDescriptor.sources,
pyCallableWrapper.getSignature().getReturnType(),
this.analyzedFormula.sourceDescriptor.sources,
argumentsChunked,
true));
formulaColumnPython.initDef(columnDefinitionMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.jpy.PyObject;

import java.util.List;
import java.util.Set;

/**
* Created by rbasralian on 8/12/23
Expand All @@ -19,8 +20,6 @@ public interface PyCallableWrapper {

Object call(Object... args);

List<Class<?>> getParamTypes();

boolean isVectorized();

boolean isVectorizable();
Expand All @@ -31,7 +30,46 @@ public interface PyCallableWrapper {

void addChunkArgument(ChunkArgument chunkArgument);

Class<?> getReturnType();
Signature getSignature();

void verifyArguments(Class<?>[] argTypes);

class Parameter {
private final String name;
private final Set<Class<?>> possibleTypes;


public Parameter(String name, Set<Class<?>> possibleTypes) {
this.name = name;
this.possibleTypes = possibleTypes;
}

public Set<Class<?>> getPossibleTypes() {
return possibleTypes;
}

public String getName() {
return name;
}
}

class Signature {
private final List<Parameter> parameters;
private final Class<?> returnType;

public Signature(List<Parameter> parameters, Class<?> returnType) {
this.parameters = parameters;
this.returnType = returnType;
}

public List<Parameter> getParameters() {
return parameters;
}

public Class<?> getReturnType() {
return returnType;
}
}

abstract class ChunkArgument {
private final Class<?> type;
Expand Down Expand Up @@ -88,4 +126,5 @@ public Object getValue() {
}

boolean isVectorizableReturnType();

}
Loading
Loading