Skip to content

Commit

Permalink
[GLUTEN-8244][CORE] Softaffinity use consistent hash schedule (#8245)
Browse files Browse the repository at this point in the history
  • Loading branch information
yikf authored Dec 23, 2024
1 parent b0e9a04 commit 6c38842
Show file tree
Hide file tree
Showing 7 changed files with 497 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class GlutenClickhouseMergetreeSoftAffinitySuite
locationDistinct
)

assertResult(3)(partitions.size)
assertResult(2)(partitions.size)

for (partition <- partitions) {
val names =
Expand All @@ -100,7 +100,7 @@ class GlutenClickhouseMergetreeSoftAffinitySuite
.map(_.name.toInt)
.sorted
.mkString(",")
assert(names == "1,4,7,10" | names == "2,5,8" || names == "3,6,9")
assert(names == "1,2,3,5,6,7" | names == "4,8,9,10")
}
}

Expand Down
284 changes: 284 additions & 0 deletions gluten-core/src/main/java/org/apache/gluten/hash/ConsistentHash.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.hash;

import com.google.common.base.Preconditions;
import org.apache.commons.codec.digest.MurmurHash3;

import javax.annotation.concurrent.ThreadSafe;

import java.util.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* A consistent hash ring implementation. the class is thread-safe.
*
* <p>It is mainly used to request across multiple nodes in a distributed system in a more balanced
* and stable way, minimizing the impact when nodes are added or removed.
*
* <p>The ConsistentHash support virtual nodes by replication of the physical nodes, called
* partition. It allocates the partition to the ring by hashing the partition key to a slot in the
* ring.
*
* @param <T> the type of node to be used in the ring.
*/
@ThreadSafe
public class ConsistentHash<T extends ConsistentHash.Node> {

// read-write lock for the ring.
private final ReadWriteLock lock = new ReentrantReadWriteLock(true);

// keep the mapping of node to its partitions, each partition is a slot in the ring.
private final Map<T, Set<Partition<T>>> nodes = new HashMap<>();

// keep the mapping of slot to partition, the partition actually is a virtual node.
private final SortedMap<Long, Partition<T>> ring = new TreeMap<>();

// the number of virtual nodes for each physical node.
private final int replicate;

// the hasher function to hash the key.
private final Hasher hasher;

public ConsistentHash(int replicate) {
Preconditions.checkArgument(replicate > 0, "HashRing require positive replicate number.");
this.replicate = replicate;
this.hasher =
(key, seed) -> {
byte[] data = key.getBytes();
return MurmurHash3.hash32x86(data, 0, data.length, seed);
};
}

public ConsistentHash(int replicate, Hasher hasher) {
Preconditions.checkArgument(replicate > 0, "HashRing require positive replicate number.");
Preconditions.checkArgument(hasher != null, "HashRing require non-null hasher.");
this.replicate = replicate;
this.hasher = hasher;
}

/**
* Add a node to the ring, the node will be replicated to `replicate` virtual nodes.
*
* @param node the node to be added to the ring.
* @return true if the node is added successfully, false otherwise.
*/
public boolean addNode(T node) {
lock.writeLock().lock();
try {
return add(node);
} finally {
lock.writeLock().unlock();
}
}

/**
* Remove the node from the ring, all the virtual nodes will be removed.
*
* @param node the node to be removed.
* @return true if the node is removed successfully, false otherwise.
*/
public boolean removeNode(T node) {
lock.writeLock().lock();
boolean removed = false;
try {
if (nodes.containsKey(node)) {
Set<Partition<T>> partitions = nodes.remove(node);
partitions.forEach(p -> ring.remove(p.getSlot()));
removed = true;
}
} finally {
lock.writeLock().unlock();
}
return removed;
}

/**
* Allocate the node by the key, the number of nodes to be located is specified by the count.
*
* @param key the key to locate the node.
* @param count the number of nodes to be located.
* @return a set of nodes located by the key.
*/
public Set<T> allocateNodes(String key, int count) {
lock.readLock().lock();
try {
Set<T> res = new HashSet<>();
if (key != null && count > 0) {
if (count < nodes.size()) {
long slot = hasher.hash(key, 0);
Iterator<Partition<T>> it = new AllocateIterator(slot);
while (it.hasNext() && res.size() < count) {
Partition<T> part = it.next();
res.add(part.getNode());
}
} else {
res.addAll(nodes.keySet());
}
}
return res;
} finally {
lock.readLock().unlock();
}
}

/**
* Get all the nodes in the ring.
*
* @return a set of nodes in the ring.
*/
public Set<T> getNodes() {
lock.readLock().lock();
try {
return new HashSet<>(nodes.keySet());
} finally {
lock.readLock().unlock();
}
}

/**
* Get the partitions of the node. return null if the node is not exists.
*
* @param node the node to get the partitions.
* @return a set of partitions of the node.
*/
public Set<Partition<T>> getPartition(T node) {
lock.readLock().lock();
try {
return nodes.get(node);
} finally {
lock.readLock().unlock();
}
}

/**
* Check if the node is in the ring.
*
* @param node the node to be checked.
* @return true if the node is in the ring, false otherwise.
*/
public boolean contains(T node) {
lock.readLock().lock();
try {
return nodes.containsKey(node);
} finally {
lock.readLock().unlock();
}
}

/**
* Check if the ring contains the slot.
*
* @param slot the slot to be checked.
* @return true if the slot is in the ring, false otherwise.
*/
public boolean ringContain(long slot) {
lock.readLock().lock();
try {
return ring.containsKey(slot);
} finally {
lock.readLock().unlock();
}
}

private boolean add(T node) {
boolean added = false;
if (node != null && !nodes.containsKey(node)) {
Set<Partition<T>> partitions =
IntStream.range(0, replicate)
.mapToObj(idx -> new Partition<T>(node, idx))
.collect(Collectors.toSet());
nodes.put(node, partitions);

// allocate slot.
for (Partition<T> partition : partitions) {
long slot;
int seed = 0;
do {
slot = this.hasher.hash(partition.getPartitionKey(), seed++);
} while (ring.containsKey(slot));

partition.setSlot(slot);
ring.put(slot, partition);
}
added = true;
}
return added;
}

public static class Partition<T extends ConsistentHash.Node> {
private final T node;

private final int index;

private long slot;

public Partition(T node, int index) {
this.node = node;
this.index = index;
}

public String getPartitionKey() {
return String.format("%s:%d", node, index);
}

public T getNode() {
return node;
}

public void setSlot(long slot) {
this.slot = slot;
}

public long getSlot() {
return this.slot;
}
}

/** Base interface for the node in the ring. */
public interface Node {
String key();
}

/** Base interface for the hash function. */
public interface Hasher {
long hash(String key, int seed);
}

private class AllocateIterator implements Iterator<Partition<T>> {
private final Iterator<Partition<T>> head;
private final Iterator<Partition<T>> tail;

AllocateIterator(long slot) {
this.head = ring.headMap(slot).values().iterator();
this.tail = ring.tailMap(slot).values().iterator();
}

@Override
public boolean hasNext() {
return head.hasNext() || tail.hasNext();
}

@Override
public Partition<T> next() {
return tail.hasNext() ? tail.next() : head.next();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.gluten.softaffinity

import org.apache.gluten.GlutenConfig
import org.apache.gluten.hash.ConsistentHash
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.softaffinity.strategy.SoftAffinityStrategy
import org.apache.gluten.softaffinity.strategy.{ConsistentHashSoftAffinityStrategy, ExecutorNode}
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkEnv
Expand All @@ -38,7 +39,8 @@ abstract class AffinityManager extends LogLevelUtil with Logging {

private val resourceRWLock = new ReentrantReadWriteLock(true)

private val softAffinityAllocation = new SoftAffinityStrategy
lazy val softAffinityReplicationNum: Int =
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM_DEFAULT_VALUE

lazy val minOnTargetHosts: Int = GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE

Expand Down Expand Up @@ -88,6 +90,9 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
}
})

private val hashRing = new ConsistentHash[ExecutorNode](softAffinityReplicationNum)
private val softAffinityStrategy = new ConsistentHashSoftAffinityStrategy(hashRing)

private val rand = new Random(System.currentTimeMillis)

def totalExecutors(): Int = totalRegisteredExecutors.intValue()
Expand All @@ -107,6 +112,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
executorsSet.add(execHostId._1)
idForExecutors += execHostId
sortedIdForExecutors = idForExecutors.sortBy(_._2)
hashRing.addNode(ExecutorNode(execHostId._1, execHostId._2))
totalRegisteredExecutors.addAndGet(1)
}
logOnLevel(
Expand Down Expand Up @@ -139,6 +145,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
nodesExecutorsMap.remove(findedExecId._2)
}
sortedIdForExecutors = idForExecutors.sortBy(_._2)
hashRing.removeNode(ExecutorNode(execId, findedExecId._2))
totalRegisteredExecutors.addAndGet(-1)
}
logOnLevel(
Expand Down Expand Up @@ -237,7 +244,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
if (nodesExecutorsMap.size < 1) {
Array.empty
} else {
softAffinityAllocation.allocateExecs(file, sortedIdForExecutors)
softAffinityStrategy.allocateExecs(file, softAffinityReplicationNum)
}
} finally {
resourceRWLock.readLock().unlock()
Expand Down Expand Up @@ -302,6 +309,10 @@ object SoftAffinityManager extends AffinityManager {
GlutenConfig.GLUTEN_SOFT_AFFINITY_ENABLED_DEFAULT_VALUE
)

override lazy val softAffinityReplicationNum: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM,
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM_DEFAULT_VALUE)

override lazy val minOnTargetHosts: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS,
GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE
Expand Down
Loading

0 comments on commit 6c38842

Please sign in to comment.