From 66495b677a728ff75a8674b217672cd51aece640 Mon Sep 17 00:00:00 2001 From: Bharath Kumarasubramanian Date: Tue, 21 Nov 2023 11:56:08 -0800 Subject: [PATCH] SAMZA-2796: Introduce config knob for framework thread sub DAG execution (#1691) Description As part of SAMZA-2781, we use framework thread pool to execute hand-offs and sub-DAG execution. We want to add a config knob to enable users opt-in to the feature as opposed to enable it by default. Changes Introduce config knob to use the framework executor Tests Added unit tests Usage Instructions Refer to the configuration documentation. To enable framework thread pool for sub-DAG execution and message hand off, set job.operator.framework.executor.enabled to true --- .../versioned/jobs/configuration-table.html | 10 ++ .../org/apache/samza/config/JobConfig.java | 8 ++ .../samza/operators/impl/OperatorImpl.java | 92 ++++++++------ .../apache/samza/container/TaskInstance.scala | 6 +- .../operators/impl/TestOperatorImpl.java | 112 ++++++++++++++++-- 5 files changed, 175 insertions(+), 53 deletions(-) diff --git a/docs/learn/documentation/versioned/jobs/configuration-table.html b/docs/learn/documentation/versioned/jobs/configuration-table.html index e00c983d8b..390be03761 100644 --- a/docs/learn/documentation/versioned/jobs/configuration-table.html +++ b/docs/learn/documentation/versioned/jobs/configuration-table.html @@ -494,6 +494,16 @@

Samza Configuration Reference

+ + job.operator.framework.executor.enabled + false + + If enabled, framework thread pool will be used for message hand off and sub DAG execution. Otherwise, the + execution will fall back to using caller thread or java fork join pool depending on the type of work + chained as part of message hand off. + + + Zookeeper-based job configuration diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java index 3d0b532625..17f527252e 100644 --- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java +++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java @@ -197,6 +197,10 @@ public class JobConfig extends MapConfig { public static final String JOB_ELASTICITY_FACTOR = "job.elasticity.factor"; public static final int DEFAULT_JOB_ELASTICITY_FACTOR = 1; + public static final String JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = "job.operator.framework.executor.enabled"; + + public static final boolean DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = false; + public JobConfig(Config config) { super(config); } @@ -528,4 +532,8 @@ public int getElasticityFactor() { public String getCoordinatorExecuteCommand() { return get(COORDINATOR_EXECUTE_COMMAND, DEFAULT_COORDINATOR_EXECUTE_COMMAND); } + + public boolean getOperatorFrameworkExecutorEnabled() { + return getBoolean(JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED, DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED); + } } \ No newline at end of file diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java index 8b477d42db..c870264e91 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java @@ -22,6 +22,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import java.util.function.Function; import org.apache.samza.SamzaException; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; @@ -95,6 +97,7 @@ public abstract class OperatorImpl { private ControlMessageSender controlMessageSender; private int elasticityFactor; private ExecutorService operatorExecutor; + private boolean operatorExecutorEnabled; /** * Initialize this {@link OperatorImpl} and its user-defined functions. @@ -136,7 +139,9 @@ public final void init(InternalTaskContext internalTaskContext) { this.taskModel = taskContext.getTaskModel(); this.callbackScheduler = taskContext.getCallbackScheduler(); handleInit(context); - this.elasticityFactor = new JobConfig(config).getElasticityFactor(); + JobConfig jobConfig = new JobConfig(config); + this.elasticityFactor = jobConfig.getElasticityFactor(); + this.operatorExecutorEnabled = jobConfig.getOperatorFrameworkExecutorEnabled(); this.operatorExecutor = context.getTaskContext().getOperatorExecutor(); initialized = true; @@ -192,21 +197,20 @@ public final CompletionStage onMessageAsync(M message, MessageCollector co getOpImplId(), getOperatorSpec().getSourceLocation(), expectedType, actualType), e); } - CompletionStage result = completableResultsFuture.thenComposeAsync(results -> { + CompletionStage result = composeFutureWithExecutor(completableResultsFuture, results -> { long endNs = this.highResClock.nanoTime(); this.handleMessageNs.update(endNs - startNs); return CompletableFuture.allOf(results.stream() - .flatMap(r -> this.registeredOperators.stream() - .map(op -> op.onMessageAsync(r, collector, coordinator))) + .flatMap(r -> this.registeredOperators.stream().map(op -> op.onMessageAsync(r, collector, coordinator))) .toArray(CompletableFuture[]::new)); - }, operatorExecutor); + }); WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn(); if (watermarkFn != null) { // check whether there is new watermark emitted from the user function Long outputWm = watermarkFn.getOutputWatermark(); - return result.thenComposeAsync(ignored -> propagateWatermark(outputWm, collector, coordinator), operatorExecutor); + return composeFutureWithExecutor(result, ignored -> propagateWatermark(outputWm, collector, coordinator)); } return result; @@ -245,11 +249,9 @@ public final CompletionStage onTimer(MessageCollector collector, TaskCoord .map(op -> op.onMessageAsync(r, collector, coordinator))) .toArray(CompletableFuture[]::new)); - return resultFuture.thenComposeAsync(x -> - CompletableFuture.allOf(this.registeredOperators - .stream() - .map(op -> op.onTimer(collector, coordinator)) - .toArray(CompletableFuture[]::new)), operatorExecutor); + return composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(this.registeredOperators.stream() + .map(op -> op.onTimer(collector, coordinator)) + .toArray(CompletableFuture[]::new))); } /** @@ -315,15 +317,14 @@ public final CompletionStage aggregateEndOfStream(EndOfStreamMessage eos, } // populate the end-of-stream through the dag - endOfStreamFuture = onEndOfStream(collector, coordinator) - .thenAcceptAsync(result -> { - if (eosStates.allEndOfStream()) { - // all inputs have been end-of-stream, shut down the task - LOG.info("All input streams have reached the end for task {}", taskName.getTaskName()); - coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); - coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); - } - }, operatorExecutor); + endOfStreamFuture = acceptFutureWithExecutor(onEndOfStream(collector, coordinator), result -> { + if (eosStates.allEndOfStream()) { + // all inputs have been end-of-stream, shut down the task + LOG.info("All input streams have reached the end for task {}", taskName.getTaskName()); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + } + }); } return endOfStreamFuture; @@ -347,10 +348,10 @@ private CompletionStage onEndOfStream(MessageCollector collector, TaskCoor .map(op -> op.onMessageAsync(r, collector, coordinator))) .toArray(CompletableFuture[]::new)); - endOfStreamFuture = resultFuture.thenComposeAsync(x -> - CompletableFuture.allOf(this.registeredOperators.stream() + endOfStreamFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf( + this.registeredOperators.stream() .map(op -> op.onEndOfStream(collector, coordinator)) - .toArray(CompletableFuture[]::new)), operatorExecutor); + .toArray(CompletableFuture[]::new))); } return endOfStreamFuture; @@ -406,15 +407,14 @@ public final CompletionStage aggregateDrainMessages(DrainMessage drainMess controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector); } - drainFuture = onDrainOfStream(collector, coordinator) - .thenAcceptAsync(result -> { - if (drainStates.areAllStreamsDrained()) { - // All input streams have been drained, shut down the task - LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName()); - coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); - coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); - } - }, operatorExecutor); + drainFuture = acceptFutureWithExecutor(onDrainOfStream(collector, coordinator), result -> { + if (drainStates.areAllStreamsDrained()) { + // All input streams have been drained, shut down the task + LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName()); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + } + }); } return drainFuture; @@ -439,10 +439,10 @@ private CompletionStage onDrainOfStream(MessageCollector collector, TaskCo .toArray(CompletableFuture[]::new)); // propagate DrainMessage to downstream operators - drainFuture = resultFuture.thenComposeAsync(x -> - CompletableFuture.allOf(this.registeredOperators.stream() + drainFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf( + this.registeredOperators.stream() .map(op -> op.onDrainOfStream(collector, coordinator)) - .toArray(CompletableFuture[]::new)), operatorExecutor); + .toArray(CompletableFuture[]::new))); } return drainFuture; @@ -474,8 +474,8 @@ public final CompletionStage aggregateWatermark(WatermarkMessage watermark controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector); } // populate the watermark through the dag - watermarkFuture = onWatermark(watermark, collector, coordinator) - .thenAcceptAsync(ignored -> watermarkStates.updateAggregateMetric(ssp, watermark), operatorExecutor); + watermarkFuture = acceptFutureWithExecutor(onWatermark(watermark, collector, coordinator), + ignored -> watermarkStates.updateAggregateMetric(ssp, watermark)); } return watermarkFuture; @@ -530,8 +530,8 @@ private CompletionStage onWatermark(long watermark, MessageCollector colle .toArray(CompletableFuture[]::new)); } - watermarkFuture = watermarkFuture.thenComposeAsync(res -> propagateWatermark(outputWm, collector, coordinator), - operatorExecutor); + watermarkFuture = + composeFutureWithExecutor(watermarkFuture, res -> propagateWatermark(outputWm, collector, coordinator)); } return watermarkFuture; @@ -679,6 +679,20 @@ final Collection handleMessage(M message, MessageCollector collector, TaskCo .toCompletableFuture().join(); } + @VisibleForTesting + final CompletionStage composeFutureWithExecutor(CompletionStage futureToChain, + Function> fn) { + return operatorExecutorEnabled ? futureToChain.thenComposeAsync(fn, operatorExecutor) + : futureToChain.thenCompose(fn); + } + + @VisibleForTesting + final CompletionStage acceptFutureWithExecutor(CompletionStage futureToChain, + Consumer consumer) { + return operatorExecutorEnabled ? futureToChain.thenAcceptAsync(consumer, operatorExecutor) + : futureToChain.thenAccept(consumer); + } + private HighResolutionClock createHighResClock(Config config) { MetricsConfig metricsConfig = new MetricsConfig(config); // The timer metrics calculation here is only enabled for debugging diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala index 285e7c8778..89738e2de0 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala @@ -93,9 +93,13 @@ class TaskInstance( val jobConfig = new JobConfig(jobContext.getConfig) val taskExecutorFactory = ReflectionUtil.getObj(jobConfig.getTaskExecutorFactory, classOf[TaskExecutorFactory]) + var operatorExecutor = Option.empty[java.util.concurrent.ExecutorService].orNull + if (jobConfig.getOperatorFrameworkExecutorEnabled) { + operatorExecutor = taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig) + } new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, tableManager, new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, streamMetadataCache, - systemStreamPartitions, taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig)) + systemStreamPartitions, operatorExecutor) } // need separate field for this instead of using it through Context, since Context throws an exception if it is null private val applicationTaskContextOption = applicationTaskContextFactoryOption diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java index 9cb307d57a..6709417b98 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java @@ -18,16 +18,24 @@ */ package org.apache.samza.operators.impl; +import com.google.common.collect.ImmutableMap; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.samza.config.Config; +import org.apache.samza.config.MapConfig; +import org.apache.samza.context.ContainerContext; import org.apache.samza.context.Context; import org.apache.samza.context.InternalTaskContext; -import org.apache.samza.context.MockContext; +import org.apache.samza.context.JobContext; +import org.apache.samza.context.TaskContext; import org.apache.samza.job.model.TaskModel; import org.apache.samza.metrics.Counter; import org.apache.samza.metrics.MetricsRegistryMap; @@ -44,33 +52,111 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.anyObject; -import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Matchers.*; +import static org.mockito.Mockito.*; public class TestOperatorImpl { private Context context; private InternalTaskContext internalTaskContext; + private JobContext jobContext; + + private TaskContext taskContext; + + private ContainerContext containerContext; + @Before public void setup() { - this.context = new MockContext(); + this.context = mock(Context.class); this.internalTaskContext = mock(InternalTaskContext.class); + this.jobContext = mock(JobContext.class); + this.taskContext = mock(TaskContext.class); + this.containerContext = mock(ContainerContext.class); when(this.internalTaskContext.getContext()).thenReturn(this.context); // might be necessary in the future when(this.internalTaskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(mock(EndOfStreamStates.class)); when(this.internalTaskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class)); - when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap()); - when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class)); - when(this.context.getTaskContext().getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor()); - when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + when(this.context.getJobContext()).thenReturn(jobContext); + when(this.context.getTaskContext()).thenReturn(taskContext); + when(this.taskContext.getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + when(this.taskContext.getTaskModel()).thenReturn(mock(TaskModel.class)); + when(this.taskContext.getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor()); + when(this.context.getContainerContext()).thenReturn(containerContext); + when(containerContext.getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + } + + @Test + public void testComposeFutureWithExecutorWithFrameworkExecutorEnabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Function> mockFunction = mock(Function.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.composeFutureWithExecutor(mockFuture, mockFunction); + + verify(mockFuture).thenComposeAsync(eq(mockFunction), eq(mockExecutor)); + } + + @Test + public void testComposeFutureWithExecutorWithFrameworkExecutorDisabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Function> mockFunction = mock(Function.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.composeFutureWithExecutor(mockFuture, mockFunction); + + verify(mockFuture).thenCompose(eq(mockFunction)); } + @Test + public void testAcceptFutureWithExecutorWithFrameworkExecutorDisabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Consumer mockConsumer = mock(Consumer.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer); + + verify(mockFuture).thenAccept(eq(mockConsumer)); + } + + @Test + public void testAcceptFutureWithExecutorWithFrameworkExecutorEnabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Consumer mockConsumer = mock(Consumer.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer); + + verify(mockFuture).thenAcceptAsync(eq(mockConsumer), eq(mockExecutor)); + } @Test(expected = IllegalStateException.class) public void testMultipleInitShouldThrow() { OperatorImpl opImpl = new TestOpImpl(mock(Object.class));