From 62036bf1d0837fe5c941e5ac9653e183379f5ac8 Mon Sep 17 00:00:00 2001 From: Kerwin Zhang Date: Thu, 11 Jul 2024 15:06:10 +0800 Subject: [PATCH] [CELEBORN] Add compression for row-based shuffle (#6380) --- .../celeborn/CelebornShuffleManager.java | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java index 6b450dea75ff..a0516d177415 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java @@ -54,6 +54,12 @@ public class CelebornShuffleManager implements ShuffleManager { private static final String LOCAL_SHUFFLE_READER_KEY = "spark.sql.adaptive.localShuffleReader.enabled"; + private static final String CELEBORN_COMPRESSION_CODEC_KEY = + CelebornConf.SHUFFLE_COMPRESSION_CODEC().key(); + + private static final String SPARK_CELEBORN_COMPRESSION_CODEC_KEY = + "spark." + CELEBORN_COMPRESSION_CODEC_KEY; + private static final CelebornShuffleWriterFactory writerFactory; static { @@ -78,6 +84,8 @@ public class CelebornShuffleManager implements ShuffleManager { private final SparkConf conf; private final CelebornConf celebornConf; + private final SparkConf rowBasedConf; + private final CelebornConf rowBasedCelebornConf; // either be "{appId}_{appAttemptId}" or "{appId}" private String appUniqueId; @@ -89,6 +97,8 @@ public class CelebornShuffleManager implements ShuffleManager { ConcurrentHashMap.newKeySet(); private final CelebornShuffleFallbackPolicyRunner fallbackPolicyRunner; + private final String celebornDefaultCodec; + // for Celeborn 0.4.0 private final Object shuffleIdTracker; @@ -110,6 +120,16 @@ public CelebornShuffleManager(SparkConf conf) { CelebornUtils.createInstance(CelebornUtils.EXECUTOR_SHUFFLE_ID_TRACKER_NAME); this.throwsFetchFailure = CelebornUtils.getThrowsFetchFailure(celebornConf); + + this.celebornDefaultCodec = CelebornConf.SHUFFLE_COMPRESSION_CODEC().defaultValueString(); + + this.rowBasedConf = conf.clone(); + this.rowBasedCelebornConf = celebornConf.clone(); + if ("none" + .equalsIgnoreCase(conf.get(SPARK_CELEBORN_COMPRESSION_CODEC_KEY, celebornDefaultCodec))) { + rowBasedConf.set(SPARK_CELEBORN_COMPRESSION_CODEC_KEY, celebornDefaultCodec); + rowBasedCelebornConf.set(CELEBORN_COMPRESSION_CODEC_KEY, celebornDefaultCodec); + } } private boolean isDriver() { @@ -133,7 +153,8 @@ private SparkShuffleManager vanillaCelebornShuffleManager() { synchronized (this) { if (_vanillaCelebornShuffleManager == null) { _vanillaCelebornShuffleManager = - SparkUtils.instantiateClass(VANILLA_CELEBORN_SHUFFLE_MANAGER_NAME, conf, isDriver()); + SparkUtils.instantiateClass( + VANILLA_CELEBORN_SHUFFLE_MANAGER_NAME, rowBasedConf, isDriver()); } } } @@ -330,6 +351,10 @@ public ShuffleReader getReader( if (handle instanceof CelebornShuffleHandle) { @SuppressWarnings("unchecked") CelebornShuffleHandle h = (CelebornShuffleHandle) handle; + CelebornConf readerConf = celebornConf; + if (!(h.dependency() instanceof ColumnarShuffleDependency)) { + readerConf = rowBasedCelebornConf; + } return CelebornUtils.getCelebornShuffleReader( h, startPartition, @@ -337,7 +362,7 @@ public ShuffleReader getReader( startMapIndex, endMapIndex, context, - celebornConf, + readerConf, metrics, shuffleIdTracker); }