Skip to content

Commit

Permalink
[GLUTEN-3922][CH] Fix incorrect shuffle hash id value when executing …
Browse files Browse the repository at this point in the history
…modulo

Fix incorrect shuffle hash id value when executing modulo.
In CH Backend, the data type of the shuffle split num is a UInt32 and the returned type of the hash function is a UInt64, when the returned value of the hash function is more than 2^31 - 1, the modulo value of the hash value and the shuffle split num is different from the one of the vanilla spark.

Close apache#3922.
  • Loading branch information
zzcclp authored and loneylee committed Dec 7, 2023
1 parent f31cc82 commit 81855a6
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ class GlutenClickHouseTPCHParquetBucketSuite
| CLUSTERED BY (c_custkey) SORTED BY (c_custkey) INTO 2 BUCKETS;
|""".stripMargin)

val customerData1 = bucketTableDataPath + "/customer_6_buckets"
spark.sql(s"DROP TABLE IF EXISTS customer_6_buckets")
spark.sql(s"""
| CREATE EXTERNAL TABLE IF NOT EXISTS customer_6_buckets (
| c_custkey bigint,
| c_name string,
| c_address string,
| c_nationkey bigint,
| c_phone string,
| c_acctbal double,
| c_mktsegment string,
| c_comment string)
| USING PARQUET
| LOCATION '$customerData1'
| CLUSTERED BY (c_custkey) SORTED BY (c_custkey) INTO 6 BUCKETS;
|""".stripMargin)

spark.sql(s"""
|INSERT INTO customer_6_buckets SELECT * FROM customer;
|""".stripMargin)

val lineitemData = bucketTableDataPath + "/lineitem"
spark.sql(s"DROP TABLE IF EXISTS lineitem")
spark.sql(s"""
Expand Down Expand Up @@ -155,6 +176,28 @@ class GlutenClickHouseTPCHParquetBucketSuite
| CLUSTERED BY (o_orderkey) SORTED BY (o_orderkey, o_orderdate) INTO 2 BUCKETS;
|""".stripMargin)

val ordersData1 = bucketTableDataPath + "/orders_6_buckets"
spark.sql(s"DROP TABLE IF EXISTS orders_6_buckets")
spark.sql(s"""
| CREATE EXTERNAL TABLE IF NOT EXISTS orders_6_buckets (
| o_orderkey bigint,
| o_custkey bigint,
| o_orderstatus string,
| o_totalprice double,
| o_orderdate date,
| o_orderpriority string,
| o_clerk string,
| o_shippriority bigint,
| o_comment string)
| USING PARQUET
| LOCATION '$ordersData1'
| CLUSTERED BY (o_orderkey) SORTED BY (o_orderkey, o_orderdate) INTO 6 BUCKETS;
|""".stripMargin)

spark.sql(s"""
|INSERT INTO orders_6_buckets SELECT * FROM orders;
|""".stripMargin)

val partData = bucketTableDataPath + "/part"
spark.sql(s"DROP TABLE IF EXISTS part")
spark.sql(s"""
Expand Down Expand Up @@ -208,7 +251,7 @@ class GlutenClickHouseTPCHParquetBucketSuite
| show tables;
|""".stripMargin)
.collect()
assert(result.length == 8)
assert(result.length == 10)
}

test("TPCH Q1") {
Expand Down Expand Up @@ -498,5 +541,30 @@ class GlutenClickHouseTPCHParquetBucketSuite
}
)
}

test("GLUTEN-3922: Fix incorrect shuffle hash id value when executing modulo") {
val SQL =
"""
|SELECT
| c_custkey, o_custkey, hash(o_custkey), pmod(hash(o_custkey), 12),
| pmod(hash(o_custkey), 4)
|FROM
| customer_6_buckets,
| orders_6_buckets
|WHERE
| c_mktsegment = 'BUILDING'
| AND c_custkey = o_custkey
| AND o_orderdate < date'1995-03-15'
|ORDER BY
| o_custkey DESC,
| c_custkey
|LIMIT 100;
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL,
true,
df => {}
)
}
}
// scalastyle:on line.size.limit
22 changes: 20 additions & 2 deletions cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,26 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block)
}
else
{
for (size_t i = 0; i < rows; i++)
partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num));
if (hash_function_name == "sparkMurmurHash3_32")
{
auto parts_num_int32 = static_cast<Int32>(parts_num);
for (size_t i = 0; i < rows; i++)
{
// cast to int32 to be the same as the data type of the vanilla spark
auto hash_int32 = static_cast<Int32>(hash_column->get64(i));
auto res = hash_int32 % parts_num_int32;
if (res < 0)
{
res += parts_num_int32;
}
partition_ids.emplace_back(static_cast<UInt64>(res));
}
}
else
{
for (size_t i = 0; i < rows; i++)
partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num));
}
}
return PartitionInfo::fromSelector(std::move(partition_ids), parts_num);
}
Expand Down

0 comments on commit 81855a6

Please sign in to comment.