diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
index f7a180b6a239c..b897010d5bb38 100644
--- a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
+++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.delta.commands
import org.apache.gluten.expression.ConverterUtils
-
+import org.apache.gluten.memory.CHThreadGroup
import org.apache.spark.{TaskContext, TaskOutputFileAlreadyExistException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
@@ -38,13 +38,11 @@ import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSour
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{SerializableConfiguration, SystemClock, Utils}
-
import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import java.util.{Date, UUID}
-
import scala.collection.mutable.ArrayBuffer
object OptimizeTableCommandOverwrites extends Logging {
@@ -76,7 +74,7 @@ object OptimizeTableCommandOverwrites extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int
): MergeTreeWriteTaskResult = {
-
+ CHThreadGroup.registerNewThreadGroup()
val jobId = SparkHadoopWriterUtils.createJobID(new Date(description.jobIdInstant), sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber)
diff --git a/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala b/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
index f7a180b6a239c..b897010d5bb38 100644
--- a/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
+++ b/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.delta.commands
import org.apache.gluten.expression.ConverterUtils
-
+import org.apache.gluten.memory.CHThreadGroup
import org.apache.spark.{TaskContext, TaskOutputFileAlreadyExistException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
@@ -38,13 +38,11 @@ import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSour
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{SerializableConfiguration, SystemClock, Utils}
-
import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import java.util.{Date, UUID}
-
import scala.collection.mutable.ArrayBuffer
object OptimizeTableCommandOverwrites extends Logging {
@@ -76,7 +74,7 @@ object OptimizeTableCommandOverwrites extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int
): MergeTreeWriteTaskResult = {
-
+ CHThreadGroup.registerNewThreadGroup()
val jobId = SparkHadoopWriterUtils.createJobID(new Date(description.jobIdInstant), sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber)
diff --git a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
index 7b4c3231b8c31..ef30aaad2294f 100644
--- a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
+++ b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/commands/OptimizeTableCommandOverwrites.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.delta.commands
import org.apache.gluten.expression.ConverterUtils
-
+import org.apache.gluten.memory.CHThreadGroup
import org.apache.spark.{TaskContext, TaskOutputFileAlreadyExistException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
@@ -40,13 +40,11 @@ import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSour
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{SerializableConfiguration, SystemClock, Utils}
-
import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import java.util.{Date, UUID}
-
import scala.collection.mutable.ArrayBuffer
object OptimizeTableCommandOverwrites extends Logging {
@@ -78,7 +76,7 @@ object OptimizeTableCommandOverwrites extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int
): MergeTreeWriteTaskResult = {
-
+ CHThreadGroup.registerNewThreadGroup()
val jobId = SparkHadoopWriterUtils.createJobID(new Date(description.jobIdInstant), sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber)
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/CHThreadGroup.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/CHThreadGroup.java
new file mode 100644
index 0000000000000..a06c552a9f6ba
--- /dev/null
+++ b/backends-clickhouse/src/main/java/org/apache/gluten/memory/CHThreadGroup.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.memory;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.util.TaskResource;
+import org.apache.spark.util.TaskResources;
+
+public class CHThreadGroup implements TaskResource {
+
+ /**
+ * Register a new thread group for the current task. This method should be called at beginning of
+ * the task.
+ */
+ public static void registerNewThreadGroup() {
+ if (TaskResources.isResourceRegistered(CHThreadGroup.class.getName())) return;
+ CHThreadGroup group = new CHThreadGroup();
+ TaskResources.addResource(CHThreadGroup.class.getName(), group);
+ TaskContext.get()
+ .addTaskCompletionListener(
+ (context -> {
+ context.taskMetrics().incPeakExecutionMemory(group.getPeakMemory());
+ }));
+ }
+
+ private long thread_group_id = 0;
+ private long peak_memory = -1;
+
+ private CHThreadGroup() {
+ thread_group_id = createThreadGroup();
+ }
+
+ public long getPeakMemory() {
+ if (peak_memory < 0) {
+ peak_memory = threadGroupPeakMemory(thread_group_id);
+ }
+ return peak_memory;
+ }
+
+ @Override
+ public void release() throws Exception {
+ if (peak_memory < 0) {
+ peak_memory = threadGroupPeakMemory(thread_group_id);
+ }
+ releaseThreadGroup(thread_group_id);
+ }
+
+ @Override
+ public int priority() {
+ return TaskResource.super.priority();
+ }
+
+ @Override
+ public String resourceName() {
+ return "CHThreadGroup";
+ }
+
+ private static native long createThreadGroup();
+
+ private static native long threadGroupPeakMemory(long id);
+
+ private static native void releaseThreadGroup(long id);
+}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHManagedCHReservationListener.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHManagedCHReservationListener.java
deleted file mode 100644
index 1f560a905858c..0000000000000
--- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHManagedCHReservationListener.java
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.memory.alloc;
-
-import org.apache.gluten.GlutenConfig;
-import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
-import org.apache.gluten.memory.memtarget.MemoryTarget;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.concurrent.atomic.AtomicLong;
-
-public class CHManagedCHReservationListener implements CHReservationListener {
-
- private static final Logger LOG = LoggerFactory.getLogger(CHManagedCHReservationListener.class);
-
- private MemoryTarget target;
- private final SimpleMemoryUsageRecorder usage;
- private final boolean throwIfMemoryExceed =
- GlutenConfig.getConf().chColumnarThrowIfMemoryExceed();
- private volatile boolean open = true;
-
- private final AtomicLong currentMemory = new AtomicLong(0L);
-
- public CHManagedCHReservationListener(MemoryTarget target, SimpleMemoryUsageRecorder usage) {
- this.target = target;
- this.usage = usage;
- }
-
- @Override
- public void reserveOrThrow(long size) {
- if (!throwIfMemoryExceed) {
- reserve(size);
- return;
- }
-
- synchronized (this) {
- if (!open) {
- return;
- }
- if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("reserve memory size from native: %d", size));
- }
- long granted = target.borrow(size);
- if (granted < size) {
- target.repay(granted);
- throw new UnsupportedOperationException(
- "Not enough spark off-heap execution memory. "
- + "Acquired: "
- + size
- + ", granted: "
- + granted
- + ". "
- + "Try tweaking config option spark.memory.offHeap.size to "
- + "get larger space to run this application. ");
- }
- currentMemory.addAndGet(size);
- usage.inc(size);
- }
- }
-
- @Override
- public long reserve(long size) {
- synchronized (this) {
- if (!open) {
- return 0L;
- }
- if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("reserve memory (without exception) size from native: %d", size));
- }
- long granted = target.borrow(size);
- if (granted < size && (LOG.isWarnEnabled())) {
- LOG.warn(
- String.format(
- "Not enough spark off-heap execution memory. "
- + "Acquired: %d, granted: %d. Try tweaking config option "
- + "spark.memory.offHeap.size to get larger space "
- + "to run this application.",
- size, granted));
- }
- currentMemory.addAndGet(granted);
- usage.inc(size);
- return granted;
- }
- }
-
- @Override
- public long unreserve(long size) {
- synchronized (this) {
- if (!open) {
- return 0L;
- }
- long memoryToFree = size;
- if ((currentMemory.get() - size) < 0L) {
- if (LOG.isDebugEnabled()) {
- LOG.debug(
- String.format(
- "The current used memory' %d will be less than 0(%d) after free %d",
- currentMemory.get(), currentMemory.get() - size, size));
- }
- memoryToFree = currentMemory.get();
- }
- if (memoryToFree == 0L) {
- return memoryToFree;
- }
- if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("unreserve memory size from native: %d", memoryToFree));
- }
- target.repay(memoryToFree);
- currentMemory.addAndGet(-memoryToFree);
- usage.inc(-size);
- return memoryToFree;
- }
- }
-
- @Override
- public void inactivate() {
- synchronized (this) {
- // for some reasons, memory audit in the native code may not be 100% accurate
- // we'll allow the inaccuracy
- if (currentMemory.get() > 0) {
- unreserve(currentMemory.get());
- } else if (currentMemory.get() < 0) {
- reserve(currentMemory.get());
- }
- currentMemory.set(0L);
-
- target = null; // make it gc reachable
- open = false;
- }
- }
-
- @Override
- public long currentMemory() {
- return currentMemory.get();
- }
-}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocator.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocator.java
deleted file mode 100644
index 0df0757c8b542..0000000000000
--- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocator.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.memory.alloc;
-
-/**
- * Like {@link org.apache.gluten.vectorized.NativePlanEvaluator}, this along with {@link
- * CHNativeMemoryAllocators}, as built-in toolkit for managing native memory allocations.
- */
-public class CHNativeMemoryAllocator {
-
- private final long nativeInstanceId;
- private final CHReservationListener listener;
-
- public CHNativeMemoryAllocator(long nativeInstanceId, CHReservationListener listener) {
- this.nativeInstanceId = nativeInstanceId;
- this.listener = listener;
- }
-
- public static CHNativeMemoryAllocator getDefault() {
- return new CHNativeMemoryAllocator(getDefaultAllocator(), CHReservationListener.NOOP);
- }
-
- public static CHNativeMemoryAllocator getDefaultForUT() {
- return new CHNativeMemoryAllocator(
- createListenableAllocator(CHReservationListener.NOOP), CHReservationListener.NOOP);
- }
-
- public static CHNativeMemoryAllocator createListenable(CHReservationListener listener) {
- return new CHNativeMemoryAllocator(createListenableAllocator(listener), listener);
- }
-
- public CHReservationListener listener() {
- return listener;
- }
-
- public long getNativeInstanceId() {
- return this.nativeInstanceId;
- }
-
- public long getBytesAllocated() {
- if (this.nativeInstanceId == -1L) return 0;
- return bytesAllocated(this.nativeInstanceId);
- }
-
- public void close() {
- if (this.nativeInstanceId == -1L) return;
- releaseAllocator(this.nativeInstanceId);
- }
-
- private static native long getDefaultAllocator();
-
- private static native long createListenableAllocator(CHReservationListener listener);
-
- private static native void releaseAllocator(long allocatorId);
-
- private static native long bytesAllocated(long allocatorId);
-}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocatorManagerImpl.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocatorManagerImpl.java
deleted file mode 100644
index 2c1c2547fa818..0000000000000
--- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocatorManagerImpl.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.memory.alloc;
-
-public class CHNativeMemoryAllocatorManagerImpl implements CHNativeMemoryAllocatorManager {
- private final CHNativeMemoryAllocator managed;
-
- public CHNativeMemoryAllocatorManagerImpl(CHNativeMemoryAllocator managed) {
- this.managed = managed;
- }
-
- @Override
- public void release() {
- managed.close();
- managed.listener().inactivate();
- }
-
- @Override
- public CHNativeMemoryAllocator getManaged() {
- return managed;
- }
-
- @Override
- public int priority() {
- return 0; // lowest priority
- }
-
- @Override
- public String resourceName() {
- return "CHNativeMemoryAllocatorManager_" + managed.getNativeInstanceId();
- }
-}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java
deleted file mode 100644
index 0f30972fcd44d..0000000000000
--- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.memory.alloc;
-
-import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
-import org.apache.gluten.memory.memtarget.MemoryTargets;
-import org.apache.gluten.memory.memtarget.Spiller;
-import org.apache.gluten.memory.memtarget.Spillers;
-
-import org.apache.spark.memory.TaskMemoryManager;
-import org.apache.spark.util.TaskResources;
-
-import java.util.Collections;
-
-/**
- * Built-in toolkit for managing native memory allocations. To use the facility, one should import
- * Gluten's C++ library then create the c++ instance using following example code:
- *
- *
```c++ auto* allocator = reinterpret_cast(allocator_id); ```
- *
- * The ID "allocator_id" can be retrieved from Java API {@link
- * CHNativeMemoryAllocator#getNativeInstanceId()}.
- *
- *
FIXME: to export the native APIs in a standard way
- */
-public abstract class CHNativeMemoryAllocators {
- private CHNativeMemoryAllocators() {}
-
- private static final CHNativeMemoryAllocator GLOBAL = CHNativeMemoryAllocator.getDefault();
-
- private static CHNativeMemoryAllocatorManager createNativeMemoryAllocatorManager(
- String name,
- TaskMemoryManager taskMemoryManager,
- Spiller spiller,
- SimpleMemoryUsageRecorder usage) {
-
- CHManagedCHReservationListener rl =
- new CHManagedCHReservationListener(
- MemoryTargets.newConsumer(taskMemoryManager, name, spiller, Collections.emptyMap()),
- usage);
- return new CHNativeMemoryAllocatorManagerImpl(CHNativeMemoryAllocator.createListenable(rl));
- }
-
- public static CHNativeMemoryAllocator contextInstance() {
- if (!TaskResources.inSparkTask()) {
- return globalInstance();
- }
-
- final String id = CHNativeMemoryAllocatorManager.class.toString();
- if (!TaskResources.isResourceRegistered(id)) {
- final CHNativeMemoryAllocatorManager manager =
- createNativeMemoryAllocatorManager(
- "ContextInstance",
- TaskResources.getLocalTaskContext().taskMemoryManager(),
- Spillers.NOOP,
- TaskResources.getSharedUsage());
- TaskResources.addResource(id, manager);
- }
- return ((CHNativeMemoryAllocatorManager) TaskResources.getResource(id)).getManaged();
- }
-
- public static CHNativeMemoryAllocator contextInstanceForUT() {
- return CHNativeMemoryAllocator.getDefaultForUT();
- }
-
- public static CHNativeMemoryAllocator createSpillable(String name, Spiller spiller) {
- if (!TaskResources.inSparkTask()) {
- throw new IllegalStateException("spiller must be used in a Spark task");
- }
-
- final CHNativeMemoryAllocatorManager manager =
- createNativeMemoryAllocatorManager(
- name,
- TaskResources.getLocalTaskContext().taskMemoryManager(),
- spiller,
- TaskResources.getSharedUsage());
- TaskResources.addAnonymousResource(manager);
- // force add memory consumer to task memory manager, will release by inactivate
- manager.getManaged().listener().reserve(1);
- return manager.getManaged();
- }
-
- public static CHNativeMemoryAllocator globalInstance() {
- return GLOBAL;
- }
-}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHReservationListener.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHReservationListener.java
deleted file mode 100644
index 926c4426d3d5d..0000000000000
--- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHReservationListener.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.memory.alloc;
-
-public interface CHReservationListener {
- CHReservationListener NOOP =
- new CHReservationListener() {
- @Override
- public void reserveOrThrow(long size) {}
-
- @Override
- public long reserve(long size) {
- return 0L;
- }
-
- @Override
- public long unreserve(long size) {
- return 0L;
- }
-
- @Override
- public void inactivate() {}
-
- @Override
- public long currentMemory() {
- return 0L;
- }
- };
-
- long reserve(long size);
-
- void reserveOrThrow(long size);
-
- long unreserve(long size);
-
- void inactivate();
-
- long currentMemory();
-}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java
index 01f38cb3b90be..adcf827eaf167 100644
--- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java
+++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java
@@ -18,7 +18,7 @@
import org.apache.gluten.GlutenConfig;
import org.apache.gluten.backendsapi.BackendsApiManager;
-import org.apache.gluten.memory.alloc.CHNativeMemoryAllocators;
+import org.apache.gluten.memory.CHThreadGroup;
import org.apache.gluten.substrait.expression.ExpressionBuilder;
import org.apache.gluten.substrait.expression.StringMapNode;
import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
@@ -81,9 +81,7 @@ private static Map getNativeBackendConf() {
public static void injectWriteFilesTempPath(String path, String fileName) {
ExpressionEvaluatorJniWrapper.injectWriteFilesTempPath(
- CHNativeMemoryAllocators.contextInstance().getNativeInstanceId(),
- path.getBytes(StandardCharsets.UTF_8),
- fileName.getBytes(StandardCharsets.UTF_8));
+ path.getBytes(StandardCharsets.UTF_8), fileName.getBytes(StandardCharsets.UTF_8));
}
// Used by WholeStageTransform to create the native computing pipeline and
@@ -93,9 +91,9 @@ public static BatchIterator createKernelWithBatchIterator(
byte[][] splitInfo,
List iterList,
boolean materializeInput) {
+ CHThreadGroup.registerNewThreadGroup();
long handle =
nativeCreateKernelWithIterator(
- CHNativeMemoryAllocators.contextInstance().getNativeInstanceId(),
wsPlan,
splitInfo,
iterList.toArray(new GeneralInIterator[0]),
@@ -106,10 +104,10 @@ public static BatchIterator createKernelWithBatchIterator(
// Only for UT.
public static BatchIterator createKernelWithBatchIterator(
- long allocId, byte[] wsPlan, byte[][] splitInfo, List iterList) {
+ byte[] wsPlan, byte[][] splitInfo, List iterList) {
+ CHThreadGroup.registerNewThreadGroup();
long handle =
nativeCreateKernelWithIterator(
- allocId,
wsPlan,
splitInfo,
iterList.toArray(new GeneralInIterator[0]),
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java
index 815bf472c027a..864cc4eb70ace 100644
--- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java
+++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java
@@ -30,14 +30,9 @@ public long make(
String dataFile,
String localDirs,
int subDirsPerLocalDir,
- boolean preferSpill,
long spillThreshold,
String hashAlgorithm,
- boolean throwIfMemoryExceed,
- boolean flushBlockBufferBeforeEvict,
long maxSortBufferSize,
- boolean spillFirstlyBeforeStop,
- boolean forceExternalSort,
boolean forceMemorySort) {
return nativeMake(
part.getShortName(),
@@ -51,14 +46,9 @@ public long make(
dataFile,
localDirs,
subDirsPerLocalDir,
- preferSpill,
spillThreshold,
hashAlgorithm,
- throwIfMemoryExceed,
- flushBlockBufferBeforeEvict,
maxSortBufferSize,
- spillFirstlyBeforeStop,
- forceExternalSort,
forceMemorySort);
}
@@ -71,9 +61,6 @@ public long makeForRSS(
long spillThreshold,
String hashAlgorithm,
Object pusher,
- boolean throwIfMemoryExceed,
- boolean flushBlockBufferBeforeEvict,
- boolean forceExternalSort,
boolean forceMemorySort) {
return nativeMakeForRSS(
part.getShortName(),
@@ -87,9 +74,6 @@ public long makeForRSS(
spillThreshold,
hashAlgorithm,
pusher,
- throwIfMemoryExceed,
- flushBlockBufferBeforeEvict,
- forceExternalSort,
forceMemorySort);
}
@@ -105,14 +89,9 @@ public native long nativeMake(
String dataFile,
String localDirs,
int subDirsPerLocalDir,
- boolean preferSpill,
long spillThreshold,
String hashAlgorithm,
- boolean throwIfMemoryExceed,
- boolean flushBlockBufferBeforeEvict,
long maxSortBufferSize,
- boolean spillFirstlyBeforeStop,
- boolean forceSort,
boolean forceMemorySort);
public native long nativeMakeForRSS(
@@ -127,15 +106,10 @@ public native long nativeMakeForRSS(
long spillThreshold,
String hashAlgorithm,
Object pusher,
- boolean throwIfMemoryExceed,
- boolean flushBlockBufferBeforeEvict,
- boolean forceSort,
boolean forceMemorySort);
public native void split(long splitterId, long block);
- public native long evict(long splitterId);
-
public native CHSplitResult stop(long splitterId) throws IOException;
public native void close(long splitterId);
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/ExpressionEvaluatorJniWrapper.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/ExpressionEvaluatorJniWrapper.java
index a5a474d2a2521..e73b293d618e6 100644
--- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/ExpressionEvaluatorJniWrapper.java
+++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/ExpressionEvaluatorJniWrapper.java
@@ -32,11 +32,9 @@ public class ExpressionEvaluatorJniWrapper {
/**
* Create a native compute kernel and return a columnar result iterator.
*
- * @param allocatorId allocator id
* @return iterator instance id
*/
public static native long nativeCreateKernelWithIterator(
- long allocatorId,
byte[] wsPlan,
byte[][] splitInfo,
GeneralInIterator[] batchItr,
@@ -46,9 +44,7 @@ public static native long nativeCreateKernelWithIterator(
/**
* Set the temp path for writing files.
*
- * @param allocatorId allocator id for current task attempt(or thread)
* @param path the temp path for writing files
*/
- public static native void injectWriteFilesTempPath(
- long allocatorId, byte[] path, byte[] filename);
+ public static native void injectWriteFilesTempPath(byte[] path, byte[] filename);
}
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/SimpleExpressionEval.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/SimpleExpressionEval.java
index b09cccb4580fd..d6cfec31969b8 100644
--- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/SimpleExpressionEval.java
+++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/SimpleExpressionEval.java
@@ -17,6 +17,7 @@
package org.apache.gluten.vectorized;
import org.apache.gluten.execution.ColumnarNativeIterator;
+import org.apache.gluten.memory.CHThreadGroup;
import org.apache.gluten.substrait.plan.PlanNode;
import io.substrait.proto.Plan;
@@ -37,6 +38,7 @@ public SimpleExpressionEval(ColumnarNativeIterator blockStream, PlanNode planNod
LOG.debug(String.format("SimpleExpressionEval exec plan: %s", plan.toString()));
}
byte[] planData = plan.toByteArray();
+ CHThreadGroup.registerNewThreadGroup();
instance = createNativeInstance(blockStream, planData);
}
diff --git a/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java b/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java
index c041ee352c421..2bb3d44e0ff16 100644
--- a/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java
+++ b/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java
@@ -28,8 +28,7 @@ public native long nativeInitMergeTreeWriterWrapper(
String taskId,
String partition_dir,
String bucket_dir,
- byte[] confArray,
- long allocId);
+ byte[] confArray);
public native String nativeMergeMTParts(
byte[] plan,
diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
index d54eb59036d86..f27ee09df8f0d 100644
--- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
+++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
@@ -82,6 +82,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
listIterator,
materializeInput
)
+
}
private def createCloseIterator(
diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala
index c113f8d4dd319..db9bba5f170a3 100644
--- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala
+++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala
@@ -18,8 +18,7 @@ package org.apache.spark.shuffle
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
-import org.apache.gluten.memory.alloc.CHNativeMemoryAllocators
-import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
+import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._
import org.apache.spark.SparkEnv
@@ -54,13 +53,7 @@ class CHColumnarShuffleWriter[K, V](
private val splitSize = GlutenConfig.getConf.maxBatchSize
private val customizedCompressCodec =
GlutenShuffleUtils.getCompressionCodec(conf).toUpperCase(Locale.ROOT)
- private val preferSpill = GlutenConfig.getConf.chColumnarShufflePreferSpill
- private val throwIfMemoryExceed = GlutenConfig.getConf.chColumnarThrowIfMemoryExceed
- private val flushBlockBufferBeforeEvict =
- GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict
private val maxSortBufferSize = GlutenConfig.getConf.chColumnarMaxSortBufferSize
- private val spillFirstlyBeforeStop = GlutenConfig.getConf.chColumnarSpillFirstlyBeforeStop
- private val forceExternalSortShuffle = GlutenConfig.getConf.chColumnarForceExternalSortShuffle
private val forceMemorySortShuffle = GlutenConfig.getConf.chColumnarForceMemorySortShuffle
private val spillThreshold = GlutenConfig.getConf.chColumnarShuffleSpillThreshold
private val jniWrapper = new CHShuffleSplitterJniWrapper
@@ -81,6 +74,7 @@ class CHColumnarShuffleWriter[K, V](
@throws[IOException]
override def write(records: Iterator[Product2[K, V]]): Unit = {
+ CHThreadGroup.registerNewThreadGroup()
internalCHWrite(records)
}
@@ -108,36 +102,11 @@ class CHColumnarShuffleWriter[K, V](
dataTmp.getAbsolutePath,
localDirs,
subDirsPerLocalDir,
- preferSpill,
spillThreshold,
CHBackendSettings.shuffleHashAlgorithm,
- throwIfMemoryExceed,
- flushBlockBufferBeforeEvict,
maxSortBufferSize,
- spillFirstlyBeforeStop,
- forceExternalSortShuffle,
forceMemorySortShuffle
)
- CHNativeMemoryAllocators.createSpillable(
- "ShuffleWriter",
- new Spiller() {
- override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = {
- if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
- return 0L;
- }
- if (nativeSplitter == 0) {
- throw new IllegalStateException(
- "Fatal: spill() called before a shuffle writer " +
- "is created. This behavior should be optimized by moving memory " +
- "allocations from make() to split()")
- }
- logError(s"Gluten shuffle writer: Trying to spill $size bytes of data")
- val spilled = splitterJniWrapper.evict(nativeSplitter);
- logError(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
- spilled
- }
- }
- )
}
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala
index 9d4c26e5a47b4..06d5b152716d3 100644
--- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala
+++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.v1
import org.apache.gluten.execution.datasource.GlutenRowSplitter
+import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized.CHColumnVector
import org.apache.spark.sql.SparkSession
@@ -37,6 +38,7 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {
nativeConf: java.util.Map[String, String]): OutputWriter = {
val originPath = path
val datasourceJniWrapper = new CHDatasourceJniWrapper();
+ CHThreadGroup.registerNewThreadGroup()
val instance =
datasourceJniWrapper.nativeInitFileWriterWrapper(
path,
diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala
index e11406d566195..815879a65934f 100644
--- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala
+++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v1
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.ConverterUtils
-import org.apache.gluten.memory.alloc.CHNativeMemoryAllocators
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.`type`.ColumnTypeNode
import org.apache.gluten.substrait.SubstraitContext
@@ -97,7 +96,6 @@ class CHMergeTreeWriterInjects extends GlutenFormatWriterInjectsBase {
// use table schema instead of data schema
SparkShimLoader.getSparkShims.attributesFromStruct(tableSchema)
)
- val allocId = CHNativeMemoryAllocators.contextInstance.getNativeInstanceId
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val instance =
datasourceJniWrapper.nativeInitMergeTreeWriterWrapper(
@@ -107,8 +105,7 @@ class CHMergeTreeWriterInjects extends GlutenFormatWriterInjectsBase {
context.getTaskAttemptID.getTaskID.getId.toString,
context.getConfiguration.get("mapreduce.task.gluten.mergetree.partition.dir"),
context.getConfiguration.get("mapreduce.task.gluten.mergetree.bucketid.str"),
- buildNativeConf(nativeConf),
- allocId
+ buildNativeConf(nativeConf)
)
new MergeTreeOutputWriter(database, tableName, datasourceJniWrapper, instance, path)
diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
index ad2f3851627c0..506bdd03b4f16 100644
--- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
+++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.datasources.v1.clickhouse
+import org.apache.gluten.memory.CHThreadGroup
+
import org.apache.spark.{SparkException, TaskContext, TaskOutputFileAlreadyExistException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
@@ -263,7 +265,7 @@ object MergeTreeFileFormatWriter extends Logging {
iterator: Iterator[InternalRow],
concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]
): MergeTreeWriteTaskResult = {
-
+ CHThreadGroup.registerNewThreadGroup();
val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber)
diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocatorManager.java b/backends-clickhouse/src/test/java/org/apache/gluten/utils/TestExceptionUtils.java
similarity index 74%
rename from backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocatorManager.java
rename to backends-clickhouse/src/test/java/org/apache/gluten/utils/TestExceptionUtils.java
index 6645748e56c2c..74e2141203b03 100644
--- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocatorManager.java
+++ b/backends-clickhouse/src/test/java/org/apache/gluten/utils/TestExceptionUtils.java
@@ -14,11 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.memory.alloc;
+package org.apache.gluten.utils;
-import org.apache.spark.util.TaskResource;
-
-/** Resource manager implementation that manages a {@link CHNativeMemoryAllocator}. */
-public interface CHNativeMemoryAllocatorManager extends TaskResource {
- CHNativeMemoryAllocator getManaged();
+public class TestExceptionUtils {
+ public static native void generateNativeException();
}
diff --git a/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java b/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java
deleted file mode 100644
index 905ffacde023d..0000000000000
--- a/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.memory;
-
-import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
-import org.apache.gluten.memory.alloc.CHManagedCHReservationListener;
-import org.apache.gluten.memory.alloc.CHNativeMemoryAllocator;
-import org.apache.gluten.memory.alloc.CHNativeMemoryAllocatorManagerImpl;
-import org.apache.gluten.memory.memtarget.MemoryTargets;
-import org.apache.gluten.memory.memtarget.Spillers;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.internal.config.package$;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-import java.util.Collections;
-
-public class TestTaskMemoryManagerSuite {
- static {
- // for skip loading lib in NativeMemoryAllocator
- System.setProperty("spark.sql.testkey", "true");
- }
-
- protected TaskMemoryManager taskMemoryManager;
- protected CHManagedCHReservationListener listener;
- protected CHNativeMemoryAllocatorManagerImpl manager;
-
- @Before
- public void initMemoryManager() {
- final SparkConf conf =
- new SparkConf()
- .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), true)
- .set(package$.MODULE$.MEMORY_OFFHEAP_SIZE(), 1000L);
- taskMemoryManager = new TaskMemoryManager(new UnifiedMemoryManager(conf, 1000L, 500L, 1), 0);
-
- listener =
- new CHManagedCHReservationListener(
- MemoryTargets.newConsumer(
- taskMemoryManager, "test", Spillers.NOOP, Collections.emptyMap()),
- new SimpleMemoryUsageRecorder());
-
- manager = new CHNativeMemoryAllocatorManagerImpl(new CHNativeMemoryAllocator(-1L, listener));
- }
-
- @After
- public void destroyMemoryManager() {
- taskMemoryManager = null;
- listener = null;
- manager = null;
- }
-
- @Test
- public void testCHNativeMemoryManager() {
- listener.reserveOrThrow(100L);
- Assert.assertEquals(100L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- listener.unreserve(100L);
- Assert.assertEquals(0L, taskMemoryManager.getMemoryConsumptionForThisTask());
- }
-
- @Test
- public void testMemoryFreeLessThanMalloc() {
- listener.reserveOrThrow(100L);
- Assert.assertEquals(100L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- listener.unreserve(200L);
- Assert.assertEquals(0L, taskMemoryManager.getMemoryConsumptionForThisTask());
- }
-
- @Test
- public void testMemoryLeak() {
- listener.reserveOrThrow(100L);
- Assert.assertEquals(100L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- listener.unreserve(100L);
- Assert.assertEquals(0L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- listener.reserveOrThrow(100L);
- Assert.assertEquals(100L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- listener.reserveOrThrow(100L);
- Assert.assertEquals(200L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- try {
- manager.release();
- } catch (Exception e) {
- Assert.assertTrue(e instanceof UnsupportedOperationException);
- }
- }
-
- @Test(expected = UnsupportedOperationException.class)
- public void testAcquireLessMemory() {
- listener.reserveOrThrow(100L);
- Assert.assertEquals(100L, taskMemoryManager.getMemoryConsumptionForThisTask());
-
- listener.reserveOrThrow(1000L);
- }
-}
diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarExternalSortShuffleSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarExternalSortShuffleSuite.scala
deleted file mode 100644
index be36cd998485d..0000000000000
--- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarExternalSortShuffleSuite.scala
+++ /dev/null
@@ -1,128 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.execution
-
-import org.apache.spark.SparkConf
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-
-class GlutenClickHouseColumnarExternalSortShuffleSuite
- extends GlutenClickHouseTPCHAbstractSuite
- with AdaptiveSparkPlanHelper {
-
- override protected val tablesPath: String = basePath + "/tpch-data-ch"
- override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
- override protected val queriesResults: String = rootPath + "mergetree-queries-output"
-
- /** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */
- override protected def sparkConf: SparkConf = {
- super.sparkConf
- .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
- .set("spark.io.compression.codec", "LZ4")
- .set("spark.sql.shuffle.partitions", "5")
- .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
- .set("spark.sql.adaptive.enabled", "true")
- .set("spark.gluten.sql.columnar.backend.ch.forceExternalSortShuffle", "true")
- }
-
- test("TPCH Q1") {
- runTPCHQuery(1) { df => }
- }
-
- test("TPCH Q2") {
- runTPCHQuery(2) { df => }
- }
-
- test("TPCH Q3") {
- runTPCHQuery(3) { df => }
- }
-
- test("TPCH Q4") {
- runTPCHQuery(4) { df => }
- }
-
- test("TPCH Q5") {
- runTPCHQuery(5) { df => }
- }
-
- test("TPCH Q6") {
- runTPCHQuery(6) { df => }
- }
-
- test("TPCH Q7") {
- runTPCHQuery(7) { df => }
- }
-
- test("TPCH Q8") {
- runTPCHQuery(8) { df => }
- }
-
- test("TPCH Q9") {
- runTPCHQuery(9) { df => }
- }
-
- test("TPCH Q10") {
- runTPCHQuery(10) { df => }
- }
-
- test("TPCH Q11") {
- runTPCHQuery(11) { df => }
- }
-
- test("TPCH Q12") {
- runTPCHQuery(12) { df => }
- }
-
- test("TPCH Q13") {
- runTPCHQuery(13) { df => }
- }
-
- test("TPCH Q14") {
- runTPCHQuery(14) { df => }
- }
-
- test("TPCH Q15") {
- runTPCHQuery(15) { df => }
- }
-
- test("TPCH Q16") {
- runTPCHQuery(16, noFallBack = false) { df => }
- }
-
- test("TPCH Q17") {
- runTPCHQuery(17) { df => }
- }
-
- test("TPCH Q18") {
- runTPCHQuery(18) { df => }
- }
-
- test("TPCH Q19") {
- runTPCHQuery(19) { df => }
- }
-
- test("TPCH Q20") {
- runTPCHQuery(20) { df => }
- }
-
- test("TPCH Q21") {
- runTPCHQuery(21, noFallBack = false) { df => }
- }
-
- test("TPCH Q22") {
- runTPCHQuery(22) { df => }
- }
-}
diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeExceptionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeExceptionSuite.scala
index a0fac50598d8c..cac1a8c5b3464 100644
--- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeExceptionSuite.scala
+++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeExceptionSuite.scala
@@ -17,8 +17,7 @@
package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
-import org.apache.gluten.memory.alloc.{CHNativeMemoryAllocator, CHReservationListener}
-import org.apache.gluten.utils.UTSystemParameters
+import org.apache.gluten.utils.{TestExceptionUtils, UTSystemParameters}
import org.apache.spark.SparkConf
@@ -31,12 +30,11 @@ class GlutenClickHouseNativeExceptionSuite extends GlutenClickHouseWholeStageTra
test("native exception caught by jvm") {
try {
- val x = new CHNativeMemoryAllocator(100, CHReservationListener.NOOP)
- x.close() // this will incur a native exception
+ TestExceptionUtils.generateNativeException()
assert(false)
} catch {
case e: Exception =>
- assert(e.getMessage.contains("allocator 100 not found"))
+ assert(e.getMessage.contains("test native exception"))
}
}
}
diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHousePreferSpillColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHousePreferSpillColumnarShuffleAQESuite.scala
deleted file mode 100644
index 1884f850718ad..0000000000000
--- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHousePreferSpillColumnarShuffleAQESuite.scala
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.execution
-
-import org.apache.spark.SparkConf
-import org.apache.spark.sql.execution.CoalescedPartitionSpec
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec}
-
-class GlutenClickHousePreferSpillColumnarShuffleAQESuite
- extends GlutenClickHouseTPCHAbstractSuite
- with AdaptiveSparkPlanHelper {
-
- override protected val tablesPath: String = basePath + "/tpch-data-ch"
- override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
- override protected val queriesResults: String = rootPath + "mergetree-queries-output"
-
- /** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */
- override protected def sparkConf: SparkConf = {
- super.sparkConf
- .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
- .set("spark.io.compression.codec", "LZ4")
- .set("spark.sql.shuffle.partitions", "5")
- .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
- .set("spark.sql.adaptive.enabled", "true")
- .set("spark.gluten.sql.columnar.backend.ch.shuffle.preferSpill", "true")
- }
-
- test("TPCH Q1") {
- runTPCHQuery(1) {
- df =>
- assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
-
- val colCustomShuffleReaderExecs = collect(df.queryExecution.executedPlan) {
- case csr: AQEShuffleReadExec => csr
- }
- assert(colCustomShuffleReaderExecs.size == 2)
- val coalescedPartitionSpec0 = colCustomShuffleReaderExecs(0)
- .partitionSpecs(0)
- .asInstanceOf[CoalescedPartitionSpec]
- assert(coalescedPartitionSpec0.startReducerIndex == 0)
- assert(coalescedPartitionSpec0.endReducerIndex == 5)
- val coalescedPartitionSpec1 = colCustomShuffleReaderExecs(1)
- .partitionSpecs(0)
- .asInstanceOf[CoalescedPartitionSpec]
- assert(coalescedPartitionSpec1.startReducerIndex == 0)
- assert(coalescedPartitionSpec1.endReducerIndex == 5)
- }
- }
-
- test("TPCH Q2") {
- runTPCHQuery(2) { df => }
- }
-
- test("TPCH Q3") {
- runTPCHQuery(3) { df => }
- }
-
- test("TPCH Q4") {
- runTPCHQuery(4) { df => }
- }
-
- test("TPCH Q5") {
- runTPCHQuery(5) { df => }
- }
-
- test("TPCH Q6") {
- runTPCHQuery(6) { df => }
- }
-
- test("TPCH Q7") {
- runTPCHQuery(7) { df => }
- }
-
- test("TPCH Q8") {
- runTPCHQuery(8) { df => }
- }
-
- test("TPCH Q9") {
- runTPCHQuery(9) { df => }
- }
-
- test("TPCH Q10") {
- runTPCHQuery(10) { df => }
- }
-
- test("TPCH Q11") {
- runTPCHQuery(11) {
- df =>
- assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
- val adaptiveSparkPlanExec = collectWithSubqueries(df.queryExecution.executedPlan) {
- case adaptive: AdaptiveSparkPlanExec => adaptive
- }
- assert(adaptiveSparkPlanExec.size == 2)
- }
- }
-
- test("TPCH Q12") {
- runTPCHQuery(12) { df => }
- }
-
- test("TPCH Q13") {
- runTPCHQuery(13) { df => }
- }
-
- test("TPCH Q14") {
- runTPCHQuery(14) { df => }
- }
-
- test("TPCH Q15") {
- runTPCHQuery(15) {
- df =>
- assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
- val adaptiveSparkPlanExec = collectWithSubqueries(df.queryExecution.executedPlan) {
- case adaptive: AdaptiveSparkPlanExec => adaptive
- }
- assert(adaptiveSparkPlanExec.size == 2)
- }
- }
-
- test("TPCH Q16") {
- runTPCHQuery(16, noFallBack = false) { df => }
- }
-
- test("TPCH Q17") {
- runTPCHQuery(17) { df => }
- }
-
- test("TPCH Q18") {
- runTPCHQuery(18) { df => }
- }
-
- test("TPCH Q19") {
- runTPCHQuery(19) { df => }
- }
-
- test("TPCH Q20") {
- runTPCHQuery(20) { df => }
- }
-
- test("TPCH Q21") {
- runTPCHQuery(21, noFallBack = false) { df => }
- }
-
- test("TPCH Q22") {
- runTPCHQuery(22) {
- df =>
- assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
- val adaptiveSparkPlanExec = collectWithSubqueries(df.queryExecution.executedPlan) {
- case adaptive: AdaptiveSparkPlanExec => adaptive
- }
- assert(adaptiveSparkPlanExec.size == 3)
- assert(adaptiveSparkPlanExec(1) == adaptiveSparkPlanExec(2))
- }
- }
-}
diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseMetricsUTUtils.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseMetricsUTUtils.scala
index 3253e04bb36cc..801b60dda7f20 100644
--- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseMetricsUTUtils.scala
+++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseMetricsUTUtils.scala
@@ -17,7 +17,6 @@
package org.apache.gluten.execution.metrics
import org.apache.gluten.execution.WholeStageTransformer
-import org.apache.gluten.memory.alloc.CHNativeMemoryAllocators
import org.apache.gluten.metrics.{MetricsUtil, NativeMetrics}
import org.apache.gluten.utils.SubstraitPlanPrinterUtil
import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, GeneralInIterator}
@@ -45,9 +44,7 @@ object GlutenClickHouseMetricsUTUtils {
SubstraitPlanPrinterUtil.jsonToSubstraitPlan(
substraitPlanJsonStr.replaceAll("basePath", basePath.substring(1)))
- val mockMemoryAllocator = CHNativeMemoryAllocators.contextInstanceForUT()
val resIter = CHNativeExpressionEvaluator.createKernelWithBatchIterator(
- mockMemoryAllocator.getNativeInstanceId,
substraitPlan.toByteArray,
new Array[Array[Byte]](0),
inBatchIters)
@@ -75,7 +72,6 @@ object GlutenClickHouseMetricsUTUtils {
iter.foreach(_.toString)
resIter.close()
- mockMemoryAllocator.close()
nativeMetricsList.toSeq
}
diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
index 1b3df81667a0b..4b5a5b328cb3f 100644
--- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
+++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
@@ -23,6 +23,7 @@ import org.apache.gluten.vectorized.GeneralInIterator
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.InputIteratorTransformer
+import org.apache.spark.util.TaskResources
import scala.collection.JavaConverters._
@@ -152,57 +153,59 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
}
test("test tpch wholestage execute") {
- val inBatchIters = new java.util.ArrayList[GeneralInIterator](0)
- val outputAttributes = new java.util.ArrayList[Attribute](0)
- val nativeMetricsList = GlutenClickHouseMetricsUTUtils
- .executeSubstraitPlan(
- substraitPlansDatPath + "/tpch-q4-wholestage-2.json",
- basePath,
- inBatchIters,
- outputAttributes
- )
-
- assert(nativeMetricsList.size == 1)
- val nativeMetricsData = nativeMetricsList(0)
- assert(nativeMetricsData.metricsDataList.size() == 3)
-
- assert(nativeMetricsData.metricsDataList.get(0).getName.equals("kRead"))
- assert(
- nativeMetricsData.metricsDataList
- .get(0)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 600572)
-
- assert(nativeMetricsData.metricsDataList.get(1).getName.equals("kFilter"))
- assert(
- nativeMetricsData.metricsDataList
- .get(1)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getInputRows == 600572)
- assert(
- nativeMetricsData.metricsDataList
- .get(1)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 379809)
-
- assert(nativeMetricsData.metricsDataList.get(2).getName.equals("kProject"))
- assert(
- nativeMetricsData.metricsDataList
- .get(2)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 379809)
+ TaskResources.runUnsafe {
+ val inBatchIters = new java.util.ArrayList[GeneralInIterator](0)
+ val outputAttributes = new java.util.ArrayList[Attribute](0)
+ val nativeMetricsList = GlutenClickHouseMetricsUTUtils
+ .executeSubstraitPlan(
+ substraitPlansDatPath + "/tpch-q4-wholestage-2.json",
+ basePath,
+ inBatchIters,
+ outputAttributes
+ )
+
+ assert(nativeMetricsList.size == 1)
+ val nativeMetricsData = nativeMetricsList(0)
+ assert(nativeMetricsData.metricsDataList.size() == 3)
+
+ assert(nativeMetricsData.metricsDataList.get(0).getName.equals("kRead"))
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(0)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 600572)
+
+ assert(nativeMetricsData.metricsDataList.get(1).getName.equals("kFilter"))
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(1)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getInputRows == 600572)
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(1)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 379809)
+
+ assert(nativeMetricsData.metricsDataList.get(2).getName.equals("kProject"))
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(2)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 379809)
+ }
}
test("Check TPCH Q2 metrics updater") {
@@ -310,106 +313,108 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
}
test("GLUTEN-1754: test agg func covar_samp, covar_pop final stage execute") {
- val inBatchIters = new java.util.ArrayList[GeneralInIterator](0)
- val outputAttributes = new java.util.ArrayList[Attribute](0)
- val nativeMetricsList = GlutenClickHouseMetricsUTUtils
- .executeSubstraitPlan(
- substraitPlansDatPath + "/covar_samp-covar_pop-partial-agg-stage.json",
- basePath,
- inBatchIters,
- outputAttributes
- )
-
- assert(nativeMetricsList.size == 1)
- val nativeMetricsData = nativeMetricsList(0)
- assert(nativeMetricsData.metricsDataList.size() == 5)
-
- assert(nativeMetricsData.metricsDataList.get(0).getName.equals("kRead"))
- assert(
- nativeMetricsData.metricsDataList
- .get(0)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 600572)
-
- assert(nativeMetricsData.metricsDataList.get(1).getName.equals("kFilter"))
- assert(
- nativeMetricsData.metricsDataList
- .get(1)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getInputRows == 600572)
- assert(
- nativeMetricsData.metricsDataList
- .get(1)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 591673)
-
- assert(nativeMetricsData.metricsDataList.get(2).getName.equals("kProject"))
-
- assert(nativeMetricsData.metricsDataList.get(3).getName.equals("kProject"))
- assert(nativeMetricsData.metricsDataList.get(4).getName.equals("kAggregate"))
- assert(
- nativeMetricsData.metricsDataList
- .get(4)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getInputRows == 591673)
- assert(
- nativeMetricsData.metricsDataList
- .get(4)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 4)
-
- assert(
- nativeMetricsData.metricsDataList
- .get(4)
- .getSteps
- .get(0)
- .getProcessors
- .get(0)
- .getOutputRows == 4)
-
- val inBatchItersFinal = new java.util.ArrayList[GeneralInIterator](
- Array(0).map(iter => new ColumnarNativeIterator(Iterator.empty.asJava)).toSeq.asJava)
- val outputAttributesFinal = new java.util.ArrayList[Attribute](0)
-
- val nativeMetricsListFinal = GlutenClickHouseMetricsUTUtils
- .executeSubstraitPlan(
- substraitPlansDatPath + "/covar_samp-covar_pop-final-agg-stage.json",
- basePath,
- inBatchItersFinal,
- outputAttributesFinal
- )
-
- assert(nativeMetricsListFinal.size == 1)
- val nativeMetricsDataFinal = nativeMetricsListFinal(0)
- assert(nativeMetricsDataFinal.metricsDataList.size() == 3)
-
- assert(nativeMetricsDataFinal.metricsDataList.get(0).getName.equals("kRead"))
- assert(nativeMetricsDataFinal.metricsDataList.get(1).getName.equals("kAggregate"))
- assert(nativeMetricsDataFinal.metricsDataList.get(1).getSteps.size() == 2)
- assert(
- nativeMetricsDataFinal.metricsDataList
- .get(1)
- .getSteps
- .get(0)
- .getName
- .equals("GraceMergingAggregatedStep"))
- assert(
- nativeMetricsDataFinal.metricsDataList.get(1).getSteps.get(1).getName.equals("Expression"))
- assert(nativeMetricsDataFinal.metricsDataList.get(2).getName.equals("kProject"))
+ TaskResources.runUnsafe {
+ val inBatchIters = new java.util.ArrayList[GeneralInIterator](0)
+ val outputAttributes = new java.util.ArrayList[Attribute](0)
+ val nativeMetricsList = GlutenClickHouseMetricsUTUtils
+ .executeSubstraitPlan(
+ substraitPlansDatPath + "/covar_samp-covar_pop-partial-agg-stage.json",
+ basePath,
+ inBatchIters,
+ outputAttributes
+ )
+
+ assert(nativeMetricsList.size == 1)
+ val nativeMetricsData = nativeMetricsList(0)
+ assert(nativeMetricsData.metricsDataList.size() == 5)
+
+ assert(nativeMetricsData.metricsDataList.get(0).getName.equals("kRead"))
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(0)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 600572)
+
+ assert(nativeMetricsData.metricsDataList.get(1).getName.equals("kFilter"))
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(1)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getInputRows == 600572)
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(1)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 591673)
+
+ assert(nativeMetricsData.metricsDataList.get(2).getName.equals("kProject"))
+
+ assert(nativeMetricsData.metricsDataList.get(3).getName.equals("kProject"))
+ assert(nativeMetricsData.metricsDataList.get(4).getName.equals("kAggregate"))
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(4)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getInputRows == 591673)
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(4)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 4)
+
+ assert(
+ nativeMetricsData.metricsDataList
+ .get(4)
+ .getSteps
+ .get(0)
+ .getProcessors
+ .get(0)
+ .getOutputRows == 4)
+
+ val inBatchItersFinal = new java.util.ArrayList[GeneralInIterator](
+ Array(0).map(iter => new ColumnarNativeIterator(Iterator.empty.asJava)).toSeq.asJava)
+ val outputAttributesFinal = new java.util.ArrayList[Attribute](0)
+
+ val nativeMetricsListFinal = GlutenClickHouseMetricsUTUtils
+ .executeSubstraitPlan(
+ substraitPlansDatPath + "/covar_samp-covar_pop-final-agg-stage.json",
+ basePath,
+ inBatchItersFinal,
+ outputAttributesFinal
+ )
+
+ assert(nativeMetricsListFinal.size == 1)
+ val nativeMetricsDataFinal = nativeMetricsListFinal(0)
+ assert(nativeMetricsDataFinal.metricsDataList.size() == 3)
+
+ assert(nativeMetricsDataFinal.metricsDataList.get(0).getName.equals("kRead"))
+ assert(nativeMetricsDataFinal.metricsDataList.get(1).getName.equals("kAggregate"))
+ assert(nativeMetricsDataFinal.metricsDataList.get(1).getSteps.size() == 2)
+ assert(
+ nativeMetricsDataFinal.metricsDataList
+ .get(1)
+ .getSteps
+ .get(0)
+ .getName
+ .equals("GraceMergingAggregatedStep"))
+ assert(
+ nativeMetricsDataFinal.metricsDataList.get(1).getSteps.get(1).getName.equals("Expression"))
+ assert(nativeMetricsDataFinal.metricsDataList.get(2).getName.equals("kProject"))
+ }
}
}
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp
index 1be9c09a8c266..3bfddd59a746a 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -627,7 +627,7 @@ DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::mapsetString(CH_TASK_MEMORY, backend_conf_map.at(GLUTEN_TASK_OFFHEAP));
+ config->setString(MemoryConfig::CH_TASK_MEMORY, backend_conf_map.at(GLUTEN_TASK_OFFHEAP));
}
const bool use_current_directory_as_tmp = config->getBool("use_current_directory_as_tmp", false);
@@ -1050,17 +1050,6 @@ String DateTimeUtil::convertTimeZone(const String & time_zone)
return res;
}
-UInt64 MemoryUtil::getCurrentMemoryUsage(size_t depth)
-{
- Int64 current_memory_usage = 0;
- auto * current_mem_tracker = DB::CurrentThread::getMemoryTracker();
- for (size_t i = 0; i < depth && current_mem_tracker; ++i)
- current_mem_tracker = current_mem_tracker->getParent();
- if (current_mem_tracker)
- current_memory_usage = current_mem_tracker->get();
- return current_memory_usage < 0 ? 0 : current_memory_usage;
-}
-
UInt64 MemoryUtil::getMemoryRSS()
{
long rss = 0L;
diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h
index 98139fb49a5b3..4366dde2efd17 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -188,7 +188,6 @@ class BackendInitializerUtil
inline static const std::string SPARK_SESSION_TIME_ZONE = "spark.sql.session.timeZone";
inline static const String GLUTEN_TASK_OFFHEAP = "spark.gluten.memory.task.offHeap.size.in.bytes";
- inline static const String CH_TASK_MEMORY = "off_heap_per_task";
/// On yarn mode, native writing on hdfs cluster takes yarn container user as the user passed to libhdfs3, which
/// will cause permission issue because yarn container user is not the owner of the hdfs dir to be written.
@@ -252,7 +251,6 @@ class DateTimeUtil
class MemoryUtil
{
public:
- static UInt64 getCurrentMemoryUsage(size_t depth = 1);
static UInt64 getMemoryRSS();
};
diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h
new file mode 100644
index 0000000000000..782df7f5413d4
--- /dev/null
+++ b/cpp-ch/local-engine/Common/GlutenConfig.h
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace local_engine
+{
+struct MemoryConfig
+{
+ inline static const String EXTRA_MEMORY_HARD_LIMIT = "extra_memory_hard_limit";
+ inline static const String CH_TASK_MEMORY = "off_heap_per_task";
+ inline static const String SPILL_MEM_RATIO = "spill_mem_ratio";
+
+ size_t extra_memory_hard_limit = 0;
+ size_t off_heap_per_task = 0;
+ double spill_mem_ratio = 0.9;
+
+ static MemoryConfig loadFromContext(DB::ContextPtr context)
+ {
+ MemoryConfig config;
+ config.extra_memory_hard_limit = context->getConfigRef().getUInt64(EXTRA_MEMORY_HARD_LIMIT, 0);
+ config.off_heap_per_task = context->getConfigRef().getUInt64(CH_TASK_MEMORY, 0);
+ config.spill_mem_ratio = context->getConfigRef().getUInt64(SPILL_MEM_RATIO, 0.9);
+ return config;
+ }
+};
+
+struct GraceMergingAggregateConfig
+{
+ inline static const String MAX_GRACE_AGGREGATE_MERGING_BUCKETS = "max_grace_aggregate_merging_buckets";
+ inline static const String THROW_ON_OVERFLOW_GRACE_AGGREGATE_MERGING_BUCKETS = "throw_on_overflow_grace_aggregate_merging_buckets";
+ inline static const String AGGREGATED_KEYS_BEFORE_EXTEND_GRACE_AGGREGATE_MERGING_BUCKETS = "aggregated_keys_before_extend_grace_aggregate_merging_buckets";
+ inline static const String MAX_PENDING_FLUSH_BLOCKS_PER_GRACE_AGGREGATE_MERGING_BUCKET = "max_pending_flush_blocks_per_grace_aggregate_merging_bucket";
+ inline static const String MAX_ALLOWED_MEMORY_USAGE_RATIO_FOR_AGGREGATE_MERGING = "max_allowed_memory_usage_ratio_for_aggregate_merging";
+
+ size_t max_grace_aggregate_merging_buckets = 32;
+ bool throw_on_overflow_grace_aggregate_merging_buckets = false;
+ size_t aggregated_keys_before_extend_grace_aggregate_merging_buckets = 8192;
+ size_t max_pending_flush_blocks_per_grace_aggregate_merging_bucket = 1_MiB;
+ double max_allowed_memory_usage_ratio_for_aggregate_merging = 0.9;
+
+ static GraceMergingAggregateConfig loadFromContext(DB::ContextPtr context)
+ {
+ GraceMergingAggregateConfig config;
+ config.max_grace_aggregate_merging_buckets = context->getConfigRef().getUInt64(MAX_GRACE_AGGREGATE_MERGING_BUCKETS, 32);
+ config.throw_on_overflow_grace_aggregate_merging_buckets = context->getConfigRef().getBool(THROW_ON_OVERFLOW_GRACE_AGGREGATE_MERGING_BUCKETS, false);
+ config.aggregated_keys_before_extend_grace_aggregate_merging_buckets = context->getConfigRef().getUInt64(AGGREGATED_KEYS_BEFORE_EXTEND_GRACE_AGGREGATE_MERGING_BUCKETS, 8192);
+ config.max_pending_flush_blocks_per_grace_aggregate_merging_bucket = context->getConfigRef().getUInt64(MAX_PENDING_FLUSH_BLOCKS_PER_GRACE_AGGREGATE_MERGING_BUCKET, 1_MiB);
+ config.max_allowed_memory_usage_ratio_for_aggregate_merging = context->getConfigRef().getDouble(MAX_ALLOWED_MEMORY_USAGE_RATIO_FOR_AGGREGATE_MERGING, 0.9);
+ return config;
+ }
+};
+
+struct StreamingAggregateConfig
+{
+ inline static const String AGGREGATED_KEYS_BEFORE_STREAMING_AGGREGATING_EVICT = "aggregated_keys_before_streaming_aggregating_evict";
+ inline static const String MAX_MEMORY_USAGE_RATIO_FOR_STREAMING_AGGREGATING = "max_memory_usage_ratio_for_streaming_aggregating";
+ inline static const String HIGH_CARDINALITY_THRESHOLD_FOR_STREAMING_AGGREGATING = "high_cardinality_threshold_for_streaming_aggregating";
+ inline static const String ENABLE_STREAMING_AGGREGATING = "enable_streaming_aggregating";
+
+ size_t aggregated_keys_before_streaming_aggregating_evict = 1024;
+ double max_memory_usage_ratio_for_streaming_aggregating = 0.9;
+ double high_cardinality_threshold_for_streaming_aggregating = 0.8;
+ bool enable_streaming_aggregating = true;
+
+ static StreamingAggregateConfig loadFromContext(DB::ContextPtr context)
+ {
+ StreamingAggregateConfig config;
+ config.aggregated_keys_before_streaming_aggregating_evict = context->getConfigRef().getUInt64(AGGREGATED_KEYS_BEFORE_STREAMING_AGGREGATING_EVICT, 1024);
+ config.max_memory_usage_ratio_for_streaming_aggregating = context->getConfigRef().getDouble(MAX_MEMORY_USAGE_RATIO_FOR_STREAMING_AGGREGATING, 0.9);
+ config.high_cardinality_threshold_for_streaming_aggregating = context->getConfigRef().getDouble(HIGH_CARDINALITY_THRESHOLD_FOR_STREAMING_AGGREGATING, 0.8);
+ config.enable_streaming_aggregating = context->getConfigRef().getBool(ENABLE_STREAMING_AGGREGATING, true);
+ return config;
+ }
+};
+
+struct ExecutorConfig
+{
+ inline static const String DUMP_PIPELINE = "dump_pipeline";
+ inline static const String USE_LOCAL_FORMAT = "use_local_format";
+
+ bool dump_pipeline = false;
+ bool use_local_format = false;
+
+ static ExecutorConfig loadFromContext(DB::ContextPtr context)
+ {
+ ExecutorConfig config;
+ config.dump_pipeline = context->getConfigRef().getBool(DUMP_PIPELINE, false);
+ config.use_local_format = context->getConfigRef().getBool(USE_LOCAL_FORMAT, false);
+ return config;
+ }
+};
+
+struct HdfsConfig
+{
+ inline static const String HDFS_ASYNC = "hdfs.enable_async_io";
+
+ bool hdfs_async = true;
+
+ static HdfsConfig loadFromContext(DB::ContextPtr context)
+ {
+ HdfsConfig config;
+ config.hdfs_async = context->getConfigRef().getBool(HDFS_ASYNC, true);
+ return config;
+ }
+};
+
+struct S3Config
+{
+ inline static const String S3_LOCAL_CACHE_ENABLE = "s3.local_cache.enabled";
+ inline static const String S3_LOCAL_CACHE_MAX_SIZE = "s3.local_cache.max_size";
+ inline static const String S3_LOCAL_CACHE_CACHE_PATH = "s3.local_cache.cache_path";
+ inline static const String S3_GCS_ISSUE_COMPOSE_REQUEST = "s3.gcs_issue_compose_request";
+
+ bool s3_local_cache_enabled = false;
+ size_t s3_local_cache_max_size = 100_GiB;
+ String s3_local_cache_cache_path = "";
+ bool s3_gcs_issue_compose_request = false;
+
+ static S3Config loadFromContext(DB::ContextPtr context)
+ {
+ S3Config config;
+ config.s3_local_cache_enabled = context->getConfigRef().getBool(S3_LOCAL_CACHE_ENABLE, false);
+ config.s3_local_cache_max_size = context->getConfigRef().getUInt64(S3_LOCAL_CACHE_MAX_SIZE, 100_GiB);
+ config.s3_local_cache_cache_path = context->getConfigRef().getString(S3_LOCAL_CACHE_CACHE_PATH, "");
+ config.s3_gcs_issue_compose_request = context->getConfigRef().getBool(S3_GCS_ISSUE_COMPOSE_REQUEST, false);
+ return config;
+ }
+};
+
+struct MergeTreeConfig
+{
+ inline static const String TABLE_PART_METADATA_CACHE_MAX_COUNT = "table_part_metadata_cache_max_count";
+ inline static const String TABLE_METADATA_CACHE_MAX_COUNT = "table_metadata_cache_max_count";
+
+ size_t table_part_metadata_cache_max_count = 1000;
+ size_t table_metadata_cache_max_count = 100;
+
+ static MergeTreeConfig loadFromContext(DB::ContextPtr context)
+ {
+ MergeTreeConfig config;
+ config.table_part_metadata_cache_max_count = context->getConfigRef().getUInt64(TABLE_PART_METADATA_CACHE_MAX_COUNT, 1000);
+ config.table_metadata_cache_max_count = context->getConfigRef().getUInt64(TABLE_METADATA_CACHE_MAX_COUNT, 100);
+ return config;
+ }
+};
+}
+
diff --git a/cpp-ch/local-engine/Common/QueryContext.cpp b/cpp-ch/local-engine/Common/QueryContext.cpp
index f4d39c612430e..68934adad3671 100644
--- a/cpp-ch/local-engine/Common/QueryContext.cpp
+++ b/cpp-ch/local-engine/Common/QueryContext.cpp
@@ -15,33 +15,47 @@
* limitations under the License.
*/
#include "QueryContext.h"
+
+#include
+
#include
#include
-#include
#include
#include
+#include
+#include
+#include
+#include
+#include
namespace DB
{
namespace ErrorCodes
{
- extern const int LOGICAL_ERROR;
+extern const int LOGICAL_ERROR;
}
}
namespace local_engine
{
using namespace DB;
-thread_local std::shared_ptr query_scope;
-thread_local std::shared_ptr thread_status;
-ConcurrentMap allocator_map;
-int64_t initializeQuery(ReservationListenerWrapperPtr listener)
+struct QueryContext
+{
+ std::shared_ptr thread_status;
+ std::shared_ptr thread_group;
+ ContextMutablePtr query_context;
+};
+
+std::unordered_map> query_map;
+std::mutex query_map_mutex;
+
+int64_t QueryContextManager::initializeQuery()
{
- if (thread_status) return -1;
- auto query_context = Context::createCopy(SerializedPlanParser::global_context);
- query_context->makeQueryContext();
+ std::shared_ptr query_context = std::make_shared();
+ query_context->query_context = Context::createCopy(SerializedPlanParser::global_context);
+ query_context->query_context->makeQueryContext();
// empty input will trigger random query id to be set
// FileCache will check if query id is set to decide whether to skip cache or not
@@ -49,56 +63,115 @@ int64_t initializeQuery(ReservationListenerWrapperPtr listener)
//
// Notice:
// this generated random query id a qualified global queryid for the spark query
- query_context->setCurrentQueryId("");
-
- auto allocator_context = std::make_shared();
- allocator_context->thread_status = std::make_shared(true);
- allocator_context->query_scope = std::make_shared(query_context);
- allocator_context->group = std::make_shared(query_context);
- allocator_context->query_context = query_context;
- allocator_context->listener = listener;
- thread_status = allocator_context->thread_status;
- query_scope = allocator_context->query_scope;
- auto allocator_id = reinterpret_cast(allocator_context.get());
- CurrentMemoryTracker::before_alloc = [listener](Int64 size, bool throw_if_memory_exceed) -> void
+ query_context->query_context->setCurrentQueryId(toString(UUIDHelpers::generateV4()));
+ auto config = MemoryConfig::loadFromContext(query_context->query_context);
+ query_context->thread_status = std::make_shared(false);
+ query_context->thread_group = std::make_shared(query_context->query_context);
+ CurrentThread::attachToGroup(query_context->thread_group);
+ auto memory_limit = config.off_heap_per_task;
+
+ query_context->thread_group->memory_tracker.setSoftLimit(memory_limit);
+ query_context->thread_group->memory_tracker.setHardLimit(memory_limit + config.extra_memory_hard_limit);
+ std::lock_guard lock_guard(query_map_mutex);
+ int64_t id = reinterpret_cast(query_context->thread_group.get());
+ query_map.emplace(id, query_context);
+ return id;
+}
+
+DB::ContextMutablePtr QueryContextManager::currentQueryContext()
+{
+ if (!CurrentThread::getGroup())
{
- if (throw_if_memory_exceed)
- listener->reserveOrThrow(size);
- else
- listener->reserve(size);
- };
- CurrentMemoryTracker::before_free = [listener](Int64 size) -> void { listener->tryFree(size); };
- CurrentMemoryTracker::current_memory = [listener]() -> Int64 { return listener->currentMemory(); };
- allocator_map.insert(allocator_id, allocator_context);
- return allocator_id;
+ throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found.");
+ }
+ std::lock_guard lock_guard(query_map_mutex);
+ int64_t id = reinterpret_cast(CurrentThread::getGroup().get());
+ return query_map[id]->query_context;
}
-void releaseAllocator(int64_t allocator_id)
+void QueryContextManager::logCurrentPerformanceCounters(ProfileEvents::Counters & counters)
{
- if (!allocator_map.get(allocator_id))
+ if (!CurrentThread::getGroup())
+ {
+ return;
+ }
+ if (logger->information())
{
- throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "allocator {} not found", allocator_id);
+ std::ostringstream msg;
+ msg << "\n---------------------Task Performance Counters-----------------------------\n";
+ for (ProfileEvents::Event event = ProfileEvents::Event(0); event < counters.num_counters; event++)
+ {
+ const auto * name = ProfileEvents::getName(event);
+ const auto * doc = ProfileEvents::getDocumentation(event);
+ auto & count = counters[event];
+ if (count == 0)
+ continue;
+ msg << std::setw(50) << std::setfill(' ') << std::left << name << "|"
+ << std::setw(20) << std::setfill(' ') << std::left << count.load()
+ << " | (" << doc << ")\n";
+ }
+ LOG_INFO(logger, "{}", msg.str());
}
- auto status = allocator_map.get(allocator_id)->thread_status;
- status->detachFromGroup();
- auto listener = allocator_map.get(allocator_id)->listener;
- if (status->untracked_memory < 0)
- listener->free(-status->untracked_memory);
- else if (status->untracked_memory > 0)
- listener->reserve(status->untracked_memory);
- allocator_map.erase(allocator_id);
- thread_status.reset();
- query_scope.reset();
}
-NativeAllocatorContextPtr getAllocator(int64_t allocator)
+size_t QueryContextManager::currentPeakMemory(int64_t id)
{
- return allocator_map.get(allocator);
+ std::lock_guard lock_guard(query_map_mutex);
+ if (!query_map.contains(id))
+ throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "context released {}", id);
+ return query_map[id]->thread_group->memory_tracker.getPeak();
}
-int64_t allocatorMemoryUsage(int64_t allocator_id)
+void QueryContextManager::finalizeQuery(int64_t id)
{
- return allocator_map.get(allocator_id)->thread_status->memory_tracker.get();
+ if (!CurrentThread::getGroup())
+ {
+ throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found.");
+ }
+ std::shared_ptr context;
+ {
+ std::lock_guard lock_guard(query_map_mutex);
+ context = query_map[id];
+ }
+ auto query_context = context->thread_status->getQueryContext();
+ if (!query_context)
+ {
+ throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "query context not found");
+ }
+ context->thread_status->flushUntrackedMemory();
+ context->thread_status->finalizePerformanceCounters();
+ LOG_INFO(logger, "Task finished, peak memory usage: {} bytes", currentPeakMemory(id));
+
+ if (currentThreadGroupMemoryUsage() > 1_MiB)
+ {
+ LOG_WARNING(logger, "{} bytes memory didn't release, There may be a memory leak!", currentThreadGroupMemoryUsage());
+ }
+ logCurrentPerformanceCounters(context->thread_group->performance_counters);
+ context->thread_status->detachFromGroup();
+ context->thread_group.reset();
+ context->thread_status.reset();
+ query_context.reset();
+ {
+ std::lock_guard lock_guard(query_map_mutex);
+ query_map.erase(id);
+ }
}
+size_t currentThreadGroupMemoryUsage()
+{
+ if (!CurrentThread::getGroup())
+ {
+ throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found, please call initializeQuery first.");
+ }
+ return CurrentThread::getGroup()->memory_tracker.get();
+}
+
+double currentThreadGroupMemoryUsageRatio()
+{
+ if (!CurrentThread::getGroup())
+ {
+ throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found, please call initializeQuery first.");
+ }
+ return static_cast(CurrentThread::getGroup()->memory_tracker.get()) / CurrentThread::getGroup()->memory_tracker.getSoftLimit();
}
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Common/QueryContext.h b/cpp-ch/local-engine/Common/QueryContext.h
index 77d522dc9337a..0fbf4977321f1 100644
--- a/cpp-ch/local-engine/Common/QueryContext.h
+++ b/cpp-ch/local-engine/Common/QueryContext.h
@@ -15,30 +15,30 @@
* limitations under the License.
*/
#pragma once
-#include
#include
-#include
-#include
#include
namespace local_engine
{
-int64_t initializeQuery(ReservationListenerWrapperPtr listener);
-
-void releaseAllocator(int64_t allocator_id);
-
-int64_t allocatorMemoryUsage(int64_t allocator_id);
-
-struct NativeAllocatorContext
+class QueryContextManager
{
- std::shared_ptr query_scope;
- std::shared_ptr thread_status;
- DB::ContextMutablePtr query_context;
- std::shared_ptr group;
- ReservationListenerWrapperPtr listener;
-};
+public:
+ static QueryContextManager & instance()
+ {
+ static QueryContextManager instance;
+ return instance;
+ }
+ int64_t initializeQuery();
+ DB::ContextMutablePtr currentQueryContext();
+ void logCurrentPerformanceCounters(ProfileEvents::Counters& counters);
+ size_t currentPeakMemory(int64_t id);
+ void finalizeQuery(int64_t id);
-using NativeAllocatorContextPtr = std::shared_ptr;
+private:
+ QueryContextManager() = default;
+ LoggerPtr logger = getLogger("QueryContextManager");
+};
-NativeAllocatorContextPtr getAllocator(int64_t allocator);
+size_t currentThreadGroupMemoryUsage();
+double currentThreadGroupMemoryUsageRatio();
}
diff --git a/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp b/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp
index a9a2df276a594..82b498e58ff90 100644
--- a/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp
+++ b/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp
@@ -22,6 +22,8 @@
#include
#include
#include
+#include
+#include
namespace DB
{
@@ -114,12 +116,13 @@ GraceMergingAggregatedTransform::GraceMergingAggregatedTransform(const DB::Block
, no_pre_aggregated(no_pre_aggregated_)
, tmp_data_disk(std::make_unique(context_->getTempDataOnDisk()))
{
- max_buckets = context->getConfigRef().getUInt64("max_grace_aggregate_merging_buckets", 32);
- throw_on_overflow_buckets = context->getConfigRef().getBool("throw_on_overflow_grace_aggregate_merging_buckets", false);
- aggregated_keys_before_extend_buckets = context->getConfigRef().getUInt64("aggregated_keys_before_extend_grace_aggregate_merging_buckets", 8196);
+ auto config = GraceMergingAggregateConfig::loadFromContext(context);
+ max_buckets = config.max_grace_aggregate_merging_buckets;
+ throw_on_overflow_buckets = config.throw_on_overflow_grace_aggregate_merging_buckets;
+ aggregated_keys_before_extend_buckets = config.aggregated_keys_before_extend_grace_aggregate_merging_buckets;
aggregated_keys_before_extend_buckets = PODArrayUtil::adjustMemoryEfficientSize(aggregated_keys_before_extend_buckets);
- max_pending_flush_blocks_per_bucket = context->getConfigRef().getUInt64("max_pending_flush_blocks_per_grace_aggregate_merging_bucket", 1024 * 1024);
- max_allowed_memory_usage_ratio = context->getConfigRef().getDouble("max_allowed_memory_usage_ratio_for_aggregate_merging", 0.9);
+ max_pending_flush_blocks_per_bucket = config.max_pending_flush_blocks_per_grace_aggregate_merging_bucket;
+ max_allowed_memory_usage_ratio = config.max_allowed_memory_usage_ratio_for_aggregate_merging;
// bucket 0 is for in-memory data, it's just a placeholder.
buckets.emplace(0, BufferFileStream());
@@ -160,7 +163,7 @@ GraceMergingAggregatedTransform::Status GraceMergingAggregatedTransform::prepare
"Output one chunk. rows: {}, bytes: {}, current memory usage: {}",
output_chunk.getNumRows(),
ReadableSize(output_chunk.bytes()),
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
total_output_rows += output_chunk.getNumRows();
total_output_blocks++;
output.push(std::move(output_chunk));
@@ -189,7 +192,7 @@ GraceMergingAggregatedTransform::Status GraceMergingAggregatedTransform::prepare
"Input one new chunk. rows: {}, bytes: {}, current memory usage: {}",
input_chunk.getNumRows(),
ReadableSize(input_chunk.bytes()),
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
total_input_rows += input_chunk.getNumRows();
total_input_blocks++;
has_input = true;
@@ -277,7 +280,7 @@ bool GraceMergingAggregatedTransform::extendBuckets()
void GraceMergingAggregatedTransform::rehashDataVariants()
{
- auto before_memoery_usage = MemoryUtil::getCurrentMemoryUsage();
+ auto before_memoery_usage = currentThreadGroupMemoryUsage();
auto converter = currentDataVariantToBlockConverter(false);
checkAndSetupCurrentDataVariants();
@@ -318,7 +321,7 @@ void GraceMergingAggregatedTransform::rehashDataVariants()
current_bucket_index,
getBucketsNum(),
ReadableSize(before_memoery_usage),
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
};
DB::Blocks GraceMergingAggregatedTransform::scatterBlock(const DB::Block & block)
@@ -539,7 +542,7 @@ void GraceMergingAggregatedTransform::mergeOneBlock(const DB::Block &block, bool
block.info.bucket_num,
current_bucket_index,
getBucketsNum(),
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
/// the block could be one read from disk. block.info.bucket_num stores the number of buckets when it was scattered.
/// so if the buckets number is not changed since it was scattered, we don't need to scatter it again.
@@ -590,11 +593,13 @@ bool GraceMergingAggregatedTransform::isMemoryOverflow()
/// More greedy memory usage strategy.
if (!current_data_variants)
return false;
- if (!context->getSettingsRef().max_memory_usage)
+
+ auto memory_soft_limit = DB::CurrentThread::getGroup()->memory_tracker.getSoftLimit();
+ if (!memory_soft_limit)
return false;
- auto max_mem_used = static_cast(context->getSettingsRef().max_memory_usage * max_allowed_memory_usage_ratio);
+ auto max_mem_used = static_cast(memory_soft_limit * max_allowed_memory_usage_ratio);
auto current_result_rows = current_data_variants->size();
- auto current_mem_used = MemoryUtil::getCurrentMemoryUsage();
+ auto current_mem_used = currentThreadGroupMemoryUsage();
if (per_key_memory_usage > 0)
{
if (current_mem_used + per_key_memory_usage * current_result_rows >= max_mem_used)
diff --git a/cpp-ch/local-engine/Operator/StreamingAggregatingStep.cpp b/cpp-ch/local-engine/Operator/StreamingAggregatingStep.cpp
index 65d77f8e968f0..2235f4cbe45f5 100644
--- a/cpp-ch/local-engine/Operator/StreamingAggregatingStep.cpp
+++ b/cpp-ch/local-engine/Operator/StreamingAggregatingStep.cpp
@@ -19,8 +19,9 @@
#include
#include
#include
-#include
#include
+#include
+#include
#include
namespace DB
@@ -41,10 +42,11 @@ StreamingAggregatingTransform::StreamingAggregatingTransform(DB::ContextPtr cont
, aggregate_columns(params_->params.aggregates_size)
, params(params_)
{
- aggregated_keys_before_evict = context->getConfigRef().getUInt64("aggregated_keys_before_streaming_aggregating_evict", 1024);
+ auto config = StreamingAggregateConfig::loadFromContext(context);
+ aggregated_keys_before_evict = config.aggregated_keys_before_streaming_aggregating_evict;
aggregated_keys_before_evict = PODArrayUtil::adjustMemoryEfficientSize(aggregated_keys_before_evict);
- max_allowed_memory_usage_ratio = context->getConfigRef().getDouble("max_memory_usage_ratio_for_streaming_aggregating", 0.9);
- high_cardinality_threshold = context->getConfigRef().getDouble("high_cardinality_threshold_for_streaming_aggregating", 0.8);
+ max_allowed_memory_usage_ratio = config.max_memory_usage_ratio_for_streaming_aggregating;
+ high_cardinality_threshold = config.high_cardinality_threshold_for_streaming_aggregating;
}
StreamingAggregatingTransform::~StreamingAggregatingTransform()
@@ -60,7 +62,7 @@ StreamingAggregatingTransform::~StreamingAggregatingTransform()
total_clear_data_variants_num,
total_aggregate_time,
total_convert_data_variants_time,
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
}
StreamingAggregatingTransform::Status StreamingAggregatingTransform::prepare()
@@ -82,7 +84,7 @@ StreamingAggregatingTransform::Status StreamingAggregatingTransform::prepare()
"Output one chunk. rows: {}, bytes: {}, current memory usage: {}",
output_chunk.getNumRows(),
ReadableSize(output_chunk.bytes()),
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
total_output_rows += output_chunk.getNumRows();
total_output_blocks++;
if (!output_chunk.getNumRows())
@@ -125,7 +127,7 @@ StreamingAggregatingTransform::Status StreamingAggregatingTransform::prepare()
"Input one new chunk. rows: {}, bytes: {}, current memory usage: {}",
input_chunk.getNumRows(),
ReadableSize(input_chunk.bytes()),
- ReadableSize(MemoryUtil::getCurrentMemoryUsage()));
+ ReadableSize(currentThreadGroupMemoryUsage()));
total_input_rows += input_chunk.getNumRows();
total_input_blocks++;
has_input = true;
@@ -136,10 +138,10 @@ bool StreamingAggregatingTransform::needEvict()
{
if (input_finished)
return true;
- if (!context->getSettingsRef().max_memory_usage)
+ auto memory_soft_limit = DB::CurrentThread::getGroup()->memory_tracker.getSoftLimit();
+ if (!memory_soft_limit)
return false;
-
- auto max_mem_used = static_cast(context->getSettingsRef().max_memory_usage * max_allowed_memory_usage_ratio);
+ auto max_mem_used = static_cast(memory_soft_limit * max_allowed_memory_usage_ratio);
auto current_result_rows = data_variants->size();
/// avoid evict empty or too small aggregated results.
if (current_result_rows < aggregated_keys_before_evict)
@@ -150,7 +152,7 @@ bool StreamingAggregatingTransform::needEvict()
if (static_cast(total_output_rows)/total_input_rows > high_cardinality_threshold)
return true;
- auto current_mem_used = MemoryUtil::getCurrentMemoryUsage();
+ auto current_mem_used = currentThreadGroupMemoryUsage();
if (per_key_memory_usage > 0)
{
/// When we know each key memory usage, we can take a more greedy memory usage strategy
diff --git a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp
index 0857995571d45..532b4114b8f07 100644
--- a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp
@@ -29,6 +29,7 @@
#include
#include
#include
+#include
namespace DB
{
@@ -287,8 +288,8 @@ void AggregateRelParser::addMergingAggregatedStep()
settings.max_threads,
PODArrayUtil::adjustMemoryEfficientSize(settings.max_block_size),
settings.min_hit_rate_to_use_consecutive_keys_optimization);
- bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);
- if (enable_streaming_aggregating)
+ auto config = StreamingAggregateConfig::loadFromContext(getContext());
+ if (config.enable_streaming_aggregating)
{
params.group_by_two_level_threshold = settings.group_by_two_level_threshold;
auto merging_step = std::make_unique(getContext(), plan->getCurrentDataStream(), params, false);
@@ -319,8 +320,8 @@ void AggregateRelParser::addCompleteModeAggregatedStep()
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
const auto & settings = getContext()->getSettingsRef();
- bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);
- if (enable_streaming_aggregating)
+ auto config = StreamingAggregateConfig::loadFromContext(getContext());
+ if (config.enable_streaming_aggregating)
{
Aggregator::Params params(
grouping_keys,
@@ -397,9 +398,9 @@ void AggregateRelParser::addAggregatingStep()
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
const auto & settings = getContext()->getSettingsRef();
- bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);
- if (enable_streaming_aggregating)
+ auto config = StreamingAggregateConfig::loadFromContext(getContext());
+ if (config.enable_streaming_aggregating)
{
// Disable spilling to disk.
// If group_by_two_level_threshold_bytes != 0, `Aggregator` will use memory usage as a condition to convert
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 151e24c6da456..5aaf006a362e8 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -82,6 +82,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -1358,8 +1359,8 @@ std::unique_ptr SerializedPlanParser::createExecutor(DB::QueryPla
logger, "clickhouse plan [optimization={}]:\n{}", settings.query_plan_enable_optimizations, PlanUtil::explainPlan(*query_plan));
LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(pipeline));
- bool dump_pipeline = context->getConfigRef().getBool("dump_pipeline", false);
- return std::make_unique(std::move(query_plan), std::move(pipeline), dump_pipeline);
+ auto config = ExecutorConfig::loadFromContext(context);
+ return std::make_unique(std::move(query_plan), std::move(pipeline), config.dump_pipeline);
}
SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_)
diff --git a/cpp-ch/local-engine/Parser/SortRelParser.cpp b/cpp-ch/local-engine/Parser/SortRelParser.cpp
index ea29e72d1324c..8fb97d6da5dd3 100644
--- a/cpp-ch/local-engine/Parser/SortRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/SortRelParser.cpp
@@ -15,10 +15,12 @@
* limitations under the License.
*/
#include "SortRelParser.h"
+
+#include
#include
#include
-#include
#include
+#include
namespace DB
{
@@ -41,11 +43,11 @@ SortRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, st
const auto & sort_rel = rel.sort();
auto sort_descr = parseSortDescription(sort_rel.sorts(), query_plan->getCurrentDataStream().header);
SortingStep::Settings settings(*getContext());
- size_t offheap_per_task = getContext()->getConfigRef().getUInt64("off_heap_per_task");
- double spill_mem_ratio = getContext()->getConfigRef().getDouble("spill_mem_ratio", 0.9);
- settings.worth_external_sort = [offheap_per_task, spill_mem_ratio]() -> bool
+ auto config = MemoryConfig::loadFromContext(getContext());
+ double spill_mem_ratio = config.spill_mem_ratio;
+ settings.worth_external_sort = [spill_mem_ratio]() -> bool
{
- return CurrentMemoryTracker::current_memory() > offheap_per_task * spill_mem_ratio;
+ return currentThreadGroupMemoryUsageRatio() > spill_mem_ratio;
};
auto sorting_step = std::make_unique(
query_plan->getCurrentDataStream(), sort_descr, limit, settings, false);
diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp
index fd6f6fd81b5d1..1ab95abcca48d 100644
--- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp
+++ b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp
@@ -133,26 +133,20 @@ void CachedShuffleWriter::lazyInitPartitionWriter(Block & input_sample)
if (partition_writer)
return;
-// auto avg_row_size = input_sample.allocatedBytes() / input_sample.rows();
-// auto overhead_memory = std::max(avg_row_size, input_sample.columns() * 16) * options.split_size * options.partition_num;
-// auto use_sort_shuffle = overhead_memory > options.spill_threshold * 0.5 || options.partition_num >= 300;
- auto use_external_sort_shuffle = options.force_external_sort;
- auto use_memory_sort_shuffle = options.force_mermory_sort;
- sort_shuffle = use_memory_sort_shuffle || use_external_sort_shuffle;
+ auto avg_row_size = input_sample.allocatedBytes() / input_sample.rows();
+ auto overhead_memory = std::max(avg_row_size, input_sample.columns() * 16) * options.split_size * options.partition_num;
+ auto use_sort_shuffle = overhead_memory > options.spill_threshold * 0.5 || options.partition_num >= 300;
+ sort_shuffle = use_sort_shuffle || options.force_memory_sort;
if (celeborn_client)
{
- if (use_external_sort_shuffle)
- partition_writer = std::make_unique(this, std::move(celeborn_client));
- else if (use_memory_sort_shuffle)
+ if (sort_shuffle)
partition_writer = std::make_unique(this, std::move(celeborn_client));
else
partition_writer = std::make_unique(this, std::move(celeborn_client));
}
else
{
- if (use_external_sort_shuffle)
- partition_writer = std::make_unique(this);
- else if (use_memory_sort_shuffle)
+ if (sort_shuffle)
partition_writer = std::make_unique(this);
else
partition_writer = std::make_unique(this);
@@ -169,9 +163,4 @@ SplitResult CachedShuffleWriter::stop()
return split_result;
}
-size_t CachedShuffleWriter::evictPartitions()
-{
- if (!partition_writer) return 0;
- return partition_writer->evictPartitions(true, options.flush_block_buffer_before_evict);
-}
}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h
index e6395c8e47128..6de22f35d9bff 100644
--- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h
+++ b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h
@@ -17,7 +17,7 @@
#pragma once
#include
#include
-#include
+#include
#include
#include
#include
@@ -46,7 +46,6 @@ class CachedShuffleWriter : public ShuffleWriterBase
~CachedShuffleWriter() override = default;
void split(DB::Block & block) override;
- size_t evictPartitions() override;
SplitResult stop() override;
private:
diff --git a/cpp-ch/local-engine/Shuffle/NativeSplitter.h b/cpp-ch/local-engine/Shuffle/NativeSplitter.h
index 201beb98a6ce9..71d63b61da78e 100644
--- a/cpp-ch/local-engine/Shuffle/NativeSplitter.h
+++ b/cpp-ch/local-engine/Shuffle/NativeSplitter.h
@@ -26,7 +26,7 @@
#include
#include
#include
-#include
+#include
#include
#include
diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp
index d02c79e0a5d6d..a2ef0888aeff5 100644
--- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp
+++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp
@@ -18,21 +18,19 @@
#include
#include
#include
-#include
#include
#include
#include
#include
#include
#include
-#include
#include
#include
#include
#include
-#include
#include
#include
+#include
#include
#include
@@ -51,15 +49,32 @@ namespace local_engine
{
static const String PARTITION_COLUMN_NAME = "partition";
-void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & block)
+int64_t searchLastPartitionIdIndex(ColumnPtr column, size_t start, size_t partition_id)
{
- /// PartitionWriter::write is alwasy the top frame who occupies evicting_or_writing
- if (evicting_or_writing)
- throw Exception(ErrorCodes::LOGICAL_ERROR, "PartitionWriter::write is invoked with evicting_or_writing being occupied");
+ const auto & int64_column = checkAndGetColumn(*column);
+ int64_t low = start, high = int64_column.size() - 1;
+ while (low <= high)
+ {
+ int64_t mid = low + (high - low) / 2;
+ if (int64_column.get64(mid) > partition_id)
+ high = mid - 1;
+ else
+ low = mid + 1;
+ if (int64_column.get64(high) == partition_id)
+ return high;
+ }
+ return -1;
+}
- evicting_or_writing = true;
- SCOPE_EXIT({ evicting_or_writing = false; });
+bool PartitionWriter::worthToSpill(size_t cache_size) const
+{
+ return (options->spill_threshold > 0 && cache_size >= options->spill_threshold) ||
+ currentThreadGroupMemoryUsageRatio() > settings.spill_mem_ratio;
+}
+void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & block)
+{
+ /// PartitionWriter::write is alwasy the top frame who occupies evicting_or_writing
Stopwatch watch;
size_t current_cached_bytes = bytes();
for (size_t partition_id = 0; partition_id < partition_info.partition_num; ++partition_id)
@@ -79,60 +94,48 @@ void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & bl
current_cached_bytes += block_buffer->bytes();
/// Only works for celeborn partitiion writer
- if (supportsEvictSinglePartition() && options->spill_threshold > 0 && current_cached_bytes >= options->spill_threshold)
+ if (supportsEvictSinglePartition() && worthToSpill(current_cached_bytes))
{
- /// If flush_block_buffer_before_evict is disabled, evict partitions from (last_partition_id+1)%partition_num to partition_id directly without flush,
- /// Otherwise flush partition block buffer if it's size is no less than average rows, then evict partitions as above.
- if (!options->flush_block_buffer_before_evict)
+ /// Calculate average rows of each partition block buffer
+ size_t avg_size = 0;
+ size_t cnt = 0;
+ for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num;
+ i = (i + 1) % options->partition_num)
{
- for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num;
- i = (i + 1) % options->partition_num)
- unsafeEvictSinglePartition(false, false, i);
+ avg_size += partition_block_buffer[i]->size();
+ ++cnt;
}
- else
- {
- /// Calculate average rows of each partition block buffer
- size_t avg_size = 0;
- size_t cnt = 0;
- for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num;
- i = (i + 1) % options->partition_num)
- {
- avg_size += partition_block_buffer[i]->size();
- ++cnt;
- }
- avg_size /= cnt;
+ avg_size /= cnt;
- for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num;
- i = (i + 1) % options->partition_num)
- {
- bool flush_block_buffer = partition_block_buffer[i]->size() >= avg_size;
- current_cached_bytes -= flush_block_buffer ? partition_block_buffer[i]->bytes() + partition_buffer[i]->bytes()
- : partition_buffer[i]->bytes();
- unsafeEvictSinglePartition(false, flush_block_buffer, i);
- }
- // std::cout << "current cached bytes after evict partitions is " << current_cached_bytes << " partition from "
- // << (last_partition_id + 1) % options->partition_num << " to " << partition_id << " average size:" << avg_size
- // << std::endl;
+ for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num;
+ i = (i + 1) % options->partition_num)
+ {
+ bool flush_block_buffer = partition_block_buffer[i]->size() >= avg_size;
+ current_cached_bytes -= flush_block_buffer ? partition_block_buffer[i]->bytes() + partition_buffer[i]->bytes()
+ : partition_buffer[i]->bytes();
+ evictSinglePartition(i);
}
-
+ // std::cout << "current cached bytes after evict partitions is " << current_cached_bytes << " partition from "
+ // << (last_partition_id + 1) % options->partition_num << " to " << partition_id << " average size:" << avg_size
+ // << std::endl;
last_partition_id = partition_id;
}
}
/// Only works for local partition writer
- if (!supportsEvictSinglePartition() && options->spill_threshold && CurrentMemoryTracker::current_memory() >= options->spill_threshold)
- unsafeEvictPartitions(false, options->flush_block_buffer_before_evict);
+ if (!supportsEvictSinglePartition() && worthToSpill(current_cached_bytes))
+ evictPartitions();
shuffle_writer->split_result.total_split_time += watch.elapsedNanoseconds();
}
-size_t LocalPartitionWriter::unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer)
+size_t LocalPartitionWriter::evictPartitions()
{
size_t res = 0;
size_t spilled_bytes = 0;
- auto spill_to_file = [this, for_memory_spill, flush_block_buffer, &res, &spilled_bytes]()
+ auto spill_to_file = [this, &res, &spilled_bytes]()
{
auto file = getNextSpillFile();
WriteBufferFromFile output(file, shuffle_writer->options.io_buffer_size);
@@ -148,12 +151,9 @@ size_t LocalPartitionWriter::unsafeEvictPartitions(bool for_memory_spill, bool f
{
auto & buffer = partition_buffer[partition_id];
- if (flush_block_buffer)
- {
- auto & block_buffer = partition_block_buffer[partition_id];
- if (!block_buffer->empty())
- buffer->addBlock(block_buffer->releaseColumns());
- }
+ auto & block_buffer = partition_block_buffer[partition_id];
+ if (!block_buffer->empty())
+ buffer->addBlock(block_buffer->releaseColumns());
if (buffer->empty())
continue;
@@ -177,24 +177,16 @@ size_t LocalPartitionWriter::unsafeEvictPartitions(bool for_memory_spill, bool f
};
Stopwatch spill_time_watch;
- if (for_memory_spill && options->throw_if_memory_exceed)
- {
- // escape memory track from current thread status; add untracked memory limit for create thread object, avoid trigger memory spill again
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- spill_to_file();
- }
- else
- {
- spill_to_file();
- }
+ spill_to_file();
shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds();
shuffle_writer->split_result.total_bytes_spilled += spilled_bytes;
+ LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds());
return res;
}
String Spillable::getNextSpillFile()
{
- auto file_name = std::to_string(split_options.shuffle_id) + "_" + std::to_string(split_options.map_id) + "_" + std::to_string(spill_infos.size());
+ auto file_name = std::to_string(static_cast(split_options.shuffle_id)) + "_" + std::to_string(static_cast(split_options.map_id)) + "_" + std::to_string(spill_infos.size());
std::hash hasher;
auto hash = hasher(file_name);
auto dir_id = hash % split_options.local_dirs_list.size();
@@ -304,32 +296,28 @@ void SortBasedPartitionWriter::write(const PartitionInfo & info, DB::Block & blo
current_accumulated_bytes += accumulated_blocks.back().allocatedBytes();
current_accumulated_rows += accumulated_blocks.back().getNumRows();
shuffle_writer->split_result.total_write_time += write_time_watch.elapsedNanoseconds();
- if (options->spill_threshold && CurrentMemoryTracker::current_memory() >= options->spill_threshold)
- unsafeEvictPartitions(false, false);
+ if (worthToSpill(current_accumulated_bytes))
+ evictPartitions();
}
-LocalPartitionWriter::LocalPartitionWriter(CachedShuffleWriter * shuffle_writer_) : PartitionWriter(shuffle_writer_), Spillable(shuffle_writer_->options)
+LocalPartitionWriter::LocalPartitionWriter(CachedShuffleWriter * shuffle_writer_) : PartitionWriter(shuffle_writer_, getLogger("LocalPartitionWriter")), Spillable(shuffle_writer_->options)
{
}
-void LocalPartitionWriter::unsafeStop()
+void LocalPartitionWriter::stop()
{
WriteBufferFromFile output(options->data_file, options->io_buffer_size);
auto offsets = mergeSpills(shuffle_writer, output, {partition_block_buffer, partition_buffer});
shuffle_writer->split_result.partition_lengths = offsets;
}
-void PartitionWriterSettings::loadFromContext(DB::ContextPtr context)
-{
- spill_memory_overhead = context->getConfigRef().getUInt64("spill_memory_overhead", 50 << 20);
-}
-
-PartitionWriter::PartitionWriter(CachedShuffleWriter * shuffle_writer_)
+PartitionWriter::PartitionWriter(CachedShuffleWriter * shuffle_writer_, LoggerPtr logger_)
: shuffle_writer(shuffle_writer_)
, options(&shuffle_writer->options)
, partition_block_buffer(options->partition_num)
, partition_buffer(options->partition_num)
, last_partition_id(options->partition_num - 1)
+ , logger(logger_)
{
for (size_t partition_id = 0; partition_id < options->partition_num; ++partition_id)
{
@@ -339,26 +327,6 @@ PartitionWriter::PartitionWriter(CachedShuffleWriter * shuffle_writer_)
settings.loadFromContext(SerializedPlanParser::global_context);
}
-size_t PartitionWriter::evictPartitions(bool for_memory_spill, bool flush_block_buffer)
-{
- if (evicting_or_writing)
- return 0;
-
- evicting_or_writing = true;
- SCOPE_EXIT({ evicting_or_writing = false; });
- return unsafeEvictPartitions(for_memory_spill, flush_block_buffer);
-}
-
-void PartitionWriter::stop()
-{
- if (evicting_or_writing)
- throw Exception(ErrorCodes::LOGICAL_ERROR, "PartitionWriter::stop is invoked with evicting_or_writing being occupied");
-
- evicting_or_writing = true;
- SCOPE_EXIT({ evicting_or_writing = false; });
- return unsafeStop();
-}
-
size_t PartitionWriter::bytes() const
{
size_t bytes = 0;
@@ -372,7 +340,8 @@ size_t PartitionWriter::bytes() const
return bytes;
}
-size_t MemorySortLocalPartitionWriter::unsafeEvictPartitions(bool for_memory_spill, bool /*flush_block_buffer*/)
+
+size_t MemorySortLocalPartitionWriter::evictPartitions()
{
size_t res = 0;
size_t spilled_bytes = 0;
@@ -456,34 +425,26 @@ size_t MemorySortLocalPartitionWriter::unsafeEvictPartitions(bool for_memory_spi
};
Stopwatch spill_time_watch;
- if (for_memory_spill && options->throw_if_memory_exceed)
- {
- // escape memory track from current thread status; add untracked memory limit for create thread object, avoid trigger memory spill again
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- spill_to_file();
- }
- else
- {
- spill_to_file();
- }
+ spill_to_file();
shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds();
shuffle_writer->split_result.total_bytes_spilled += spilled_bytes;
+ LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds());
return res;
}
-void MemorySortLocalPartitionWriter::unsafeStop()
+void MemorySortLocalPartitionWriter::stop()
{
- unsafeEvictPartitions(false, false);
+ evictPartitions();
WriteBufferFromFile output(options->data_file, options->io_buffer_size);
auto offsets = mergeSpills(shuffle_writer, output);
shuffle_writer->split_result.partition_lengths = offsets;
}
-size_t MemorySortCelebornPartitionWriter::unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer)
+size_t MemorySortCelebornPartitionWriter::evictPartitions()
{
size_t res = 0;
size_t spilled_bytes = 0;
- auto spill_to_celeborn = [this, for_memory_spill, flush_block_buffer, &res, &spilled_bytes]()
+ auto spill_to_celeborn = [this, &res, &spilled_bytes]()
{
Stopwatch serialization_time_watch;
@@ -553,202 +514,49 @@ size_t MemorySortCelebornPartitionWriter::unsafeEvictPartitions(bool for_memory_
shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime();
shuffle_writer->split_result.total_io_time += compressed_output.getWriteTime();
-
shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds();
};
Stopwatch spill_time_watch;
- if (for_memory_spill && options->throw_if_memory_exceed)
- {
- // escape memory track from current thread status; add untracked memory limit for create thread object, avoid trigger memory spill again
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- spill_to_celeborn();
- }
- else
- {
- spill_to_celeborn();
- }
-
+ spill_to_celeborn();
shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds();
shuffle_writer->split_result.total_bytes_spilled += spilled_bytes;
+ LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds());
return res;
}
-void MemorySortCelebornPartitionWriter::unsafeStop()
-{
- unsafeEvictPartitions(false, false);
-}
-
-size_t ExternalSortLocalPartitionWriter::unsafeEvictPartitions(bool, bool)
+void MemorySortCelebornPartitionWriter::stop()
{
- // escape memory track
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- if (accumulated_blocks.empty())
- return 0;
- if (max_merge_block_bytes)
- {
- max_merge_block_size = std::max(max_merge_block_bytes / (current_accumulated_bytes / current_accumulated_rows), 128UL);
- }
- Stopwatch watch;
- MergeSorter sorter(sort_header, std::move(accumulated_blocks), sort_description, max_merge_block_size, 0);
- streams.emplace_back(&tmp_data->createStream(sort_header));
- while (auto data = sorter.read())
- {
- Block serialized_block = sort_header.cloneWithColumns(data.detachColumns());
- streams.back()->write(serialized_block);
- }
- streams.back()->finishWriting();
- auto result = current_accumulated_bytes;
- current_accumulated_bytes = 0;
- current_accumulated_rows = 0;
- shuffle_writer->split_result.total_spill_time += watch.elapsedNanoseconds();
- return result;
+ evictPartitions();
}
-std::queue ExternalSortLocalPartitionWriter::mergeDataInMemory()
-{
- if (accumulated_blocks.empty())
- return {};
- std::queue result;
- MergeSorter sorter(sort_header, std::move(accumulated_blocks), sort_description, max_merge_block_size, 0);
- while (auto data = sorter.read())
- {
- Block serialized_block = sort_header.cloneWithColumns(data.detachColumns());
- result.push(serialized_block);
- }
- return result;
-}
-
-ExternalSortLocalPartitionWriter::MergeContext ExternalSortLocalPartitionWriter::prepareMerge()
-{
- MergeContext context;
- if (options->spill_firstly_before_stop)
- unsafeEvictPartitions(false, false);
- auto num_input = accumulated_blocks.empty() ? streams.size() : streams.size() + 1;
- std::unique_ptr algorithm = std::make_unique(
- sort_header, num_input, sort_description, max_merge_block_size, 0, SortingQueueStrategy::Batch);
- context.codec = CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {});
- auto sorted_memory_data = mergeDataInMemory();
- context.merger = std::make_unique(std::move(algorithm), streams, sorted_memory_data, output_header);
- return context;
-}
-
-void ExternalSortLocalPartitionWriter::unsafeStop()
-{
- // escape memory track
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- Stopwatch write_time_watch;
- // no data to write
- if (streams.empty() && accumulated_blocks.empty())
- return;
- auto context = prepareMerge();
- WriteBufferFromFile output(options->data_file, options->io_buffer_size);
- CompressedWriteBuffer compressed_output(output, context.codec, shuffle_writer->options.io_buffer_size);
- NativeWriter native_writer(compressed_output, output_header);
-
- std::vector partition_length(shuffle_writer->options.partition_num, 0);
- size_t current_file_size = 0;
- size_t current_partition_raw_size = 0;
- size_t current_partition_id = 0;
- auto finish_partition_if_needed = [&]()
- {
- if (!partition_length[current_partition_id])
- {
- compressed_output.sync();
- shuffle_writer->split_result.raw_partition_lengths[current_partition_id] = current_partition_raw_size;
- partition_length[current_partition_id] = output.count() - current_file_size;
- current_file_size = output.count();
- current_partition_id++;
- current_partition_raw_size = 0;
- }
- };
- while (!context.merger->isFinished())
- {
- auto result = context.merger->next();
- if (result.empty)
- break;
- for (auto & item : result.blocks)
- {
- while (item.second - current_partition_id > 1)
- finish_partition_if_needed();
- current_partition_raw_size += native_writer.write(item.first);
- }
- }
- while (shuffle_writer->options.partition_num - current_partition_id > 0)
- finish_partition_if_needed();
- shuffle_writer->split_result.partition_lengths = partition_length;
- shuffle_writer->split_result.total_write_time += write_time_watch.elapsedNanoseconds();
- shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime();
- shuffle_writer->split_result.total_io_time += compressed_output.getWriteTime();
-}
-
-void ExternalSortCelebornPartitionWriter::unsafeStop()
-{
- // escape memory track
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- Stopwatch write_time_watch;
- // no data to write
- if (streams.empty() && accumulated_blocks.empty())
- return;
- auto context = prepareMerge();
-
- WriteBufferFromOwnString output;
- CompressedWriteBuffer compressed_output(output, context.codec, shuffle_writer->options.io_buffer_size);
- NativeWriter native_writer(compressed_output, output_header);
- std::vector partition_length(shuffle_writer->options.partition_num, 0);
-
- while (!context.merger->isFinished())
- {
- auto result = context.merger->next();
- if (result.empty)
- break;
- for (auto & item : result.blocks)
- {
- shuffle_writer->split_result.raw_partition_lengths[item.second] += native_writer.write(item.first);
- compressed_output.sync();
- partition_length[item.second] += output.count();
- Stopwatch push_time;
- celeborn_client->pushPartitionData(item.second, output.str().data(), output.str().size());
- shuffle_writer->split_result.total_io_time += push_time.elapsedNanoseconds();
- output.restart();
- }
- }
-
- shuffle_writer->split_result.partition_lengths = partition_length;
- shuffle_writer->split_result.total_write_time += write_time_watch.elapsedNanoseconds();
- shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime();
- shuffle_writer->split_result.total_io_time += compressed_output.getWriteTime();
-}
CelebornPartitionWriter::CelebornPartitionWriter(CachedShuffleWriter * shuffleWriter, std::unique_ptr celeborn_client_)
- : PartitionWriter(shuffleWriter), celeborn_client(std::move(celeborn_client_))
+ : PartitionWriter(shuffleWriter, getLogger("CelebornPartitionWriter")), celeborn_client(std::move(celeborn_client_))
{
}
-size_t CelebornPartitionWriter::unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer)
+size_t CelebornPartitionWriter::evictPartitions()
{
size_t res = 0;
for (size_t partition_id = 0; partition_id < options->partition_num; ++partition_id)
- res += unsafeEvictSinglePartition(for_memory_spill, flush_block_buffer, partition_id);
+ res += evictSinglePartition(partition_id);
return res;
}
-size_t CelebornPartitionWriter::unsafeEvictSinglePartition(bool for_memory_spill, bool flush_block_buffer, size_t partition_id)
+size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id)
{
size_t res = 0;
size_t spilled_bytes = 0;
- auto spill_to_celeborn = [this, for_memory_spill, flush_block_buffer, partition_id, &res, &spilled_bytes]()
+ auto spill_to_celeborn = [this,partition_id, &res, &spilled_bytes]()
{
Stopwatch serialization_time_watch;
auto & buffer = partition_buffer[partition_id];
- if (flush_block_buffer)
+ auto & block_buffer = partition_block_buffer[partition_id];
+ if (!block_buffer->empty())
{
- auto & block_buffer = partition_block_buffer[partition_id];
- if (!block_buffer->empty())
- {
- // std::cout << "flush block buffer for partition:" << partition_id << " rows:" << block_buffer->size() << std::endl;
- buffer->addBlock(block_buffer->releaseColumns());
- }
+ // std::cout << "flush block buffer for partition:" << partition_id << " rows:" << block_buffer->size() << std::endl;
+ buffer->addBlock(block_buffer->releaseColumns());
}
/// Skip empty buffer
@@ -781,26 +589,16 @@ size_t CelebornPartitionWriter::unsafeEvictSinglePartition(bool for_memory_spill
};
Stopwatch spill_time_watch;
- if (for_memory_spill && options->throw_if_memory_exceed)
- {
- // escape memory track from current thread status; add untracked memory limit for create thread object, avoid trigger memory spill again
- IgnoreMemoryTracker ignore(settings.spill_memory_overhead);
- spill_to_celeborn();
- }
- else
- {
- spill_to_celeborn();
- }
-
+ spill_to_celeborn();
shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds();
shuffle_writer->split_result.total_bytes_spilled += spilled_bytes;
+ LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds());
return res;
}
-void CelebornPartitionWriter::unsafeStop()
+void CelebornPartitionWriter::stop()
{
- unsafeEvictPartitions(false, true);
-
+ evictPartitions();
for (const auto & length : shuffle_writer->split_result.partition_lengths)
shuffle_writer->split_result.total_bytes_written += length;
}
@@ -831,4 +629,6 @@ size_t Partition::spill(NativeWriter & writer)
return written_bytes;
}
+
+
}
diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.h b/cpp-ch/local-engine/Shuffle/PartitionWriter.h
index e2c10b0cbd464..0c3c0be50f2d6 100644
--- a/cpp-ch/local-engine/Shuffle/PartitionWriter.h
+++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.h
@@ -18,13 +18,13 @@
#include
#include
#include
+#include
#include
#include
-#include
#include
#include
#include
-#include
+#include
#include
@@ -60,53 +60,43 @@ class Partition
size_t cached_bytes = 0;
};
-struct PartitionWriterSettings
-{
- uint64_t spill_memory_overhead = 0;
-
- void loadFromContext(DB::ContextPtr context);
-};
-
class CachedShuffleWriter;
using PartitionPtr = std::shared_ptr;
class PartitionWriter : boost::noncopyable
{
public:
- explicit PartitionWriter(CachedShuffleWriter * shuffle_writer_);
+ explicit PartitionWriter(CachedShuffleWriter * shuffle_writer_, LoggerPtr logger_);
virtual ~PartitionWriter() = default;
virtual String getName() const = 0;
virtual void write(const PartitionInfo & info, DB::Block & block);
- size_t evictPartitions(bool for_memory_spill = false, bool flush_block_buffer = false);
- void stop();
+ virtual void stop() = 0;
protected:
+ virtual size_t evictPartitions() = 0;
+
size_t bytes() const;
- virtual size_t unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer = false) = 0;
+ virtual bool worthToSpill(size_t cache_size) const;
virtual bool supportsEvictSinglePartition() const { return false; }
- virtual size_t unsafeEvictSinglePartition(bool for_memory_spill, bool flush_block_buffer, size_t partition_id)
+ virtual size_t evictSinglePartition(size_t partition_id)
{
throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Evict single partition is not supported for {}", getName());
}
- virtual void unsafeStop() = 0;
-
CachedShuffleWriter * shuffle_writer;
const SplitOptions * options;
- PartitionWriterSettings settings;
+ MemoryConfig settings;
std::vector partition_block_buffer;
std::vector partition_buffer;
- /// Make sure memory spill doesn't happen while write/stop are executed.
- bool evicting_or_writing{false};
-
/// Only valid in celeborn partition writer
size_t last_partition_id;
+ LoggerPtr logger = nullptr;
};
class Spillable
@@ -138,21 +128,21 @@ class LocalPartitionWriter : public PartitionWriter, public Spillable
String getName() const override { return "LocalPartitionWriter"; }
-protected:
- size_t unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer) override;
- void unsafeStop() override;
+ size_t evictPartitions() override;
+ void stop() override;
+
};
class SortBasedPartitionWriter : public PartitionWriter
{
-public:
- explicit SortBasedPartitionWriter(CachedShuffleWriter * shuffle_writer_) : PartitionWriter(shuffle_writer_)
+protected:
+ explicit SortBasedPartitionWriter(CachedShuffleWriter * shuffle_writer_, LoggerPtr logger) : PartitionWriter(shuffle_writer_, logger)
{
max_merge_block_size = options->split_size;
max_sort_buffer_size = options->max_sort_buffer_size;
max_merge_block_bytes = SerializedPlanParser::global_context->getSettings().prefer_external_sort_block_bytes;
}
-
+public:
String getName() const override { return "SortBasedPartitionWriter"; }
void write(const PartitionInfo & info, DB::Block & block) override;
size_t adaptiveBlockSize()
@@ -181,80 +171,32 @@ class MemorySortLocalPartitionWriter : public SortBasedPartitionWriter, public S
{
public:
explicit MemorySortLocalPartitionWriter(CachedShuffleWriter* shuffle_writer_)
- : SortBasedPartitionWriter(shuffle_writer_), Spillable(shuffle_writer_->options)
+ : SortBasedPartitionWriter(shuffle_writer_, getLogger("MemorySortLocalPartitionWriter")), Spillable(shuffle_writer_->options)
{
}
~MemorySortLocalPartitionWriter() override = default;
String getName() const override { return "MemorySortLocalPartitionWriter"; }
-protected:
- size_t unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer) override;
- void unsafeStop() override;
+ size_t evictPartitions() override;
+ void stop() override;
};
class MemorySortCelebornPartitionWriter : public SortBasedPartitionWriter
{
public:
explicit MemorySortCelebornPartitionWriter(CachedShuffleWriter* shuffle_writer_, std::unique_ptr celeborn_client_)
- : SortBasedPartitionWriter(shuffle_writer_), celeborn_client(std::move(celeborn_client_))
+ : SortBasedPartitionWriter(shuffle_writer_, getLogger("MemorySortCelebornPartitionWriter")), celeborn_client(std::move(celeborn_client_))
{
}
+ String getName() const override { return "MemorySortCelebornPartitionWriter"; }
~MemorySortCelebornPartitionWriter() override = default;
-protected:
- size_t unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer) override;
- void unsafeStop() override;
-
-private:
- std::unique_ptr celeborn_client;
-};
-
-class SortedPartitionDataMerger;
-
-class ExternalSortLocalPartitionWriter : public SortBasedPartitionWriter
-{
-public:
- struct MergeContext
- {
- CompressionCodecPtr codec;
- std::unique_ptr merger;
- };
-
- explicit ExternalSortLocalPartitionWriter(CachedShuffleWriter * shuffle_writer_) : SortBasedPartitionWriter(shuffle_writer_)
- {
- max_merge_block_size = options->split_size;
- max_sort_buffer_size = options->max_sort_buffer_size;
- max_merge_block_bytes = SerializedPlanParser::global_context->getSettings().prefer_external_sort_block_bytes;
- tmp_data = std::make_unique(SerializedPlanParser::global_context->getTempDataOnDisk());
- }
-
- ~ExternalSortLocalPartitionWriter() override = default;
-
- String getName() const override { return "ExternalSortLocalPartitionWriter"; }
+ void stop() override;
protected:
- size_t unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer) override;
- /// Prepare for data merging, spill the remaining memory data,and create a merger object.
- MergeContext prepareMerge();
- void unsafeStop() override;
- std::queue mergeDataInMemory();
-
- TemporaryDataOnDiskPtr tmp_data;
- std::vector streams;
-};
-
-class ExternalSortCelebornPartitionWriter : public ExternalSortLocalPartitionWriter
-{
-public:
- explicit ExternalSortCelebornPartitionWriter(CachedShuffleWriter * shuffle_writer_, std::unique_ptr celeborn_client_)
- : ExternalSortLocalPartitionWriter(shuffle_writer_), celeborn_client(std::move(celeborn_client_))
- {
- }
-protected:
- void unsafeStop() override;
-
+ size_t evictPartitions() override;
private:
std::unique_ptr celeborn_client;
};
@@ -266,15 +208,12 @@ class CelebornPartitionWriter : public PartitionWriter
~CelebornPartitionWriter() override = default;
String getName() const override { return "CelebornPartitionWriter"; }
-
+ void stop() override;
protected:
- size_t unsafeEvictPartitions(bool for_memory_spill, bool flush_block_buffer) override;
-
+ size_t evictPartitions() override;
bool supportsEvictSinglePartition() const override { return true; }
- size_t unsafeEvictSinglePartition(bool for_memory_spill, bool flush_block_buffer, size_t partition_id) override;
-
- void unsafeStop() override;
-
+ size_t evictSinglePartition(size_t partition_id) override;
+private:
std::unique_ptr celeborn_client;
};
}
diff --git a/cpp-ch/local-engine/Shuffle/ShuffleCommon.cpp b/cpp-ch/local-engine/Shuffle/ShuffleCommon.cpp
new file mode 100644
index 0000000000000..e0d8c0e84eaa4
--- /dev/null
+++ b/cpp-ch/local-engine/Shuffle/ShuffleCommon.cpp
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include
+#include
+#include
+#include
+
+namespace local_engine
+{
+void ColumnsBuffer::add(DB::Block & block, int start, int end)
+{
+ if (!header)
+ header = block.cloneEmpty();
+
+ if (accumulated_columns.empty())
+ {
+ accumulated_columns.reserve(block.columns());
+ for (size_t i = 0; i < block.columns(); i++)
+ {
+ auto column = block.getColumns()[i]->cloneEmpty();
+ column->reserve(prefer_buffer_size);
+ accumulated_columns.emplace_back(std::move(column));
+ }
+ }
+
+ assert(!accumulated_columns.empty());
+ for (size_t i = 0; i < block.columns(); ++i)
+ {
+ if (!accumulated_columns[i]->onlyNull())
+ {
+ accumulated_columns[i]->insertRangeFrom(*block.getByPosition(i).column, start, end - start);
+ }
+ else
+ {
+ accumulated_columns[i]->insertMany(DB::Field(), end - start);
+ }
+ }
+}
+
+void ColumnsBuffer::appendSelective(
+ size_t column_idx,
+ const DB::Block & source,
+ const DB::IColumn::Selector & selector,
+ size_t from,
+ size_t length)
+{
+ if (!header)
+ header = source.cloneEmpty();
+
+ if (accumulated_columns.empty())
+ {
+ accumulated_columns.reserve(source.columns());
+ for (size_t i = 0; i < source.columns(); i++)
+ {
+ auto column = source.getColumns()[i]->convertToFullIfNeeded()->cloneEmpty();
+ column->reserve(prefer_buffer_size);
+ accumulated_columns.emplace_back(std::move(column));
+ }
+ }
+
+ if (!accumulated_columns[column_idx]->onlyNull())
+ {
+ accumulated_columns[column_idx]->insertRangeSelective(
+ *source.getByPosition(column_idx).column->convertToFullIfNeeded(),
+ selector,
+ from,
+ length);
+ }
+ else
+ {
+ accumulated_columns[column_idx]->insertMany(DB::Field(), length);
+ }
+}
+
+size_t ColumnsBuffer::size() const
+{
+ return accumulated_columns.empty() ? 0 : accumulated_columns[0]->size();
+}
+
+bool ColumnsBuffer::empty() const
+{
+ return accumulated_columns.empty() ? true : accumulated_columns[0]->empty();
+}
+
+DB::Block ColumnsBuffer::releaseColumns()
+{
+ DB::Columns columns(std::make_move_iterator(accumulated_columns.begin()), std::make_move_iterator(accumulated_columns.end()));
+ accumulated_columns.clear();
+
+ if (columns.empty())
+ return header.cloneEmpty();
+ else
+ return header.cloneWithColumns(columns);
+}
+
+DB::Block ColumnsBuffer::getHeader()
+{
+ return header;
+}
+
+ColumnsBuffer::ColumnsBuffer(size_t prefer_buffer_size_)
+ : prefer_buffer_size(prefer_buffer_size_)
+{
+}
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.h b/cpp-ch/local-engine/Shuffle/ShuffleCommon.h
similarity index 60%
rename from cpp-ch/local-engine/Shuffle/ShuffleSplitter.h
rename to cpp-ch/local-engine/Shuffle/ShuffleCommon.h
index 75edea325c67b..d398362aa4b64 100644
--- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.h
+++ b/cpp-ch/local-engine/Shuffle/ShuffleCommon.h
@@ -47,14 +47,8 @@ struct SplitOptions
int compress_level;
size_t spill_threshold = 300 * 1024 * 1024;
std::string hash_algorithm;
- bool throw_if_memory_exceed = true;
- /// Whether to flush partition_block_buffer in PartitionWriter before evict.
- bool flush_block_buffer_before_evict = false;
size_t max_sort_buffer_size = 1_GiB;
- // Whether to spill firstly before stop external sort shuffle.
- bool spill_firstly_before_stop = true;
- bool force_external_sort = false;
- bool force_mermory_sort = false;
+ bool force_memory_sort = false;
};
class ColumnsBuffer
@@ -118,92 +112,6 @@ struct SplitResult
}
};
-class ShuffleSplitter;
-using ShuffleSplitterPtr = std::unique_ptr;
-class ShuffleSplitter : public ShuffleWriterBase
-{
-public:
- inline const static std::vector compress_methods = {"", "ZSTD", "LZ4"};
-
- static ShuffleSplitterPtr create(const std::string & short_name, const SplitOptions & options_);
-
- explicit ShuffleSplitter(const SplitOptions & options);
- virtual ~ShuffleSplitter() override
- {
- if (!stopped)
- stop();
- }
-
- void split(DB::Block & block) override;
- virtual void computeAndCountPartitionId(DB::Block &) { }
- std::vector getPartitionLength() const { return split_result.partition_lengths; }
- void writeIndexFile();
- SplitResult stop() override;
-
-private:
- void init();
- void initOutputIfNeeded(DB::Block & block);
- void splitBlockByPartition(DB::Block & block);
- void spillPartition(size_t partition_id);
- std::string getPartitionTempFile(size_t partition_id);
- void mergePartitionFiles();
- std::unique_ptr getPartitionWriteBuffer(size_t partition_id);
-
-protected:
- bool stopped = false;
- PartitionInfo partition_info;
- std::vector partition_buffer;
- std::vector> partition_outputs;
- std::vector> partition_write_buffers;
- std::vector> partition_cached_write_buffers;
- std::vector compressed_buffers;
- std::vector output_columns_indicies;
- DB::Block output_header;
- SplitOptions options;
- SplitResult split_result;
-};
-
-class RoundRobinSplitter : public ShuffleSplitter
-{
-public:
- static ShuffleSplitterPtr create(const SplitOptions & options);
-
- explicit RoundRobinSplitter(const SplitOptions & options_);
- virtual ~RoundRobinSplitter() override = default;
-
- void computeAndCountPartitionId(DB::Block & block) override;
-
-private:
- std::unique_ptr selector_builder;
-};
-
-class HashSplitter : public ShuffleSplitter
-{
-public:
- static ShuffleSplitterPtr create(const SplitOptions & options);
-
- explicit HashSplitter(SplitOptions options_);
- virtual ~HashSplitter() override = default;
-
- void computeAndCountPartitionId(DB::Block & block) override;
-
-private:
- std::unique_ptr selector_builder;
-};
-
-class RangeSplitter : public ShuffleSplitter
-{
-public:
- static ShuffleSplitterPtr create(const SplitOptions & options);
-
- explicit RangeSplitter(const SplitOptions & options_);
- virtual ~RangeSplitter() override = default;
-
- void computeAndCountPartitionId(DB::Block & block) override;
-
-private:
- std::unique_ptr selector_builder;
-};
struct SplitterHolder
{
std::unique_ptr splitter;
diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp
deleted file mode 100644
index 9baf3c4692c81..0000000000000
--- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp
+++ /dev/null
@@ -1,424 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "ShuffleSplitter.h"
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-namespace local_engine
-{
-
-void ShuffleSplitter::split(DB::Block & block)
-{
- if (block.rows() == 0)
- {
- return;
- }
- initOutputIfNeeded(block);
- computeAndCountPartitionId(block);
- Stopwatch split_time_watch;
- block = convertAggregateStateInBlock(block);
- split_result.total_split_time += split_time_watch.elapsedNanoseconds();
- splitBlockByPartition(block);
-}
-
-SplitResult ShuffleSplitter::stop()
-{
- // spill all buffers
- Stopwatch watch;
- for (size_t i = 0; i < options.partition_num; i++)
- {
- spillPartition(i);
- partition_outputs[i]->flush();
- partition_write_buffers[i]->sync();
- }
- for (auto * item : compressed_buffers)
- {
- if (item)
- {
- split_result.total_compress_time += item->getCompressTime();
- split_result.total_io_time += item->getWriteTime();
- }
- }
- split_result.total_serialize_time = split_result.total_spill_time - split_result.total_compress_time - split_result.total_io_time;
- partition_outputs.clear();
- partition_cached_write_buffers.clear();
- partition_write_buffers.clear();
- mergePartitionFiles();
- split_result.total_write_time += watch.elapsedNanoseconds();
- stopped = true;
- return split_result;
-}
-
-void ShuffleSplitter::initOutputIfNeeded(Block & block)
-{
- if (output_header.columns() == 0) [[unlikely]]
- {
- output_header = block.cloneEmpty();
- if (output_columns_indicies.empty())
- {
- output_header = block.cloneEmpty();
- for (size_t i = 0; i < block.columns(); ++i)
- {
- output_columns_indicies.push_back(i);
- }
- }
- else
- {
- ColumnsWithTypeAndName cols;
- for (const auto & index : output_columns_indicies)
- {
- cols.push_back(block.getByPosition(index));
- }
- output_header = DB::Block(cols);
- }
- }
-}
-
-void ShuffleSplitter::splitBlockByPartition(DB::Block & block)
-{
- Stopwatch split_time_watch;
- DB::Block out_block;
- for (size_t col = 0; col < output_header.columns(); ++col)
- {
- out_block.insert(block.getByPosition(output_columns_indicies[col]));
- }
- for (size_t col = 0; col < output_header.columns(); ++col)
- {
- for (size_t j = 0; j < partition_info.partition_num; ++j)
- {
- size_t from = partition_info.partition_start_points[j];
- size_t length = partition_info.partition_start_points[j + 1] - from;
- if (length == 0)
- continue; // no data for this partition continue;
- partition_buffer[j]->appendSelective(col, out_block, partition_info.partition_selector, from, length);
- }
- }
- split_result.total_split_time += split_time_watch.elapsedNanoseconds();
-
- for (size_t i = 0; i < options.partition_num; ++i)
- {
- auto & buffer = partition_buffer[i];
- if (buffer->size() >= options.split_size)
- {
- spillPartition(i);
- }
- }
-}
-
-ShuffleSplitter::ShuffleSplitter(const SplitOptions & options_) : options(options_)
-{
- init();
-}
-
-void ShuffleSplitter::init()
-{
- partition_buffer.resize(options.partition_num);
- partition_outputs.resize(options.partition_num);
- partition_write_buffers.resize(options.partition_num);
- partition_cached_write_buffers.resize(options.partition_num);
- split_result.partition_lengths.resize(options.partition_num);
- split_result.raw_partition_lengths.resize(options.partition_num);
- for (size_t partition_i = 0; partition_i < options.partition_num; ++partition_i)
- {
- partition_buffer[partition_i] = std::make_shared(options.split_size);
- split_result.partition_lengths[partition_i] = 0;
- split_result.raw_partition_lengths[partition_i] = 0;
- }
-}
-
-void ShuffleSplitter::spillPartition(size_t partition_id)
-{
- Stopwatch watch;
- if (!partition_outputs[partition_id])
- {
- partition_write_buffers[partition_id] = getPartitionWriteBuffer(partition_id);
- partition_outputs[partition_id]
- = std::make_unique(*partition_write_buffers[partition_id], output_header);
- }
- DB::Block result = partition_buffer[partition_id]->releaseColumns();
- if (result.rows() > 0)
- {
- partition_outputs[partition_id]->write(result);
- }
- split_result.total_spill_time += watch.elapsedNanoseconds();
- split_result.total_bytes_spilled += result.bytes();
-}
-
-void ShuffleSplitter::mergePartitionFiles()
-{
- Stopwatch merge_io_time;
- DB::WriteBufferFromFile data_write_buffer = DB::WriteBufferFromFile(options.data_file);
- std::string buffer;
- size_t buffer_size = options.io_buffer_size;
- buffer.reserve(buffer_size);
- for (size_t i = 0; i < options.partition_num; ++i)
- {
- auto file = getPartitionTempFile(i);
- DB::ReadBufferFromFile reader = DB::ReadBufferFromFile(file, options.io_buffer_size);
- while (reader.next())
- {
- auto bytes = reader.readBig(buffer.data(), buffer_size);
- data_write_buffer.write(buffer.data(), bytes);
- split_result.partition_lengths[i] += bytes;
- split_result.total_bytes_written += bytes;
- }
- reader.close();
- std::filesystem::remove(file);
- }
- split_result.total_io_time += merge_io_time.elapsedNanoseconds();
- data_write_buffer.close();
-}
-
-
-ShuffleSplitterPtr ShuffleSplitter::create(const std::string & short_name, const SplitOptions & options_)
-{
- if (short_name == "rr")
- return RoundRobinSplitter::create(options_);
- else if (short_name == "hash")
- return HashSplitter::create(options_);
- else if (short_name == "single")
- {
- SplitOptions options = options_;
- options.partition_num = 1;
- return RoundRobinSplitter::create(options);
- }
- else if (short_name == "range")
- return RangeSplitter::create(options_);
- else
- throw std::runtime_error("unsupported splitter " + short_name);
-}
-
-std::string ShuffleSplitter::getPartitionTempFile(size_t partition_id)
-{
- auto file_name = std::to_string(options.shuffle_id) + "_" + std::to_string(options.map_id) + "_" + std::to_string(partition_id);
- std::hash hasher;
- auto hash = hasher(file_name);
- auto dir_id = hash % options.local_dirs_list.size();
- auto sub_dir_id = (hash / options.local_dirs_list.size()) % options.num_sub_dirs;
-
- std::string dir = std::filesystem::path(options.local_dirs_list[dir_id]) / std::format("{:02x}", sub_dir_id);
- if (!std::filesystem::exists(dir))
- std::filesystem::create_directories(dir);
- return std::filesystem::path(dir) / file_name;
-}
-
-std::unique_ptr ShuffleSplitter::getPartitionWriteBuffer(size_t partition_id)
-{
- auto file = getPartitionTempFile(partition_id);
- if (partition_cached_write_buffers[partition_id] == nullptr)
- partition_cached_write_buffers[partition_id]
- = std::make_unique(file, options.io_buffer_size, O_CREAT | O_WRONLY | O_APPEND);
- if (!options.compress_method.empty()
- && std::find(compress_methods.begin(), compress_methods.end(), options.compress_method) != compress_methods.end())
- {
- auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), {});
- auto compressed = std::make_unique(*partition_cached_write_buffers[partition_id], codec);
- compressed_buffers.emplace_back(compressed.get());
- return compressed;
- }
- else
- {
- return std::move(partition_cached_write_buffers[partition_id]);
- }
-}
-
-void ShuffleSplitter::writeIndexFile()
-{
- auto index_file = options.data_file + ".index";
- auto writer = std::make_unique(index_file, options.io_buffer_size, O_CREAT | O_WRONLY | O_TRUNC);
- for (auto len : split_result.partition_lengths)
- {
- DB::writeIntText(len, *writer);
- DB::writeChar('\n', *writer);
- }
-}
-
-void ColumnsBuffer::add(DB::Block & block, int start, int end)
-{
- if (!header)
- header = block.cloneEmpty();
-
- if (accumulated_columns.empty())
- {
- accumulated_columns.reserve(block.columns());
- for (size_t i = 0; i < block.columns(); i++)
- {
- auto column = block.getColumns()[i]->cloneEmpty();
- column->reserve(prefer_buffer_size);
- accumulated_columns.emplace_back(std::move(column));
- }
- }
-
- assert(!accumulated_columns.empty());
- for (size_t i = 0; i < block.columns(); ++i)
- {
- if (!accumulated_columns[i]->onlyNull())
- {
- accumulated_columns[i]->insertRangeFrom(*block.getByPosition(i).column, start, end - start);
- }
- else
- {
- accumulated_columns[i]->insertMany(DB::Field(), end - start);
- }
- }
-}
-
-void ColumnsBuffer::appendSelective(
- size_t column_idx, const DB::Block & source, const DB::IColumn::Selector & selector, size_t from, size_t length)
-{
- if (!header)
- header = source.cloneEmpty();
-
- if (accumulated_columns.empty())
- {
- accumulated_columns.reserve(source.columns());
- for (size_t i = 0; i < source.columns(); i++)
- {
- auto column = source.getColumns()[i]->convertToFullIfNeeded()->cloneEmpty();
- column->reserve(prefer_buffer_size);
- accumulated_columns.emplace_back(std::move(column));
- }
- }
-
- if (!accumulated_columns[column_idx]->onlyNull())
- {
- accumulated_columns[column_idx]->insertRangeSelective(
- *source.getByPosition(column_idx).column->convertToFullIfNeeded(), selector, from, length);
- }
- else
- {
- accumulated_columns[column_idx]->insertMany(DB::Field(), length);
- }
-}
-
-size_t ColumnsBuffer::size() const
-{
- return accumulated_columns.empty() ? 0 : accumulated_columns[0]->size();
-}
-
-bool ColumnsBuffer::empty() const
-{
- return accumulated_columns.empty() ? true : accumulated_columns[0]->empty();
-}
-
-DB::Block ColumnsBuffer::releaseColumns()
-{
- DB::Columns columns(std::make_move_iterator(accumulated_columns.begin()), std::make_move_iterator(accumulated_columns.end()));
- accumulated_columns.clear();
-
- if (columns.empty())
- return header.cloneEmpty();
- else
- return header.cloneWithColumns(columns);
-}
-
-DB::Block ColumnsBuffer::getHeader()
-{
- return header;
-}
-
-ColumnsBuffer::ColumnsBuffer(size_t prefer_buffer_size_) : prefer_buffer_size(prefer_buffer_size_)
-{
-}
-
-RoundRobinSplitter::RoundRobinSplitter(const SplitOptions & options_) : ShuffleSplitter(options_)
-{
- Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ",");
- for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter)
- {
- output_columns_indicies.push_back(std::stoi(*iter));
- }
- selector_builder = std::make_unique(options.partition_num);
-}
-
-void RoundRobinSplitter::computeAndCountPartitionId(DB::Block & block)
-{
- Stopwatch watch;
- partition_info = selector_builder->build(block);
- split_result.total_compute_pid_time += watch.elapsedNanoseconds();
-}
-
-ShuffleSplitterPtr RoundRobinSplitter::create(const SplitOptions & options_)
-{
- return std::make_unique(options_);
-}
-
-HashSplitter::HashSplitter(SplitOptions options_) : ShuffleSplitter(options_)
-{
- Poco::StringTokenizer exprs_list(options_.hash_exprs, ",");
- std::vector hash_fields;
- for (auto iter = exprs_list.begin(); iter != exprs_list.end(); ++iter)
- {
- hash_fields.push_back(std::stoi(*iter));
- }
-
- Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ",");
- for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter)
- {
- output_columns_indicies.push_back(std::stoi(*iter));
- }
-
- selector_builder = std::make_unique(options.partition_num, hash_fields, options_.hash_algorithm);
-}
-
-std::unique_ptr HashSplitter::create(const SplitOptions & options_)
-{
- return std::make_unique(options_);
-}
-
-void HashSplitter::computeAndCountPartitionId(DB::Block & block)
-{
- Stopwatch watch;
- partition_info = selector_builder->build(block);
- split_result.total_compute_pid_time += watch.elapsedNanoseconds();
-}
-
-ShuffleSplitterPtr RangeSplitter::create(const SplitOptions & options_)
-{
- return std::make_unique(options_);
-}
-
-RangeSplitter::RangeSplitter(const SplitOptions & options_) : ShuffleSplitter(options_)
-{
- Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ",");
- for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter)
- {
- output_columns_indicies.push_back(std::stoi(*iter));
- }
- selector_builder = std::make_unique(options.hash_exprs, options.partition_num);
-}
-
-void RangeSplitter::computeAndCountPartitionId(DB::Block & block)
-{
- Stopwatch watch;
- partition_info = selector_builder->build(block);
- split_result.total_compute_pid_time += watch.elapsedNanoseconds();
-}
-}
diff --git a/cpp-ch/local-engine/Shuffle/SortedPartitionDataMerger.cpp b/cpp-ch/local-engine/Shuffle/SortedPartitionDataMerger.cpp
deleted file mode 100644
index 9c4ae6bf4680f..0000000000000
--- a/cpp-ch/local-engine/Shuffle/SortedPartitionDataMerger.cpp
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "SortedPartitionDataMerger.h"
-using namespace DB;
-namespace local_engine
-{
-SortedPartitionDataMerger::SortedPartitionDataMerger(
- std::unique_ptr algorithm,
- const std::vector & streams,
- std::queue & extra_blocks_in_memory_,
- const Block & output_header_)
-{
- merging_algorithm = std::move(algorithm);
- IMergingAlgorithm::Inputs initial_inputs;
- bool use_in_memory_data = !extra_blocks_in_memory_.empty();
- for (auto * stream : streams)
- {
- Block data = stream->read();
- IMergingAlgorithm::Input input;
- input.set({data.getColumns(), data.rows()});
- initial_inputs.emplace_back(std::move(input));
- }
- if (use_in_memory_data)
- {
- IMergingAlgorithm::Input input;
- const auto & data = extra_blocks_in_memory_.front();
- input.set({data.getColumns(), data.rows()});
- initial_inputs.emplace_back(std::move(input));
- extra_blocks_in_memory_.pop();
- }
- for (int i = 0; i < streams.size(); ++i)
- sources.emplace_back(std::make_shared(streams[i], i));
- if (use_in_memory_data)
- sources.emplace_back(std::make_shared(extra_blocks_in_memory_, sources.size()));
- output_header = output_header_;
- merging_algorithm->initialize(std::move(initial_inputs));
-}
-
-int64_t searchLastPartitionIdIndex(ColumnPtr column, size_t start, size_t partition_id)
-{
- const auto & int64_column = checkAndGetColumn(*column);
- int64_t low = start, high = int64_column.size() - 1;
- while (low <= high)
- {
- int64_t mid = low + (high - low) / 2;
- if (int64_column.get64(mid) > partition_id)
- high = mid - 1;
- else
- low = mid + 1;
- if (int64_column.get64(high) == partition_id)
- return high;
- }
- return -1;
-}
-
-SortedPartitionDataMerger::Result SortedPartitionDataMerger::next()
-{
- if (finished)
- return Result{.empty = true};
- Chunk chunk;
- while (true)
- {
- auto result = merging_algorithm->merge();
-
- if (result.required_source >= 0)
- {
- auto stream = sources[result.required_source];
- auto block = stream->next();
- if (block)
- {
- IMergingAlgorithm::Input input;
- input.set({block.getColumns(), block.rows()});
- merging_algorithm->consume(input, stream->getPartitionId());
- }
- }
- if (result.chunk.getNumRows() > 0)
- {
- chunk = std::move(result.chunk);
- break;
- }
- if (result.is_finished)
- {
- finished = true;
- if (chunk.getNumRows() == 0)
- return Result{.empty = true};
- break;
- }
- }
- Result partitions;
- size_t row_idx = 0;
- Columns result_columns;
- result_columns.reserve(chunk.getColumns().size() - 1);
- for (size_t i = 0; i < chunk.getColumns().size() - 1; ++i)
- result_columns.push_back(chunk.getColumns()[i]);
- while (row_idx < chunk.getNumRows())
- {
- auto idx = searchLastPartitionIdIndex(chunk.getColumns().back(), row_idx, current_partition_id);
- if (idx >= 0)
- {
- if (row_idx == 0 && idx == chunk.getNumRows() - 1)
- {
- partitions.blocks.emplace_back(output_header.cloneWithColumns(result_columns), current_partition_id);
- break;
- }
- else
- {
- Columns cut_columns;
- cut_columns.reserve(result_columns.size());
- for (auto & result_column : result_columns)
- cut_columns.push_back(result_column->cut(row_idx, idx - row_idx + 1));
- partitions.blocks.emplace_back(output_header.cloneWithColumns(cut_columns), current_partition_id);
- row_idx = idx + 1;
- if (idx != chunk.getNumRows() - 1)
- current_partition_id++;
- }
- }
- else
- {
- current_partition_id++;
- }
- }
- return partitions;
-}
-}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Shuffle/SortedPartitionDataMerger.h b/cpp-ch/local-engine/Shuffle/SortedPartitionDataMerger.h
deleted file mode 100644
index e38f58647e963..0000000000000
--- a/cpp-ch/local-engine/Shuffle/SortedPartitionDataMerger.h
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#pragma once
-#include
-#include
-#include
-
-
-namespace local_engine
-{
-
-int64_t searchLastPartitionIdIndex(DB::ColumnPtr column, size_t start, size_t partition_id);
-class SortedPartitionDataMerger;
-using SortedPartitionDataMergerPtr = std::unique_ptr;
-class SortedPartitionDataMerger
-{
-public:
- struct Result
- {
- bool empty = false;
- std::vector> blocks;
- };
-
- class SortedData
- {
- public:
- SortedData(DB::TemporaryFileStream * stream, size_t partitionId) : stream(stream), partition_id(partitionId) { }
- SortedData(const std::queue & blocksInMemory, size_t partitionId)
- : blocks_in_memory(blocksInMemory), partition_id(partitionId)
- {
- }
- DB::Block next()
- {
- if (stream)
- {
- auto data = stream->read();
- end = !data;
- return data;
- }
- if (!blocks_in_memory.empty())
- {
- auto block = blocks_in_memory.front();
- blocks_in_memory.pop();
- return block;
- }
- return {};
- }
- bool isEnd() const
- {
- if (stream)
- return stream->isEof();
- return blocks_in_memory.empty() || end;
- }
- size_t getPartitionId() const { return partition_id; }
-
- private:
- DB::TemporaryFileStream * stream = nullptr;
- std::queue blocks_in_memory;
- size_t partition_id;
- bool end = false;
- };
-
- SortedPartitionDataMerger(
- std::unique_ptr algorithm,
- const std::vector & streams,
- std::queue & extra_blocks_in_memory,
- const DB::Block & output_header);
- Result next();
- bool isFinished() const { return finished; }
-
-private:
- std::unique_ptr merging_algorithm;
- std::vector> sources;
- DB::Block output_header;
- bool finished = false;
- size_t current_partition_id = 0;
-};
-
-}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
index 403b845147fa4..3a394a4b4e2eb 100644
--- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
+++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
@@ -149,7 +149,6 @@ bool SparkMergeTreeWriter::blockToPart(Block & block)
new_parts.emplace_back(writeTempPartAndFinalize(item, metadata_snapshot).part);
part_num++;
- manualFreeMemory(before_write_memory);
/// Reset earlier to free memory
item.block.clear();
item.partition.clear();
@@ -158,36 +157,6 @@ bool SparkMergeTreeWriter::blockToPart(Block & block)
return true;
}
-void SparkMergeTreeWriter::manualFreeMemory(size_t before_write_memory)
-{
- // If mergetree disk is not local fs, like remote fs s3 or hdfs,
- // it may alloc memory in current thread, and free on global thread.
- // Now, wo have not idea to clear global memory by used spark thread tracker.
- // So we manually correct the memory usage.
- if (isRemoteStorage && insert_without_local_storage)
- return;
-
- auto disk = storage->getStoragePolicy()->getAnyDisk();
- std::lock_guard lock(memory_mutex);
- auto * memory_tracker = CurrentThread::getMemoryTracker();
- if (memory_tracker && CurrentMemoryTracker::before_free)
- {
- CurrentThread::flushUntrackedMemory();
- const size_t ch_alloc = memory_tracker->get();
- if (disk->getName().contains("s3") && context->getSettings().s3_allow_parallel_part_upload && ch_alloc > before_write_memory)
- {
- const size_t diff_ch_alloc = before_write_memory - ch_alloc;
- memory_tracker->adjustWithUntrackedMemory(diff_ch_alloc);
- }
-
- const size_t spark_alloc = CurrentMemoryTracker::current_memory();
- const size_t diff_alloc = spark_alloc - memory_tracker->get();
-
- if (diff_alloc > 0)
- CurrentMemoryTracker::before_free(diff_alloc);
- }
-}
-
void SparkMergeTreeWriter::finalize()
{
chunkToPart(squashing->flush());
diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h
index 269b0352c0566..d1fd82f371d6f 100644
--- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h
+++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h
@@ -73,7 +73,6 @@ class SparkMergeTreeWriter
void checkAndMerge(bool force = false);
void safeEmplaceBackPart(DB::MergeTreeDataPartPtr);
void safeAddPart(DB::MergeTreeDataPartPtr);
- void manualFreeMemory(size_t before_write_memory);
void saveMetadata();
void commitPartToRemoteStorageIfNeeded();
void finalizeMerge();
diff --git a/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.cpp b/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.cpp
index 0731ac92cd078..c59d6ddb4bd41 100644
--- a/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.cpp
+++ b/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.cpp
@@ -16,6 +16,8 @@
*/
#include "StorageMergeTreeFactory.h"
+#include
+
namespace local_engine
{
@@ -67,14 +69,12 @@ DataPartsVector StorageMergeTreeFactory::getDataPartsByNames(const StorageID & i
{
DataPartsVector res;
auto table_name = getTableName(id, snapshot_id);
-
+ auto config = MergeTreeConfig::loadFromContext(SerializedPlanParser::global_context);
std::lock_guard lock(datapart_mutex);
std::unordered_set missing_names;
if (!datapart_map->has(table_name)) [[unlikely]]
{
- auto cache = std::make_shared>(
- SerializedPlanParser::global_context->getConfigRef().getInt64("table_part_metadata_cache_max_count", 1000000)
- );
+ auto cache = std::make_shared>(config.table_part_metadata_cache_max_count);
datapart_map->add(table_name, cache);
}
diff --git a/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.h b/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.h
index d7bcb93c07d7d..f372175bb02ce 100644
--- a/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.h
+++ b/cpp-ch/local-engine/Storages/StorageMergeTreeFactory.h
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#pragma once
+#include
#include
#include
#include
@@ -34,11 +35,11 @@ class StorageMergeTreeFactory
static DataPartsVector getDataPartsByNames(const StorageID & id, const String & snapshot_id, std::unordered_set part_name);
static void init_cache_map()
{
+ auto config = MergeTreeConfig::loadFromContext(SerializedPlanParser::global_context);
auto & storage_map_v = storage_map;
if (!storage_map_v)
{
- storage_map_v = std::make_unique>(
- SerializedPlanParser::global_context->getConfigRef().getInt64("table_metadata_cache_max_count", 100));
+ storage_map_v = std::make_unique>(config.table_metadata_cache_max_count);
}
else
{
@@ -48,7 +49,7 @@ class StorageMergeTreeFactory
if (!datapart_map_v)
{
datapart_map_v = std::make_unique>>>(
- SerializedPlanParser::global_context->getConfigRef().getInt64("table_metadata_cache_max_count", 100));
+ config.table_metadata_cache_max_count);
}
else
{
diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp
index 0221afd885141..00acaf58398cc 100644
--- a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp
+++ b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp
@@ -36,6 +36,7 @@
#include
#endif
+#include
#include
namespace DB
@@ -81,8 +82,8 @@ FormatFilePtr FormatFileUtil::createFile(
#if USE_PARQUET
if (file.has_parquet())
{
- bool useLocalFormat = context->getConfigRef().getBool("use_local_format", false);
- return std::make_shared(context, file, read_buffer_builder, useLocalFormat);
+ auto config = ExecutorConfig::loadFromContext(context);
+ return std::make_shared(context, file, read_buffer_builder, config.use_local_format);
}
#endif
diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp
index e73ca8ecee2bb..da15890070b09 100644
--- a/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp
+++ b/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp
@@ -48,6 +48,7 @@
#include
#include
#include
+#include
#include