Skip to content

Commit

Permalink
[VL] Fix RoundRobinPartitioner by setting start partition id (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 authored Nov 27, 2023
1 parent f5d62cb commit e6dd56e
Show file tree
Hide file tree
Showing 20 changed files with 74 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ShuffleDependency
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ColumnarShuffleDependency
import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils}
import org.apache.spark.shuffle.utils.RangePartitionerBoundsGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, BoundReference, UnsafeProjection, UnsafeRow}
Expand Down Expand Up @@ -198,11 +198,12 @@ object CHExecUtil extends Logging {
}

new NativePartitioning(
"hash",
GlutenShuffleUtils.HashPartitioningShortName,
partitoining.numPartitions,
Array.empty[Byte],
hashFields.mkString(",").getBytes(),
outputFields.mkString(",").getBytes())
outputFields.mkString(",").getBytes()
)
}

private def buildPartitioningOptions(nativePartitioning: NativePartitioning): IteratorOptions = {
Expand Down Expand Up @@ -244,9 +245,19 @@ object CHExecUtil extends Logging {
}
val nativePartitioning: NativePartitioning = newPartitioning match {
case SinglePartition =>
new NativePartitioning("single", 1, Array.empty[Byte], Array.empty[Byte], requiredFields)
new NativePartitioning(
GlutenShuffleUtils.SinglePartitioningShortName,
1,
Array.empty[Byte],
Array.empty[Byte],
requiredFields)
case RoundRobinPartitioning(n) =>
new NativePartitioning("rr", n, Array.empty[Byte], Array.empty[Byte], requiredFields)
new NativePartitioning(
GlutenShuffleUtils.RoundRobinPartitioningShortName,
n,
Array.empty[Byte],
Array.empty[Byte],
requiredFields)
case HashPartitioning(_, _) =>
buildHashPartitioning(
newPartitioning.asInstanceOf[HashPartitioning],
Expand Down Expand Up @@ -282,7 +293,7 @@ object CHExecUtil extends Logging {
Seq[Int]()
}
new NativePartitioning(
"range",
GlutenShuffleUtils.RangePartitioningShortName,
numPartitions,
Array.empty[Byte],
orderingAndRangeBounds.getBytes(),
Expand Down
2 changes: 2 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
jdouble reallocThreshold,
jlong firstBatchHandle,
jlong taskAttemptId,
jint startPartitionId,
jint pushBufferMaxSize,
jobject partitionPusher,
jstring partitionWriterTypeJstr) {
Expand Down Expand Up @@ -825,6 +826,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
}

shuffleWriterOptions.task_attempt_id = (int64_t)taskAttemptId;
shuffleWriterOptions.start_partition_id = startPartitionId;
shuffleWriterOptions.compression_threshold = bufferCompressThreshold;

auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE);
Expand Down
1 change: 1 addition & 0 deletions cpp/core/shuffle/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct ShuffleWriterOptions {

int64_t thread_id = -1;
int64_t task_attempt_id = -1;
int32_t start_partition_id = 0;

arrow::ipc::IpcWriteOptions ipc_write_options = arrow::ipc::IpcWriteOptions::Defaults();

Expand Down
5 changes: 3 additions & 2 deletions cpp/core/shuffle/Partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@

namespace gluten {

arrow::Result<std::shared_ptr<Partitioner>> Partitioner::make(Partitioning partitioning, int32_t numPartitions) {
arrow::Result<std::shared_ptr<Partitioner>>
Partitioner::make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId) {
switch (partitioning) {
case Partitioning::kHash:
return std::make_shared<HashPartitioner>(numPartitions);
case Partitioning::kRoundRobin:
return std::make_shared<RoundRobinPartitioner>(numPartitions);
return std::make_shared<RoundRobinPartitioner>(numPartitions, startPartitionId);
case Partitioning::kSingle:
return std::make_shared<SinglePartitioner>();
case Partitioning::kRange:
Expand Down
3 changes: 2 additions & 1 deletion cpp/core/shuffle/Partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ namespace gluten {

class Partitioner {
public:
static arrow::Result<std::shared_ptr<Partitioner>> make(Partitioning partitioning, int32_t numPartitions);
static arrow::Result<std::shared_ptr<Partitioner>>
make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId);

// Whether the first column is partition key.
bool hasPid() const {
Expand Down
18 changes: 3 additions & 15 deletions cpp/core/shuffle/RoundRobinPartitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,9 @@ arrow::Status gluten::RoundRobinPartitioner::compute(
std::fill(std::begin(partition2RowCount), std::end(partition2RowCount), 0);
row2Partition.resize(numRows);

int32_t pidSelection = pidSelection_;
for (int32_t i = 0; i < numRows;) {
int32_t low = i;
int32_t up = std::min((int64_t)(i + (numPartitions_ - pidSelection)), numRows);
for (; low != up;) {
row2Partition[low++] = pidSelection++;
}

pidSelection_ = pidSelection;
pidSelection = 0;
i = up;
}

if (pidSelection_ >= numPartitions_) {
pidSelection_ -= numPartitions_;
for (int32_t i = 0; i < numRows; ++i) {
row2Partition[i] = pidSelection_;
pidSelection_ = (pidSelection_ + 1) % numPartitions_;
}

for (auto& pid : row2Partition) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/core/shuffle/RoundRobinPartitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace gluten {

class RoundRobinPartitioner final : public Partitioner {
public:
RoundRobinPartitioner(int32_t numPartitions) : Partitioner(numPartitions, false) {}
RoundRobinPartitioner(int32_t numPartitions, int32_t startPartitionId)
: Partitioner(numPartitions, false), pidSelection_(startPartitionId % numPartitions) {}

arrow::Status compute(
const int32_t* pidArr,
Expand Down
20 changes: 10 additions & 10 deletions cpp/core/tests/RoundRobinPartitionerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
namespace gluten {
class RoundRobinPartitionerTest : public ::testing::Test {
protected:
void prepareData(int numPart) {
partitioner_ = std::make_shared<RoundRobinPartitioner>(numPart);
void prepareData(int numPart, int seed) {
partitioner_ = std::make_shared<RoundRobinPartitioner>(numPart, seed);
row2Partition_.clear();
partition2RowCount_.clear();
partition2RowCount_.resize(numPart);
Expand Down Expand Up @@ -62,18 +62,18 @@ class RoundRobinPartitionerTest : public ::testing::Test {
};

TEST_F(RoundRobinPartitionerTest, TestInit) {
int numPart = 0;
prepareData(numPart);
int numPart = 2;
prepareData(numPart, 3);
ASSERT_NE(partitioner_, nullptr);
int32_t pidSelection = getPidSelection();
ASSERT_EQ(pidSelection, 0);
ASSERT_EQ(pidSelection, 1);
}

TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) {
// numRows equal numPart
{
int numPart = 10;
prepareData(numPart);
prepareData(numPart, 0);
int numRows = 10;
ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok());
ASSERT_EQ(getPidSelection(), 0);
Expand All @@ -85,7 +85,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) {
// numRows less than numPart
{
int numPart = 10;
prepareData(numPart);
prepareData(numPart, 0);
int numRows = 8;
ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok());
ASSERT_EQ(getPidSelection(), 8);
Expand All @@ -99,7 +99,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) {
// numRows greater than numPart
{
int numPart = 10;
prepareData(numPart);
prepareData(numPart, 0);
int numRows = 12;
ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok());
ASSERT_EQ(getPidSelection(), 2);
Expand All @@ -113,7 +113,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) {
// numRows greater than 2*numPart
{
int numPart = 10;
prepareData(numPart);
prepareData(numPart, 0);
int numRows = 22;
ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok());
ASSERT_EQ(getPidSelection(), 2);
Expand All @@ -127,7 +127,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) {

TEST_F(RoundRobinPartitionerTest, TestComoputeContinuous) {
int numPart = 10;
prepareData(numPart);
prepareData(numPart, 0);

{
int numRows = 8;
Expand Down
3 changes: 2 additions & 1 deletion cpp/velox/shuffle/VeloxShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ arrow::Status VeloxShuffleWriter::init() {
VELOX_CHECK_NOT_NULL(options_.memory_pool);

ARROW_ASSIGN_OR_RAISE(partitionWriter_, partitionWriterCreator_->make(this));
ARROW_ASSIGN_OR_RAISE(partitioner_, Partitioner::make(options_.partitioning, numPartitions_));
ARROW_ASSIGN_OR_RAISE(
partitioner_, Partitioner::make(options_.partitioning, numPartitions_, options_.start_partition_id));

// pre-allocated buffer size for each partition, unit is row count
// when partitioner is SinglePart, partial variables don`t need init
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
.getNativeInstanceHandle,
handle,
context.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId),
"celeborn",
GlutenConfig.getConf.columnarShuffleReallocThreshold
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,28 @@ package org.apache.spark.shuffle

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.vectorized.NativePartitioning

import org.apache.spark.SparkConf
import org.apache.spark.internal.config._
import org.apache.spark.util.random.XORShiftRandom

import java.util.Locale

object GlutenShuffleUtils {
val SinglePartitioningShortName = "single"
val RoundRobinPartitioningShortName = "rr"
val HashPartitioningShortName = "hash"
val RangePartitioningShortName = "range"

def getStartPartitionId(partition: NativePartitioning, partitionId: Int): Int = {
partition.getShortName match {
case RoundRobinPartitioningShortName =>
new XORShiftRandom(partitionId).nextInt(partition.getNumPartitions)
case _ => 0
}
}

def checkCodecValues(codecConf: String, codec: String, validValues: Set[String]): Unit = {
if (!validValues.contains(codec)) {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ public long make(
boolean writeEOS,
double reallocThreshold,
long handle,
long taskAttemptId) {
long taskAttemptId,
int startPartitionId) {
return nativeMake(
part.getShortName(),
part.getNumPartitions(),
Expand All @@ -81,6 +82,7 @@ public long make(
reallocThreshold,
handle,
taskAttemptId,
startPartitionId,
0,
null,
"local");
Expand All @@ -105,6 +107,7 @@ public long makeForRSS(
long memoryManagerHandle,
long handle,
long taskAttemptId,
int startPartitionId,
String partitionWriterType,
double reallocThreshold) {
return nativeMake(
Expand All @@ -123,6 +126,7 @@ public long makeForRSS(
reallocThreshold,
handle,
taskAttemptId,
startPartitionId,
pushBufferMaxSize,
pusher,
partitionWriterType);
Expand All @@ -144,6 +148,7 @@ public native long nativeMake(
double reallocThreshold,
long handle,
long taskAttemptId,
int startPartitionId,
int pushBufferMaxSize,
Object pusher,
String partitionWriterType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ class ColumnarShuffleWriter[K, V](
writeEOS,
reallocThreshold,
handle,
taskContext.taskAttemptId()
taskContext.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId)
)
}
val startTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import io.glutenproject.vectorized.{ArrowWritableColumnVector, NativeColumnarToR
import org.apache.spark.{Partitioner, RangePartitioner, ShuffleDependency}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ColumnarShuffleDependency
import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
Expand Down Expand Up @@ -146,14 +146,14 @@ object ExecUtil {

val nativePartitioning: NativePartitioning = newPartitioning match {
case SinglePartition =>
new NativePartitioning("single", 1)
new NativePartitioning(GlutenShuffleUtils.SinglePartitioningShortName, 1)
case RoundRobinPartitioning(n) =>
new NativePartitioning("rr", n)
new NativePartitioning(GlutenShuffleUtils.RoundRobinPartitioningShortName, n)
case HashPartitioning(exprs, n) =>
new NativePartitioning("hash", n)
new NativePartitioning(GlutenShuffleUtils.HashPartitioningShortName, n)
// range partitioning fall back to row-based partition id computation
case RangePartitioning(orders, n) =>
new NativePartitioning("range", n)
new NativePartitioning(GlutenShuffleUtils.RangePartitioningShortName, n)
}

val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,6 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenDataSourceStrategySuite]
enableSuite[GlutenDataSourceSuite]
enableSuite[GlutenFileFormatWriterSuite]
.excludeByPrefix("empty file should be skipped while write to file")
enableSuite[GlutenFileIndexSuite]
enableSuite[GlutenParquetCodecSuite]
// Unsupported compression codec.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,4 @@ import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
class GlutenFileFormatWriterSuite
extends FileFormatWriterSuite
with GlutenSQLTestsBaseTrait
with CodegenInterpretedPlanTest {

test("gluten empty file should be skipped while write to file") {
withTempPath {
path =>
spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
val partFiles = path
.listFiles()
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
// result only one row, gluten result is more reasonable
assert(partFiles.length === 1)
}
}
}
with CodegenInterpretedPlanTest {}
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,6 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenDataSourceStrategySuite]
enableSuite[GlutenDataSourceSuite]
enableSuite[GlutenFileFormatWriterSuite]
.excludeByPrefix("empty file should be skipped while write to file")
enableSuite[GlutenFileIndexSuite]
enableSuite[GlutenFileMetadataStructSuite]
enableSuite[GlutenParquetV1AggregatePushDownSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,4 @@ import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
class GlutenFileFormatWriterSuite
extends FileFormatWriterSuite
with GlutenSQLTestsBaseTrait
with CodegenInterpretedPlanTest {

test("gluten empty file should be skipped while write to file") {
withTempPath {
path =>
spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
val partFiles = path
.listFiles()
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
// result only one row, gluten result is more reasonable
assert(partFiles.length === 1)
}
}
}
with CodegenInterpretedPlanTest {}
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,6 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenDataSourceStrategySuite]
enableSuite[GlutenDataSourceSuite]
enableSuite[GlutenFileFormatWriterSuite]
.excludeByPrefix("empty file should be skipped while write to file")
enableSuite[GlutenFileIndexSuite]
enableSuite[GlutenFileMetadataStructSuite]
.exclude("SPARK-41896: Filter on row_index and a stored column at the same time")
Expand Down
Loading

0 comments on commit e6dd56e

Please sign in to comment.