diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala index d3e910191101..aaa336e9bfa7 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -972,4 +972,24 @@ class GlutenClickHouseNativeWriteTableSuite } } + test("GLUTEN-4316: fix crash on dynamic partition inserting") { + withSQLConf( + ("spark.gluten.sql.native.writer.enabled", "true"), + ("spark.gluten.enabled", "true")) { + formats.foreach( + format => { + val tbl = "t_" + format + spark.sql(s"drop table IF EXISTS $tbl") + val sql1 = + s"create table $tbl(a int, b map, c struct) " + + s"partitioned by (day string) stored as $format" + val sql2 = s"insert overwrite $tbl partition (day) " + + s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + + s"struct('1', null) as c, '2024-01-08' as day from range(10)" + spark.sql(sql1) + spark.sql(sql2) + }) + } + } + } diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp index 0bd7ca905351..d67161623af2 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp @@ -15,15 +15,25 @@ * limitations under the License. */ #include "SourceFromJavaIter.h" +#include #include +#include +#include +#include #include #include #include #include +#include #include #include #include #include +#include +#include +#include +#include +#include namespace local_engine { @@ -57,7 +67,7 @@ DB::Chunk SourceFromJavaIter::generate() { jbyteArray block = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); DB::Block * data = reinterpret_cast(byteArrayToLong(env, block)); - if(materialize_input) + if (materialize_input) materializeBlockInplace(*data); if (data->rows() > 0) { @@ -82,12 +92,14 @@ DB::Chunk SourceFromJavaIter::generate() CLEAN_JNIENV return result; } + SourceFromJavaIter::~SourceFromJavaIter() { GET_JNIENV(env) env->DeleteGlobalRef(java_iter); CLEAN_JNIENV } + Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr) { jsize len = env->GetArrayLength(arr); @@ -99,6 +111,7 @@ Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr) delete[] c_arr; return result; } + void SourceFromJavaIter::convertNullable(DB::Chunk & chunk) { auto output = this->getOutputs().front().getHeader(); @@ -106,12 +119,79 @@ void SourceFromJavaIter::convertNullable(DB::Chunk & chunk) auto columns = chunk.detachColumns(); for (size_t i = 0; i < columns.size(); ++i) { - DB::WhichDataType which(columns.at(i)->getDataType()); - if (output.getByPosition(i).type->isNullable() && !which.isNullable() && !which.isAggregateFunction()) - { - columns[i] = DB::makeNullable(columns.at(i)); - } + const auto & column = columns.at(i); + const auto & type = output.getByPosition(i).type; + columns[i] = convertNestedNullable(column, type); } chunk.setColumns(columns, rows); } + +DB::ColumnPtr SourceFromJavaIter::convertNestedNullable(const DB::ColumnPtr & column, const DB::DataTypePtr & target_type) +{ + DB::WhichDataType column_type(column->getDataType()); + if (column_type.isAggregateFunction()) + return column; + + if (DB::isColumnConst(*column)) + { + const auto & data_column = assert_cast(*column).getDataColumnPtr(); + const auto & result_column = convertNestedNullable(data_column, target_type); + return DB::ColumnConst::create(result_column, column->size()); + } + + // if target type is non-nullable, the column type must be also non-nullable, recursively converting it's nested type + // if target type is nullable, the column type may be nullable or non-nullable, converting it and then recursively converting it's nested type + DB::ColumnPtr new_column = column; + if (!column_type.isNullable() && target_type->isNullable()) + new_column = DB::makeNullable(column); + + DB::ColumnPtr nested_column = new_column; + DB::DataTypePtr nested_target_type = removeNullable(target_type); + if (new_column->isNullable()) + { + const auto & nullable_col = typeid_cast(new_column->getPtr().get()); + nested_column = nullable_col->getNestedColumnPtr(); + const auto & result_column = convertNestedNullable(nested_column, nested_target_type); + return DB::ColumnNullable::create(result_column, nullable_col->getNullMapColumnPtr()); + } + + DB::WhichDataType nested_column_type(nested_column->getDataType()); + if (nested_column_type.isMap()) + { + // header: Map(String, Nullable(String)) + // chunk: Map(String, String) + const auto & array_column = assert_cast(*nested_column).getNestedColumn(); + const auto & map_type = assert_cast(*nested_target_type); + auto tuple_columns = assert_cast(array_column.getDataPtr().get())->getColumns(); + // only convert for value column as key is always non-nullable + const auto & value_column = convertNestedNullable(tuple_columns[1], map_type.getValueType()); + auto result_column = DB::ColumnArray::create(DB::ColumnTuple::create(DB::Columns{tuple_columns[0], value_column}), array_column.getOffsetsPtr()); + return DB::ColumnMap::create(std::move(result_column)); + } + + if (nested_column_type.isArray()) + { + // header: Array(Nullable(String)) + // chunk: Array(String) + const auto & list_column = assert_cast(*nested_column); + auto nested_type = assert_cast(*nested_target_type).getNestedType(); + const auto & result_column = convertNestedNullable(list_column.getDataPtr(), nested_type); + return DB::ColumnArray::create(result_column, list_column.getOffsetsPtr()); + } + + if (nested_column_type.isTuple()) + { + // header: Tuple(Nullable(String), Nullable(String)) + // chunk: Tuple(String, Nullable(String)) + const auto & tuple_column = assert_cast(*nested_column); + auto nested_types = assert_cast(*nested_target_type).getElements(); + DB::Columns columns; + for (size_t i = 0; i != tuple_column.tupleSize(); ++i) + columns.push_back(convertNestedNullable(tuple_column.getColumnPtr(i), nested_types[i])); + return DB::ColumnTuple::create(std::move(columns)); + } + + return new_column; +} + } diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h index 98fd1fbb75df..e5cc601cff1e 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace local_engine { @@ -38,6 +39,7 @@ class SourceFromJavaIter : public DB::ISource private: DB::Chunk generate() override; void convertNullable(DB::Chunk & chunk); + DB::ColumnPtr convertNestedNullable(const DB::ColumnPtr & column, const DB::DataTypePtr & target_type); jobject java_iter; bool materialize_input;