diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala index 8853dfc77853..1d4d1b6f8afb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala @@ -226,4 +226,32 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { spark.sql("drop table t2") } + test("array decimal32 CH column to row") { + compareResultsAgainstVanillaSpark("SELECT array(1.0, 2.0)", true, { _ => }, false) + compareResultsAgainstVanillaSpark("SELECT map(1.0, '2', 3.0, '4')", true, { _ => }, false) + } + + test("array decimal32 spark row to CH column") { + withTable("test_array_decimal") { + sql(""" + |create table test_array_decimal(val array) + |using parquet + |""".stripMargin) + sql(""" + |insert into test_array_decimal + |values array(1.0, 2.0), array(3.0, 4.0), + |array(5.0, 6.0), array(7.0, 8.0), array(7.0, 7.0) + |""".stripMargin) + // disable native scan so will get a spark row to CH column + withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") { + val q = "SELECT max(val) from test_array_decimal" + compareResultsAgainstVanillaSpark(q, true, { _ => }, false) + val q2 = "SELECT max(val[0]) from test_array_decimal" + compareResultsAgainstVanillaSpark(q2, true, { _ => }, false) + val q3 = "SELECT max(val[1]) from test_array_decimal" + compareResultsAgainstVanillaSpark(q3, true, { _ => }, false) + } + } + } + } diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp index 5bb66e4b3f9d..3d5a7731bffb 100644 --- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp +++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp @@ -586,12 +586,11 @@ int64_t BackingDataLengthCalculator::getArrayElementSize(const DataTypePtr & nes else if (nested_which.isUInt16() || nested_which.isInt16() || nested_which.isDate()) return 2; else if ( - nested_which.isUInt32() || nested_which.isInt32() || nested_which.isFloat32() || nested_which.isDate32() - || nested_which.isDecimal32()) + nested_which.isUInt32() || nested_which.isInt32() || nested_which.isFloat32() || nested_which.isDate32()) return 4; else if ( nested_which.isUInt64() || nested_which.isInt64() || nested_which.isFloat64() || nested_which.isDateTime64() - || nested_which.isDecimal64()) + || nested_which.isDecimal32() || nested_which.isDecimal64()) return 8; else return 8; @@ -702,6 +701,12 @@ int64_t VariableLengthDataWriter::writeArray(size_t row_idx, const DB::Array & a auto v = elem.get(); writer.unsafeWrite(reinterpret_cast(&v), buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); } + else if (writer.getWhichDataType().isDecimal32()) + { + // We can not use get() directly here to process Decimal32 field, + // because it will get 4 byte data, but Decimal32 is 8 byte in Spark, which will cause error conversion. + writer.write(elem, buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); + } else writer.unsafeWrite( reinterpret_cast(&elem.get()),