Skip to content

Commit

Permalink
Merge branch 'main' into uniffle9
Browse files Browse the repository at this point in the history
  • Loading branch information
summaryzb authored Jul 30, 2024
2 parents 46168a8 + 295899c commit 014c82c
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,24 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

testWithSpecifiedSparkVersion("mask", Some("3.4")) {
runQueryAndCompare("SELECT mask(c_comment) FROM customer limit 50") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
runQueryAndCompare("SELECT mask(c_comment, 'Y') FROM customer limit 50") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
runQueryAndCompare("SELECT mask(c_comment, 'Y', 'y') FROM customer limit 50") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
runQueryAndCompare("SELECT mask(c_comment, 'Y', 'y', 'o') FROM customer limit 50") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
runQueryAndCompare("SELECT mask(c_comment, 'Y', 'y', 'o', '*') FROM customer limit 50") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

test("bit_length") {
runQueryAndCompare(
"select bit_length(c_comment), bit_length(cast(c_comment as binary))" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.gluten.extension.columnar.validator.FallbackInjects
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -1141,13 +1140,10 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
StructField("lastUpdated", LongType, true),
StructField("version", LongType, true))),
true)))
val df = spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS)
df.select(collect_set(col("txn"))).collect

df.select(min(col("txn"))).collect

df.select(max(col("txn"))).collect

spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS).createOrReplaceTempView("t1")
runQueryAndCompare("select collect_set(txn), min(txn), max(txn) from t1") {
checkGlutenOperatorMatch[HashAggregateExecTransformer]
}
}

test("drop redundant partial sort which has pre-project when offload sortAgg") {
Expand Down
10 changes: 6 additions & 4 deletions cpp/velox/shuffle/RadixSort.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

namespace gluten {

template <typename Element>
// Spark radix sort implementation. This implementation is for shuffle sort only as it removes unused
// params (desc, signed) in shuffle.
// https://github.com/apache/spark/blob/308669fc301916837bacb7c3ec1ecef93190c094/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java#L25
class RadixSort {
public:
/**
Expand All @@ -39,7 +41,7 @@ class RadixSort {
* @return The starting index of the sorted data within the given array. We return this instead
* of always copying the data back to position zero for efficiency.
*/
static int32_t sort(Element* array, size_t size, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
static int32_t sort(uint64_t* array, size_t size, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
assert(startByteIndex >= 0 && "startByteIndex should >= 0");
assert(endByteIndex <= 7 && "endByteIndex should <= 7");
assert(endByteIndex > startByteIndex);
Expand Down Expand Up @@ -75,7 +77,7 @@ class RadixSort {
* @param outIndex the starting index where sorted output data should be written.
*/
static void sortAtByte(
Element* array,
uint64_t* array,
int64_t numRecords,
std::vector<int64_t>& counts,
int32_t byteIdx,
Expand Down Expand Up @@ -103,7 +105,7 @@ class RadixSort {
* significant byte. If the byte does not need sorting the vector entry will be empty.
*/
static std::vector<std::vector<int64_t>>
getCounts(Element* array, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
getCounts(uint64_t* array, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
std::vector<std::vector<int64_t>> counts;
counts.resize(8);

Expand Down
7 changes: 2 additions & 5 deletions cpp/velox/shuffle/VeloxSortShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,9 @@ arrow::Status VeloxSortShuffleWriter::evictAllPartitions() {
{
ScopedTimer timer(&sortTime_);
if (options_.useRadixSort) {
begin = RadixSort<uint64_t>::sort(
arrayPtr_, arraySize_, numRecords, kPartitionIdStartByteIndex, kPartitionIdEndByteIndex);
begin = RadixSort::sort(arrayPtr_, arraySize_, numRecords, kPartitionIdStartByteIndex, kPartitionIdEndByteIndex);
} else {
auto ptr = arrayPtr_;
qsort(ptr, numRecords, sizeof(uint64_t), compare);
(void)ptr;
std::sort(arrayPtr_, arrayPtr_ + numRecords);
}
}

Expand Down
1 change: 0 additions & 1 deletion cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"xxhash64", "xxhash64_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"negative", "unaryminus"},
{"get_array_item", "get"}};

Expand Down
9 changes: 0 additions & 9 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1048,15 +1048,6 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait
LOG_VALIDATION_MSG("Validation failed for function " + funcName + " resolve type in AggregateRel.");
return false;
}
static const std::unordered_set<std::string> notSupportComplexTypeAggFuncs = {"set_agg", "min", "max"};
if (notSupportComplexTypeAggFuncs.find(baseFuncName) != notSupportComplexTypeAggFuncs.end() &&
exec::isRawInput(funcStep)) {
auto type = binder.tryResolveType(signature->argumentTypes()[0]);
if (type->isArray() || type->isMap() || type->isRow()) {
LOG_VALIDATION_MSG("Validation failed for function " + baseFuncName + " complex type is not supported.");
return false;
}
}

resolved = true;
break;
Expand Down
Loading

0 comments on commit 014c82c

Please sign in to comment.