Skip to content

Commit

Permalink
add more case
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee committed Aug 30, 2024
1 parent e7caab1 commit c74937a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ trait ClickHouseTableV2Base {
s"$keyName $s can not contain '.' (not support nested column yet)")
}
})
Some(keys.map(s => s.toLowerCase()))
Some(keys)
} else {
None
}
Expand All @@ -102,27 +102,22 @@ trait ClickHouseTableV2Base {

lazy val orderByKeyOption: Option[Seq[String]] = {
if (bucketOption.isDefined && bucketOption.get.sortColumnNames.nonEmpty) {
val orderByKes = bucketOption.get.sortColumnNames
val invalidKeys = orderByKes.intersect(partitionColumns)
val orderByKeys = bucketOption.get.sortColumnNames.map(normalizeColName).toSeq
val invalidKeys = orderByKeys.intersect(partitionColumns)
if (invalidKeys.nonEmpty) {
throw new IllegalStateException(
s"partition cols $invalidKeys can not be in the order by keys.")
}
Some(orderByKes)
Some(orderByKeys)
} else {
val tableProperties = deltaProperties
if (tableProperties.containsKey("orderByKey")) {
if (tableProperties.get("orderByKey").nonEmpty) {
val orderByKes = tableProperties.get("orderByKey").split(",").map(_.trim).toSeq
val invalidKeys = orderByKes.intersect(partitionColumns)
if (invalidKeys.nonEmpty) {
throw new IllegalStateException(
s"partition cols $invalidKeys can not be in the order by keys.")
}
Some(orderByKes)
} else {
None
val orderByKeys = getCommaSeparatedColumns("orderByKey")
if (orderByKeys.nonEmpty) {
val invalidKeys = orderByKeys.get.intersect(partitionColumns)
if (invalidKeys.nonEmpty) {
throw new IllegalStateException(
s"partition cols $invalidKeys can not be in the order by keys.")
}
orderByKeys
} else {
None
}
Expand All @@ -131,27 +126,18 @@ trait ClickHouseTableV2Base {

lazy val primaryKeyOption: Option[Seq[String]] = {
if (orderByKeyOption.isDefined) {
val tableProperties = deltaProperties
if (tableProperties.containsKey("primaryKey")) {
if (tableProperties.get("primaryKey").nonEmpty) {
val primaryKeys = tableProperties.get("primaryKey").split(",").map(_.trim).toSeq
if (!orderByKeyOption.get.mkString(",").startsWith(primaryKeys.mkString(","))) {
throw new IllegalStateException(
s"Primary key $primaryKeys must be a prefix of the sorting key")
}
Some(primaryKeys)
} else {
None
}
} else {
None
val primaryKeys = getCommaSeparatedColumns("primaryKey")
if (!orderByKeyOption.get.mkString(",").startsWith(primaryKeys.get.mkString(","))) {
throw new IllegalStateException(
s"Primary key $primaryKeys must be a prefix of the sorting key")
}
primaryKeys
} else {
None
}
}

lazy val partitionColumns = deltaSnapshot.metadata.partitionColumns
lazy val partitionColumns = deltaSnapshot.metadata.partitionColumns.map(normalizeColName).toSeq

lazy val clickhouseTableConfigs: Map[String, String] = {
val tableProperties = deltaProperties()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ object MergeTreeFileFormatWriter extends Logging {

val writerBucketSpec = bucketSpec.map {
spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
val bucketColumns =
spec.bucketColumnNames.map(c => dataColumns.find(_.name.equalsIgnoreCase(c)).get)
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
Expand All @@ -104,7 +105,7 @@ object MergeTreeFileFormatWriter extends Logging {
MergeTreeWriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name.equalsIgnoreCase(c)).get)
}

val caseInsensitiveOptions = CaseInsensitiveMap(options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ class GlutenClickHouseHiveTableSuite
}

private val txt_table_name = "hive_txt_test"
private val txt_upper_table_name = "hive_txt_upper_test"
private val txt_user_define_input = "hive_txt_user_define_input"
private val json_table_name = "hive_json_test"
private val parquet_table_name = "hive_parquet_test"

private val txt_table_create_sql = genTableCreateSql(txt_table_name, "textfile")
private val txt_upper_create_sql = genTableCreateUpperSql(txt_upper_table_name, "textfile")

private val parquet_table_create_sql = genTableCreateSql(parquet_table_name, "parquet")
private val json_table_create_sql = "create table if not exists %s (".format(json_table_name) +
"string_field string," +
Expand Down Expand Up @@ -136,6 +139,24 @@ class GlutenClickHouseHiveTableSuite
"map_field map<int, long>," +
"map_field_with_null map<int, long>) stored as %s".format(fileFormat)

def genTableCreateUpperSql(tableName: String, fileFormat: String): String =
"create table if not exists %s (".format(tableName) +
"STRING_FIELD string," +
"INT_FIELD int," +
"LONG_FIELD long," +
"FLOAT_FIELD float," +
"DOUBLE_FIELD double," +
"SHORT_FIELD short," +
"BYTE_FIELD byte," +
"BOOL_FIELD boolean," +
"DECIMAL_FIELD decimal(23, 12)," +
"DATE_FIELD date," +
"TIMESTAMP_FIELD timestamp," +
"ARRAY_FIELD array<int>," +
"ARRAY_FIELD_WITH_NULL array<int>," +
"MAP_FIELD map<int, long>," +
"MAP_FIELD_WITH_NULL map<int, long>) stored as %s".format(fileFormat)

protected def initializeTable(
table_name: String,
table_create_sql: String,
Expand All @@ -161,6 +182,7 @@ class GlutenClickHouseHiveTableSuite
override def beforeAll(): Unit = {
super.beforeAll()
initializeTable(txt_table_name, txt_table_create_sql, null)
initializeTable(txt_upper_table_name, txt_upper_create_sql, null)
initializeTable(txt_user_define_input, txt_table_user_define_create_sql, null)
initializeTable(
json_table_name,
Expand Down Expand Up @@ -1214,8 +1236,9 @@ class GlutenClickHouseHiveTableSuite
|select
| string_field,
| int_field,
| long_field
| from $txt_user_define_input
| long_field,
| date_field
| from $txt_table_name
|""".stripMargin)

sourceDF.write
Expand All @@ -1224,6 +1247,50 @@ class GlutenClickHouseHiveTableSuite
.option("clickhouse.bucketColumnNames", "STRING_FIELD")
.mode(SaveMode.Overwrite)
.save(dataPath)

assert(new File(dataPath).listFiles().nonEmpty)

val dataPath2 = s"$basePath/lineitem_mergetree_bucket2"
val df2 = spark.sql(s"""
|select
| string_field STRING_FIELD,
| int_field INT_FIELD,
| long_field LONG_FIELD,
| date_field DATE_FIELD
| from $txt_table_name
|""".stripMargin)

df2.write
.format("clickhouse")
.partitionBy("DATE_FIELD")
.option("clickhouse.numBuckets", "1")
.option("clickhouse.bucketColumnNames", "STRING_FIELD")
.option("clickhouse.orderByKey", "INT_FIELD,LONG_FIELD")
.option("clickhouse.primaryKey", "INT_FIELD")
.mode(SaveMode.Overwrite)
.save(dataPath2)
assert(new File(dataPath2).listFiles().nonEmpty)

val dataPath3 = s"$basePath/lineitem_mergetree_bucket3"
val df3 = spark.sql(s"""
|select
| string_field,
| int_field,
| long_field,
| date_field
| from $txt_upper_table_name
|""".stripMargin)

df3.write
.format("clickhouse")
.partitionBy("date_field")
.option("clickhouse.numBuckets", "1")
.option("clickhouse.bucketColumnNames", "string_field")
.option("clickhouse.orderByKey", "int_field,LONG_FIELD")
.option("clickhouse.primaryKey", "INT_FIELD")
.mode(SaveMode.Overwrite)
.save(dataPath3)
assert(new File(dataPath3).listFiles().nonEmpty)
}

test("GLUTEN-6506: Orc read time zone") {
Expand Down

0 comments on commit c74937a

Please sign in to comment.