Skip to content

Commit

Permalink
Spark 3.4: Support distribute by any predefined transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Yxang authored and pan3793 committed Aug 27, 2023
1 parent c025ef3 commit c8b61f5
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,6 @@ case class ClusterSpec(
override def toString: String = s"cluster: $name, shards: [${shards.mkString(", ")}]"

@JsonIgnore @transient override lazy val nodes: Array[NodeSpec] = shards.sorted.flatMap(_.nodes)

def totalWeight: Int = shards.map(_.weight).sum
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.clickhouse.cluster

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row

class ClusterShardByTransformSuite extends SparkClickHouseClusterTest {
override protected def sparkConf: SparkConf = {
val _conf = super.sparkConf
.set("spark.clickhouse.write.distributed.convertLocal", "true")
_conf
}

def runTest(func_name: String, func_args: Array[String]): Unit = {
val func_expr = s"$func_name(${func_args.mkString(",")})"
val cluster = "single_replica"
val db = s"db_${func_name}_shard_transform"
val tbl_dist = s"tbl_${func_name}_shard"
val tbl_local = s"${tbl_dist}_local"

try {
runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db ON CLUSTER $cluster")

spark.sql(
s"""CREATE TABLE $db.$tbl_local (
| create_time TIMESTAMP NOT NULL,
| create_date DATE NOT NULL,
| value STRING NOT NULL
|) USING ClickHouse
|TBLPROPERTIES (
| cluster = '$cluster',
| engine = 'MergeTree()',
| order_by = 'create_time'
|)
|""".stripMargin
)

runClickHouseSQL(
s"""CREATE TABLE $db.$tbl_dist ON CLUSTER $cluster
|AS $db.$tbl_local
|ENGINE = Distributed($cluster, '$db', '$tbl_local', $func_expr)
|""".stripMargin
)
spark.sql(
s"""INSERT INTO `$db`.`$tbl_dist`
|VALUES
| (timestamp'2021-01-01 10:10:10', date'2021-01-01', '1'),
| (timestamp'2022-02-02 11:10:10', date'2022-02-02', '2'),
| (timestamp'2023-03-03 12:10:10', date'2023-03-03', '3'),
| (timestamp'2024-04-04 13:10:10', date'2024-04-04', '4')
| AS tab(create_time, create_date, value)
|""".stripMargin
)
// check that data is indeed written
checkAnswer(
spark.table(s"$db.$tbl_dist").select("value").orderBy("create_time"),
Seq(Row("1"), Row("2"), Row("3"), Row("4"))
)

// check same data is sharded in the same server comparing native sharding
runClickHouseSQL(
s"""INSERT INTO `$db`.`$tbl_dist`
|VALUES
| (timestamp'2021-01-01 10:10:10', date'2021-01-01', '1'),
| (timestamp'2022-02-02 11:10:10', date'2022-02-02', '2'),
| (timestamp'2023-03-03 12:10:10', date'2023-03-03', '3'),
| (timestamp'2024-04-04 13:10:10', date'2024-04-04', '4')
|""".stripMargin
)
checkAnswer(
spark.table(s"$db.$tbl_local")
.groupBy("value").count().filter("count != 2"),
Seq.empty
)

} finally {
runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_dist ON CLUSTER $cluster")
runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_local ON CLUSTER $cluster")
runClickHouseSQL(s"DROP DATABASE IF EXISTS $db ON CLUSTER $cluster")
}
}

Seq(
// wait for SPARK-44180 to be fixed, then add implicit cast test cases
("toYear", Array("create_date")),
// ("toYear", Array("create_time")),
("toYYYYMM", Array("create_date")),
// ("toYYYYMM", Array("create_time")),
("toYYYYMMDD", Array("create_date")),
// ("toYYYYMMDD", Array("create_time")),
("toHour", Array("create_time")),
("xxHash64", Array("value")),
("murmurHash2_64", Array("value")),
("murmurHash2_32", Array("value")),
("murmurHash3_64", Array("value")),
("murmurHash3_32", Array("value")),
("cityHash64", Array("value")),
("modulo", Array("toYYYYMM(create_date)", "10"))
).foreach {
case (func_name: String, func_args: Array[String]) =>
test(s"shard by $func_name(${func_args.mkString(",")})")(runTest(func_name, func_args))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object ExprUtils extends SQLConfHelper with Serializable {
functionRegistry: FunctionRegistry
): Array[V2SortOrder] =
toSparkSplits(
shardingKeyIgnoreRand,
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight)),
partitionKey,
functionRegistry
).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
Expand Down Expand Up @@ -216,4 +216,7 @@ object ExprUtils extends SQLConfHelper with Serializable {
case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
case other: Transform => throw CHClientException(s"Unsupported transform: $other")
}

def toSplitWithModulo(shardingKey: Expr, weight: Int): FuncExpr =
FuncExpr("modulo", List(shardingKey, StringLiteral(weight.toString)))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.time.LocalDate
import java.time.format.DateTimeFormatter

object Days extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_days"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toYYYYMMDD")

override def description: String = s"$name: (date: Date) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, DateType, _, _)) => this
// case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 DATE argument. $description")
}

override def inputTypes: Array[DataType] = Array(DateType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(days: Int): Int = {
val date = LocalDate.ofEpochDay(days)
val formatter = DateTimeFormatter.ofPattern("yyyyMMdd")
date.format(formatter).toInt
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ object StaticFunctionRegistry extends FunctionRegistry {
"clickhouse_murmurHash2_64" -> MurmurHash2_64,
"clickhouse_murmurHash3_32" -> MurmurHash3_32,
"clickhouse_murmurHash3_64" -> MurmurHash3_64,
"clickhouse_cityHash64" -> CityHash64
"clickhouse_cityHash64" -> CityHash64,
"clickhouse_years" -> Years,
"clickhouse_months" -> Months,
"clickhouse_days" -> Days,
"clickhouse_hours" -> Hours,
"sharding_mod" -> Mod
)

override def list: Array[String] = functions.keys.toArray
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.sql.Timestamp
import java.text.SimpleDateFormat

object Hours extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_hours"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toHour", "HOUR")

override def description: String = s"$name: (time: timestamp) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 TIMESTAMP argument. $description")
}

override def inputTypes: Array[DataType] = Array(TimestampType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(time: Long): Int = {
val ts = new Timestamp(time / 1000)
val formatter: SimpleDateFormat = new SimpleDateFormat("hh")
formatter.format(ts).toInt
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

object Mod extends UnboundFunction with ScalarFunction[Long] with ClickhouseEquivFunction {

override def name: String = "sharding_mod"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

// remainder is not a Clickhouse function, but modulo will be parsed to remainder in the connector.
// Added remainder as a synonym.
override val ckFuncNames: Array[String] = Array("modulo", "remainder")

override def description: String = s"$name: (a: long, b: long) => mod: long"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(a, b) if
(a match {
case StructField(_, LongType, _, _) => true
case StructField(_, IntegerType, _, _) => true
case StructField(_, ShortType, _, _) => true
case StructField(_, ByteType, _, _) => true
case StructField(_, StringType, _, _) => true
case _ => false
}) &&
(b match {
case StructField(_, LongType, _, _) => true
case StructField(_, IntegerType, _, _) => true
case StructField(_, ShortType, _, _) => true
case StructField(_, ByteType, _, _) => true
case StructField(_, StringType, _, _) => true
case _ => false
}) =>
this
case _ => throw new UnsupportedOperationException(s"Expect 2 integer arguments. $description")
}

override def inputTypes: Array[DataType] = Array(LongType, LongType)

override def resultType: DataType = LongType

override def isResultNullable: Boolean = false

def invoke(a: Long, b: Long): Long = a % b
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.time.LocalDate
import java.time.format.DateTimeFormatter

object Months extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_months"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toYYYYMM")

override def description: String = s"$name: (date: Date) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, DateType, _, _)) => this
// case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 DATE argument. $description")
}

override def inputTypes: Array[DataType] = Array(DateType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(days: Int): Int = {
val date = LocalDate.ofEpochDay(days)
val formatter = DateTimeFormatter.ofPattern("yyyyMM")
date.format(formatter).toInt
}
}
Loading

0 comments on commit c8b61f5

Please sign in to comment.