From e07f2f0f4fb430b0cb7b1d689972bb1a81844734 Mon Sep 17 00:00:00 2001 From: Nathaniel Bauernfeind Date: Mon, 24 Jun 2024 16:58:24 -0600 Subject: [PATCH] Fix type coercion from one numeric to another --- .../engine/table/impl/select/MatchFilter.java | 95 ++++++++++- .../table/impl/QueryTableWhereTest.java | 147 ++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/MatchFilter.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/MatchFilter.java index 2568b9d7d10..7d0f17a297b 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/MatchFilter.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/MatchFilter.java @@ -15,6 +15,7 @@ import io.deephaven.engine.table.Table; import io.deephaven.engine.table.TableDefinition; import io.deephaven.engine.table.impl.QueryCompilerRequestProcessor; +import io.deephaven.engine.table.impl.lang.QueryLanguageFunctionUtils; import io.deephaven.engine.table.impl.preview.DisplayWrapper; import io.deephaven.engine.table.impl.DependencyStreamProvider; import io.deephaven.engine.table.impl.indexer.DataIndexer; @@ -24,6 +25,7 @@ import io.deephaven.util.SafeCloseable; import io.deephaven.util.datastructures.CachingSupplier; import io.deephaven.util.type.ArrayTypeUtils; +import io.deephaven.util.type.TypeUtils; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.jpy.PyObject; @@ -344,7 +346,14 @@ final boolean convertValue( if (tableDefinition.getColumn(strValue) != null) { // this is also a column name which needs to take precedence, and we can't convert it throw new IllegalArgumentException(String.format( - "Failed to convert literal value <%s> for column \"%s\" of type %s; it is a column name", + "Failed to convert value <%s> for column \"%s\" of type %s; it is a column name", + strValue, column.getName(), column.getDataType().getName())); + } + if (strValue.endsWith("_") + && tableDefinition.getColumn(strValue.substring(0, strValue.length() - 1)) != null) { + // this also a column array name which needs to take precedence, and we can't convert it + throw new IllegalArgumentException(String.format( + "Failed to convert value <%s> for column \"%s\" of type %s; it is a column array access name", strValue, column.getName(), column.getDataType().getName())); } @@ -390,6 +399,18 @@ Object convertStringLiteral(String str) { } return Byte.parseByte(str); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Byte) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.byteCast(boxer.get(paramValue)); + } }; } if (cls == short.class) { @@ -401,6 +422,18 @@ Object convertStringLiteral(String str) { } return Short.parseShort(str); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Short) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.shortCast(boxer.get(paramValue)); + } }; } if (cls == int.class) { @@ -412,6 +445,18 @@ Object convertStringLiteral(String str) { } return Integer.parseInt(str); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Integer) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.intCast(boxer.get(paramValue)); + } }; } if (cls == long.class) { @@ -423,6 +468,18 @@ Object convertStringLiteral(String str) { } return Long.parseLong(str); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Long) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.longCast(boxer.get(paramValue)); + } }; } if (cls == float.class) { @@ -434,6 +491,18 @@ Object convertStringLiteral(String str) { } return Float.parseFloat(str); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Float) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.floatCast(boxer.get(paramValue)); + } }; } if (cls == double.class) { @@ -445,6 +514,18 @@ Object convertStringLiteral(String str) { } return Double.parseDouble(str); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Double) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.doubleCast(boxer.get(paramValue)); + } }; } if (cls == Boolean.class) { @@ -485,6 +566,18 @@ Object convertStringLiteral(String str) { } return str.charAt(0); } + + @Override + Object convertParamValue(Object paramValue) { + paramValue = super.convertParamValue(paramValue); + if (paramValue instanceof Character) { + return paramValue; + } + // noinspection unchecked + final TypeUtils.TypeBoxer boxer = + (TypeUtils.TypeBoxer) TypeUtils.getTypeBoxer(paramValue.getClass()); + return QueryLanguageFunctionUtils.charCast(boxer.get(paramValue)); + } }; } if (cls == BigDecimal.class) { diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/QueryTableWhereTest.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/QueryTableWhereTest.java index e800401c839..599df3684b2 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/QueryTableWhereTest.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/QueryTableWhereTest.java @@ -43,6 +43,7 @@ import junit.framework.TestCase; import org.apache.commons.lang3.mutable.MutableBoolean; import org.apache.commons.lang3.mutable.MutableObject; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -1404,4 +1405,150 @@ public void testRangeFilterFallback() { final WhereFilter realFilter = filter.getRealFilter(); Assert.eqTrue(realFilter instanceof ConditionFilter, "realFilter instanceof ConditionFilter"); } + + @Test + public void testEnsureColumnsTakePrecedence() { + final Table table = emptyTable(10).update("X=i", "Y=i%2"); + ExecutionContext.getContext().getQueryScope().putParam("Y", 5); + + { + final Table r1 = table.where("X == Y"); + final Table r2 = table.where("Y == X"); + Assert.equals(r1.getRowSet(), "r1.getRowSet()", RowSetFactory.flat(2)); + assertTableEquals(r1, r2); + } + + { + final Table r1 = table.where("X >= Y"); + final Table r2 = table.where("Y <= X"); + Assert.equals(r1.getRowSet(), "r1.getRowSet()", RowSetFactory.flat(10)); + assertTableEquals(r1, r2); + } + + { + final Table r1 = table.where("X > Y"); + final Table r2 = table.where("Y < X"); + Assert.equals(r1.getRowSet(), "r1.getRowSet()", RowSetFactory.fromRange(2, 9)); + assertTableEquals(r1, r2); + } + + { + final Table r1 = table.where("X < Y"); + final Table r2 = table.where("Y > X"); + Assert.equals(r1.getRowSet(), "r1.getRowSet()", RowSetFactory.empty()); + assertTableEquals(r1, r2); + } + + { + final Table r1 = table.where("X <= Y"); + final Table r2 = table.where("Y >= X"); + Assert.equals(r1.getRowSet(), "r1.getRowSet()", RowSetFactory.flat(2)); + assertTableEquals(r1, r2); + } + } + + @Test + @Ignore + public void testEnsureColumnArraysTakePrecedence() { + // TODO: column arrays aren't well supported in match arrays and this example's where filter fails to compile + final Table table = emptyTable(10).update("X=i", "Y=new int[]{1, 5, 9}"); + ExecutionContext.getContext().getQueryScope().putParam("Y_", new int[] {0, 4, 8}); + + final Table result = table.where("X == Y_[1]"); + Assert.equals(result.getRowSet(), "result.getRowSet()", RowSetFactory.fromKeys(5)); + + // check that the mirror matches the expected result + final Table mResult = table.where("Y_[1] == X"); + assertTableEquals(result, mResult); + } + + @Test + public void testIntToByteCoercion() { + final Table table = emptyTable(11).update("X = ii % 2 == 0 ? (byte) ii : null"); + final Class colType = table.getDefinition().getColumn("X").getDataType(); + Assert.eq(colType, "colType", byte.class); + + ExecutionContext.getContext().getQueryScope().putParam("val_null", QueryConstants.NULL_INT); + ExecutionContext.getContext().getQueryScope().putParam("val_5", 5); + + final Table null_result = table.where("X == val_null"); + final Table range_result = table.where("X >= val_5"); + Assert.eq(null_result.size(), "null_result.size()", 5); + Assert.eq(range_result.size(), "range_result.size()", 3); + } + + @Test + public void testIntToShortCoercion() { + final Table table = emptyTable(11).update("X= ii % 2 == 0 ? (short) ii : null"); + final Class colType = table.getDefinition().getColumn("X").getDataType(); + Assert.eq(colType, "colType", short.class); + + ExecutionContext.getContext().getQueryScope().putParam("val_null", QueryConstants.NULL_INT); + ExecutionContext.getContext().getQueryScope().putParam("val_5", 5); + + final Table null_result = table.where("X == val_null"); + final Table range_result = table.where("X >= val_5"); + Assert.eq(null_result.size(), "null_result.size()", 5); + Assert.eq(range_result.size(), "range_result.size()", 3); + } + + @Test + public void testLongToIntCoercion() { + final Table table = emptyTable(11).update("X= ii % 2 == 0 ? (int) ii : null"); + final Class colType = table.getDefinition().getColumn("X").getDataType(); + Assert.eq(colType, "colType", int.class); + + ExecutionContext.getContext().getQueryScope().putParam("val_null", QueryConstants.NULL_LONG); + ExecutionContext.getContext().getQueryScope().putParam("val_5", 5L); + + final Table null_result = table.where("X == val_null"); + final Table range_result = table.where("X >= val_5"); + Assert.eq(null_result.size(), "null_result.size()", 5); + Assert.eq(range_result.size(), "range_result.size()", 3); + } + + @Test + public void testIntToLongCoercion() { + final Table table = emptyTable(11).update("X= ii % 2 == 0 ? ii : null"); + final Class colType = table.getDefinition().getColumn("X").getDataType(); + Assert.eq(colType, "colType", long.class); + + ExecutionContext.getContext().getQueryScope().putParam("val_null", QueryConstants.NULL_INT); + ExecutionContext.getContext().getQueryScope().putParam("val_5", 5); + + final Table null_result = table.where("X == val_null"); + final Table range_result = table.where("X >= val_5"); + Assert.eq(null_result.size(), "null_result.size()", 5); + Assert.eq(range_result.size(), "range_result.size()", 3); + } + + @Test + public void testIntToFloatCoercion() { + final Table table = emptyTable(11).update("X= ii % 2 == 0 ? (float) ii : null"); + final Class colType = table.getDefinition().getColumn("X").getDataType(); + Assert.eq(colType, "colType", float.class); + + ExecutionContext.getContext().getQueryScope().putParam("val_null", QueryConstants.NULL_INT); + ExecutionContext.getContext().getQueryScope().putParam("val_5", 5); + + final Table null_result = table.where("X == val_null"); + final Table range_result = table.where("X >= val_5"); + Assert.eq(null_result.size(), "null_result.size()", 5); + Assert.eq(range_result.size(), "range_result.size()", 3); + } + + @Test + public void testIntToDoubleCoercion() { + final Table table = emptyTable(11).update("X= ii % 2 == 0 ? (double) ii : null"); + final Class colType = table.getDefinition().getColumn("X").getDataType(); + Assert.eq(colType, "colType", double.class); + + ExecutionContext.getContext().getQueryScope().putParam("val_null", QueryConstants.NULL_INT); + ExecutionContext.getContext().getQueryScope().putParam("val_5", 5); + + final Table null_result = table.where("X == val_null"); + final Table range_result = table.where("X >= val_5"); + Assert.eq(null_result.size(), "null_result.size()", 5); + Assert.eq(range_result.size(), "range_result.size()", 3); + } }