Skip to content

Commit

Permalink
[GLUTEN-4316][CH] Fix crash on dynamic partition inserting (#4317)
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)

(Fixes: #4316)

How was this patch tested?
Add UT
  • Loading branch information
exmy authored Jan 16, 2024
1 parent 2588a89 commit 46527ca
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -972,4 +972,24 @@ class GlutenClickHouseNativeWriteTableSuite
}
}

test("GLUTEN-4316: fix crash on dynamic partition inserting") {
withSQLConf(
("spark.gluten.sql.native.writer.enabled", "true"),
("spark.gluten.enabled", "true")) {
formats.foreach(
format => {
val tbl = "t_" + format
spark.sql(s"drop table IF EXISTS $tbl")
val sql1 =
s"create table $tbl(a int, b map<string, string>, c struct<d:string, e:string>) " +
s"partitioned by (day string) stored as $format"
val sql2 = s"insert overwrite $tbl partition (day) " +
s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " +
s"struct('1', null) as c, '2024-01-08' as day from range(10)"
spark.sql(sql1)
spark.sql(sql2)
})
}
}

}
92 changes: 86 additions & 6 deletions cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,25 @@
* limitations under the License.
*/
#include "SourceFromJavaIter.h"
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnTuple.h>
#include <Columns/IColumn.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <DataTypes/DataTypesNumber.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <jni/jni_common.h>
#include <Common/assert_cast.h>
#include <Common/CHUtil.h>
#include <Common/DebugUtils.h>
#include <Common/Exception.h>
#include <Common/JNIUtils.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/IDataType.h>

namespace local_engine
{
Expand Down Expand Up @@ -57,7 +67,7 @@ DB::Chunk SourceFromJavaIter::generate()
{
jbyteArray block = static_cast<jbyteArray>(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next));
DB::Block * data = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block));
if(materialize_input)
if (materialize_input)
materializeBlockInplace(*data);
if (data->rows() > 0)
{
Expand All @@ -82,12 +92,14 @@ DB::Chunk SourceFromJavaIter::generate()
CLEAN_JNIENV
return result;
}

SourceFromJavaIter::~SourceFromJavaIter()
{
GET_JNIENV(env)
env->DeleteGlobalRef(java_iter);
CLEAN_JNIENV
}

Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr)
{
jsize len = env->GetArrayLength(arr);
Expand All @@ -99,19 +111,87 @@ Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr)
delete[] c_arr;
return result;
}

void SourceFromJavaIter::convertNullable(DB::Chunk & chunk)
{
auto output = this->getOutputs().front().getHeader();
auto rows = chunk.getNumRows();
auto columns = chunk.detachColumns();
for (size_t i = 0; i < columns.size(); ++i)
{
DB::WhichDataType which(columns.at(i)->getDataType());
if (output.getByPosition(i).type->isNullable() && !which.isNullable() && !which.isAggregateFunction())
{
columns[i] = DB::makeNullable(columns.at(i));
}
const auto & column = columns.at(i);
const auto & type = output.getByPosition(i).type;
columns[i] = convertNestedNullable(column, type);
}
chunk.setColumns(columns, rows);
}

DB::ColumnPtr SourceFromJavaIter::convertNestedNullable(const DB::ColumnPtr & column, const DB::DataTypePtr & target_type)
{
DB::WhichDataType column_type(column->getDataType());
if (column_type.isAggregateFunction())
return column;

if (DB::isColumnConst(*column))
{
const auto & data_column = assert_cast<const DB::ColumnConst &>(*column).getDataColumnPtr();
const auto & result_column = convertNestedNullable(data_column, target_type);
return DB::ColumnConst::create(result_column, column->size());
}

// if target type is non-nullable, the column type must be also non-nullable, recursively converting it's nested type
// if target type is nullable, the column type may be nullable or non-nullable, converting it and then recursively converting it's nested type
DB::ColumnPtr new_column = column;
if (!column_type.isNullable() && target_type->isNullable())
new_column = DB::makeNullable(column);

DB::ColumnPtr nested_column = new_column;
DB::DataTypePtr nested_target_type = removeNullable(target_type);
if (new_column->isNullable())
{
const auto & nullable_col = typeid_cast<const DB::ColumnNullable *>(new_column->getPtr().get());
nested_column = nullable_col->getNestedColumnPtr();
const auto & result_column = convertNestedNullable(nested_column, nested_target_type);
return DB::ColumnNullable::create(result_column, nullable_col->getNullMapColumnPtr());
}

DB::WhichDataType nested_column_type(nested_column->getDataType());
if (nested_column_type.isMap())
{
// header: Map(String, Nullable(String))
// chunk: Map(String, String)
const auto & array_column = assert_cast<const DB::ColumnMap &>(*nested_column).getNestedColumn();
const auto & map_type = assert_cast<const DB::DataTypeMap &>(*nested_target_type);
auto tuple_columns = assert_cast<const DB::ColumnTuple *>(array_column.getDataPtr().get())->getColumns();
// only convert for value column as key is always non-nullable
const auto & value_column = convertNestedNullable(tuple_columns[1], map_type.getValueType());
auto result_column = DB::ColumnArray::create(DB::ColumnTuple::create(DB::Columns{tuple_columns[0], value_column}), array_column.getOffsetsPtr());
return DB::ColumnMap::create(std::move(result_column));
}

if (nested_column_type.isArray())
{
// header: Array(Nullable(String))
// chunk: Array(String)
const auto & list_column = assert_cast<const DB::ColumnArray &>(*nested_column);
auto nested_type = assert_cast<const DB::DataTypeArray &>(*nested_target_type).getNestedType();
const auto & result_column = convertNestedNullable(list_column.getDataPtr(), nested_type);
return DB::ColumnArray::create(result_column, list_column.getOffsetsPtr());
}

if (nested_column_type.isTuple())
{
// header: Tuple(Nullable(String), Nullable(String))
// chunk: Tuple(String, Nullable(String))
const auto & tuple_column = assert_cast<const DB::ColumnTuple &>(*nested_column);
auto nested_types = assert_cast<const DB::DataTypeTuple &>(*nested_target_type).getElements();
DB::Columns columns;
for (size_t i = 0; i != tuple_column.tupleSize(); ++i)
columns.push_back(convertNestedNullable(tuple_column.getColumnPtr(i), nested_types[i]));
return DB::ColumnTuple::create(std::move(columns));
}

return new_column;
}

}
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Storages/SourceFromJavaIter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <jni.h>
#include <Processors/ISource.h>
#include <Interpreters/Context.h>
#include <Columns/IColumn.h>

namespace local_engine
{
Expand All @@ -38,6 +39,7 @@ class SourceFromJavaIter : public DB::ISource
private:
DB::Chunk generate() override;
void convertNullable(DB::Chunk & chunk);
DB::ColumnPtr convertNestedNullable(const DB::ColumnPtr & column, const DB::DataTypePtr & target_type);

jobject java_iter;
bool materialize_input;
Expand Down

0 comments on commit 46527ca

Please sign in to comment.