diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala index 2778710155bf9..db5e75458f301 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala @@ -87,7 +87,7 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { options, columnPruning = session.sessionState.conf.csvColumnPruning, session.sessionState.conf.sessionLocalTimeZone) - checkSchema(dataSchema) && + SparkSchemaUtil.checkSchema(dataSchema) && checkCsvOptions(csvOptions, session.sessionState.conf.sessionLocalTimeZone) && dataSchema.nonEmpty } @@ -105,14 +105,4 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { csvOptions.columnPruning && SparkShimLoader.getSparkShims.dateTimestampFormatInReadIsDefaultValue(csvOptions, timeZone) } - - private def checkSchema(schema: StructType): Boolean = { - try { - SparkSchemaUtil.toArrowSchema(schema) - true - } catch { - case _: Exception => - false - } - } } diff --git a/cpp/velox/tests/VeloxRowToColumnarTest.cc b/cpp/velox/tests/VeloxRowToColumnarTest.cc index c784dbd59c346..0d11dd4acbc98 100644 --- a/cpp/velox/tests/VeloxRowToColumnarTest.cc +++ b/cpp/velox/tests/VeloxRowToColumnarTest.cc @@ -87,10 +87,58 @@ TEST_F(VeloxRowToColumnarTest, allTypes) { makeNullableFlatVector<bool>( {std::nullopt, true, false, std::nullopt, true, true, false, true, std::nullopt, std::nullopt}), makeFlatVector<velox::StringView>( - {"alice0", "bob1", "alice2", "bob3", "Alice4", "Bob5", "AlicE6", "boB7", "ALICE8", "BOB9"}), + {"alice0", + "bob1", + "alice2", + "bob3", + "Alice4", + "Bob5123456789098766notinline", + "AlicE6", + "boB7", + "ALICE8", + "BOB9"}), makeNullableFlatVector<velox::StringView>( {"alice", "bob", std::nullopt, std::nullopt, "Alice", "Bob", std::nullopt, "alicE", std::nullopt, "boB"}), }); testRowVectorEqual(vector); } + +TEST_F(VeloxRowToColumnarTest, bigint) { + auto vector = makeRowVector({ + makeNullableFlatVector<int64_t>({1, 2, 3, std::nullopt, 4, std::nullopt, 5, 6, std::nullopt, 7}), + }); + testRowVectorEqual(vector); +} + +TEST_F(VeloxRowToColumnarTest, decimal) { + auto vector = makeRowVector({ + makeNullableFlatVector<int128_t>( + {123456, HugeInt::build(1045, 1789), 3678, std::nullopt, 4, std::nullopt, 5, 687987, std::nullopt, 7}, + DECIMAL(38, 2)), + makeNullableFlatVector<int64_t>( + {178987, 2, 3, std::nullopt, 4, std::nullopt, 5, 6, std::nullopt, 7}, DECIMAL(12, 3)), + }); + testRowVectorEqual(vector); +} + +TEST_F(VeloxRowToColumnarTest, timestamp) { + auto vector = makeRowVector({ + makeNullableFlatVector<Timestamp>( + {Timestamp(-946684800, 0), + Timestamp(-7266, 0), + Timestamp(0, 0), + Timestamp(946684800, 0), + Timestamp(9466848000, 0), + Timestamp(94668480000, 0), + Timestamp(946729316, 0), + Timestamp(946729316, 0), + Timestamp(946729316, 0), + Timestamp(7266, 0), + Timestamp(-50049331200, 0), + Timestamp(253405036800, 0), + Timestamp(-62480037600, 0), + std::nullopt}), + }); + testRowVectorEqual(vector); +} } // namespace gluten diff --git a/gluten-data/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java b/gluten-data/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java index dfd570debc0a5..da741af0f01ca 100644 --- a/gluten-data/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java +++ b/gluten-data/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java @@ -718,6 +718,17 @@ public Decimal getDecimal(int rowId, int precision, int scale) { return accessor.getDecimal(rowId, precision, scale); } + @Override + public void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + putInt(rowId, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + putLong(rowId, value.toUnscaledLong()); + } else { + writer.setBytes(rowId, value.toJavaBigDecimal()); + } + } + @Override public UTF8String getUTF8String(int rowId) { if (isNullAt(rowId)) { @@ -1255,9 +1266,8 @@ void setNull(int rowId) { throw new UnsupportedOperationException(); } - void setNotNull(int rowId) { - throw new UnsupportedOperationException(); - } + // Arrow not need to setNotNull, set the valus is enough. + void setNotNull(int rowId) {} void setNulls(int rowId, int count) { throw new UnsupportedOperationException(); diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/utils/SparkSchemaUtil.scala b/gluten-data/src/main/scala/org/apache/spark/sql/utils/SparkSchemaUtil.scala index b49077bd27403..8e66981ac72f7 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/utils/SparkSchemaUtil.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/utils/SparkSchemaUtil.scala @@ -37,6 +37,16 @@ object SparkSchemaUtil { SparkArrowUtil.toArrowSchema(schema, timeZoneId) } + def checkSchema(schema: StructType): Boolean = { + try { + SparkSchemaUtil.toArrowSchema(schema) + true + } catch { + case _: Exception => + false + } + } + def isTimeZoneIDEquivalentToUTC(zoneId: String): Boolean = { getTimeZoneIDOffset(zoneId) == 0 }