Skip to content

Commit

Permalink
[GLUTEN-3922][CH] Fix incorrect shuffle hash id value when executing …
Browse files Browse the repository at this point in the history
…modulo (apache#3923)

Fix incorrect shuffle hash id value when executing modulo.
In CH Backend, the data type of the shuffle split num is a UInt32 and the returned type of the hash function is a UInt64, when the returned value of the hash function is more than 2^31 - 1, the modulo value of the hash value and the shuffle split num is different from the one of the vanilla spark.

Close apache#3922.
  • Loading branch information
zzcclp authored and loneylee committed Dec 7, 2023
1 parent 144df96 commit f98ebeb
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ class GlutenClickHouseTPCHParquetBucketSuite
| CLUSTERED BY (c_custkey) SORTED BY (c_custkey) INTO 2 BUCKETS;
|""".stripMargin)

val customerData1 = bucketTableDataPath + "/customer_6_buckets"
spark.sql(s"DROP TABLE IF EXISTS customer_6_buckets")
spark.sql(s"""
| CREATE EXTERNAL TABLE IF NOT EXISTS customer_6_buckets (
| c_custkey bigint,
| c_name string,
| c_address string,
| c_nationkey bigint,
| c_phone string,
| c_acctbal double,
| c_mktsegment string,
| c_comment string)
| USING PARQUET
| LOCATION '$customerData1'
| CLUSTERED BY (c_custkey) SORTED BY (c_custkey) INTO 6 BUCKETS;
|""".stripMargin)

spark.sql(s"""
|INSERT INTO customer_6_buckets SELECT * FROM customer;
|""".stripMargin)

val lineitemData = bucketTableDataPath + "/lineitem"
spark.sql(s"DROP TABLE IF EXISTS lineitem")
spark.sql(s"""
Expand Down Expand Up @@ -155,6 +176,28 @@ class GlutenClickHouseTPCHParquetBucketSuite
| CLUSTERED BY (o_orderkey) SORTED BY (o_orderkey, o_orderdate) INTO 2 BUCKETS;
|""".stripMargin)

val ordersData1 = bucketTableDataPath + "/orders_6_buckets"
spark.sql(s"DROP TABLE IF EXISTS orders_6_buckets")
spark.sql(s"""
| CREATE EXTERNAL TABLE IF NOT EXISTS orders_6_buckets (
| o_orderkey bigint,
| o_custkey bigint,
| o_orderstatus string,
| o_totalprice double,
| o_orderdate date,
| o_orderpriority string,
| o_clerk string,
| o_shippriority bigint,
| o_comment string)
| USING PARQUET
| LOCATION '$ordersData1'
| CLUSTERED BY (o_orderkey) SORTED BY (o_orderkey, o_orderdate) INTO 6 BUCKETS;
|""".stripMargin)

spark.sql(s"""
|INSERT INTO orders_6_buckets SELECT * FROM orders;
|""".stripMargin)

val partData = bucketTableDataPath + "/part"
spark.sql(s"DROP TABLE IF EXISTS part")
spark.sql(s"""
Expand Down Expand Up @@ -208,7 +251,7 @@ class GlutenClickHouseTPCHParquetBucketSuite
| show tables;
|""".stripMargin)
.collect()
assert(result.length == 8)
assert(result.length == 10)
}

test("TPCH Q1") {
Expand Down Expand Up @@ -498,5 +541,30 @@ class GlutenClickHouseTPCHParquetBucketSuite
}
)
}

test("GLUTEN-3922: Fix incorrect shuffle hash id value when executing modulo") {
val SQL =
"""
|SELECT
| c_custkey, o_custkey, hash(o_custkey), pmod(hash(o_custkey), 12),
| pmod(hash(o_custkey), 4)
|FROM
| customer_6_buckets,
| orders_6_buckets
|WHERE
| c_mktsegment = 'BUILDING'
| AND c_custkey = o_custkey
| AND o_orderdate < date'1995-03-15'
|ORDER BY
| o_custkey DESC,
| c_custkey
|LIMIT 100;
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL,
true,
df => {}
)
}
}
// scalastyle:on line.size.limit
22 changes: 20 additions & 2 deletions cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,26 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block)
}
else
{
for (size_t i = 0; i < rows; i++)
partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num));
if (hash_function_name == "sparkMurmurHash3_32")
{
auto parts_num_int32 = static_cast<Int32>(parts_num);
for (size_t i = 0; i < rows; i++)
{
// cast to int32 to be the same as the data type of the vanilla spark
auto hash_int32 = static_cast<Int32>(hash_column->get64(i));
auto res = hash_int32 % parts_num_int32;
if (res < 0)
{
res += parts_num_int32;
}
partition_ids.emplace_back(static_cast<UInt64>(res));
}
}
else
{
for (size_t i = 0; i < rows; i++)
partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num));
}
}
return PartitionInfo::fromSelector(std::move(partition_ids), parts_num);
}
Expand Down

0 comments on commit f98ebeb

Please sign in to comment.