diff --git a/docs/learn/documentation/versioned/jobs/configuration-table.html b/docs/learn/documentation/versioned/jobs/configuration-table.html index e00c983d8b..390be03761 100644 --- a/docs/learn/documentation/versioned/jobs/configuration-table.html +++ b/docs/learn/documentation/versioned/jobs/configuration-table.html @@ -494,6 +494,16 @@

Samza Configuration Reference

+ + job.operator.framework.executor.enabled + false + + If enabled, framework thread pool will be used for message hand off and sub DAG execution. Otherwise, the + execution will fall back to using caller thread or java fork join pool depending on the type of work + chained as part of message hand off. + + + Zookeeper-based job configuration diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java index 3d0b532625..17f527252e 100644 --- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java +++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java @@ -197,6 +197,10 @@ public class JobConfig extends MapConfig { public static final String JOB_ELASTICITY_FACTOR = "job.elasticity.factor"; public static final int DEFAULT_JOB_ELASTICITY_FACTOR = 1; + public static final String JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = "job.operator.framework.executor.enabled"; + + public static final boolean DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = false; + public JobConfig(Config config) { super(config); } @@ -528,4 +532,8 @@ public int getElasticityFactor() { public String getCoordinatorExecuteCommand() { return get(COORDINATOR_EXECUTE_COMMAND, DEFAULT_COORDINATOR_EXECUTE_COMMAND); } + + public boolean getOperatorFrameworkExecutorEnabled() { + return getBoolean(JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED, DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED); + } } \ No newline at end of file diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java index 8b477d42db..c870264e91 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java @@ -22,6 +22,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import java.util.function.Function; import org.apache.samza.SamzaException; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; @@ -95,6 +97,7 @@ public abstract class OperatorImpl { private ControlMessageSender controlMessageSender; private int elasticityFactor; private ExecutorService operatorExecutor; + private boolean operatorExecutorEnabled; /** * Initialize this {@link OperatorImpl} and its user-defined functions. @@ -136,7 +139,9 @@ public final void init(InternalTaskContext internalTaskContext) { this.taskModel = taskContext.getTaskModel(); this.callbackScheduler = taskContext.getCallbackScheduler(); handleInit(context); - this.elasticityFactor = new JobConfig(config).getElasticityFactor(); + JobConfig jobConfig = new JobConfig(config); + this.elasticityFactor = jobConfig.getElasticityFactor(); + this.operatorExecutorEnabled = jobConfig.getOperatorFrameworkExecutorEnabled(); this.operatorExecutor = context.getTaskContext().getOperatorExecutor(); initialized = true; @@ -192,21 +197,20 @@ public final CompletionStage onMessageAsync(M message, MessageCollector co getOpImplId(), getOperatorSpec().getSourceLocation(), expectedType, actualType), e); } - CompletionStage result = completableResultsFuture.thenComposeAsync(results -> { + CompletionStage result = composeFutureWithExecutor(completableResultsFuture, results -> { long endNs = this.highResClock.nanoTime(); this.handleMessageNs.update(endNs - startNs); return CompletableFuture.allOf(results.stream() - .flatMap(r -> this.registeredOperators.stream() - .map(op -> op.onMessageAsync(r, collector, coordinator))) + .flatMap(r -> this.registeredOperators.stream().map(op -> op.onMessageAsync(r, collector, coordinator))) .toArray(CompletableFuture[]::new)); - }, operatorExecutor); + }); WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn(); if (watermarkFn != null) { // check whether there is new watermark emitted from the user function Long outputWm = watermarkFn.getOutputWatermark(); - return result.thenComposeAsync(ignored -> propagateWatermark(outputWm, collector, coordinator), operatorExecutor); + return composeFutureWithExecutor(result, ignored -> propagateWatermark(outputWm, collector, coordinator)); } return result; @@ -245,11 +249,9 @@ public final CompletionStage onTimer(MessageCollector collector, TaskCoord .map(op -> op.onMessageAsync(r, collector, coordinator))) .toArray(CompletableFuture[]::new)); - return resultFuture.thenComposeAsync(x -> - CompletableFuture.allOf(this.registeredOperators - .stream() - .map(op -> op.onTimer(collector, coordinator)) - .toArray(CompletableFuture[]::new)), operatorExecutor); + return composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(this.registeredOperators.stream() + .map(op -> op.onTimer(collector, coordinator)) + .toArray(CompletableFuture[]::new))); } /** @@ -315,15 +317,14 @@ public final CompletionStage aggregateEndOfStream(EndOfStreamMessage eos, } // populate the end-of-stream through the dag - endOfStreamFuture = onEndOfStream(collector, coordinator) - .thenAcceptAsync(result -> { - if (eosStates.allEndOfStream()) { - // all inputs have been end-of-stream, shut down the task - LOG.info("All input streams have reached the end for task {}", taskName.getTaskName()); - coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); - coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); - } - }, operatorExecutor); + endOfStreamFuture = acceptFutureWithExecutor(onEndOfStream(collector, coordinator), result -> { + if (eosStates.allEndOfStream()) { + // all inputs have been end-of-stream, shut down the task + LOG.info("All input streams have reached the end for task {}", taskName.getTaskName()); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + } + }); } return endOfStreamFuture; @@ -347,10 +348,10 @@ private CompletionStage onEndOfStream(MessageCollector collector, TaskCoor .map(op -> op.onMessageAsync(r, collector, coordinator))) .toArray(CompletableFuture[]::new)); - endOfStreamFuture = resultFuture.thenComposeAsync(x -> - CompletableFuture.allOf(this.registeredOperators.stream() + endOfStreamFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf( + this.registeredOperators.stream() .map(op -> op.onEndOfStream(collector, coordinator)) - .toArray(CompletableFuture[]::new)), operatorExecutor); + .toArray(CompletableFuture[]::new))); } return endOfStreamFuture; @@ -406,15 +407,14 @@ public final CompletionStage aggregateDrainMessages(DrainMessage drainMess controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector); } - drainFuture = onDrainOfStream(collector, coordinator) - .thenAcceptAsync(result -> { - if (drainStates.areAllStreamsDrained()) { - // All input streams have been drained, shut down the task - LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName()); - coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); - coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); - } - }, operatorExecutor); + drainFuture = acceptFutureWithExecutor(onDrainOfStream(collector, coordinator), result -> { + if (drainStates.areAllStreamsDrained()) { + // All input streams have been drained, shut down the task + LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName()); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + } + }); } return drainFuture; @@ -439,10 +439,10 @@ private CompletionStage onDrainOfStream(MessageCollector collector, TaskCo .toArray(CompletableFuture[]::new)); // propagate DrainMessage to downstream operators - drainFuture = resultFuture.thenComposeAsync(x -> - CompletableFuture.allOf(this.registeredOperators.stream() + drainFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf( + this.registeredOperators.stream() .map(op -> op.onDrainOfStream(collector, coordinator)) - .toArray(CompletableFuture[]::new)), operatorExecutor); + .toArray(CompletableFuture[]::new))); } return drainFuture; @@ -474,8 +474,8 @@ public final CompletionStage aggregateWatermark(WatermarkMessage watermark controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector); } // populate the watermark through the dag - watermarkFuture = onWatermark(watermark, collector, coordinator) - .thenAcceptAsync(ignored -> watermarkStates.updateAggregateMetric(ssp, watermark), operatorExecutor); + watermarkFuture = acceptFutureWithExecutor(onWatermark(watermark, collector, coordinator), + ignored -> watermarkStates.updateAggregateMetric(ssp, watermark)); } return watermarkFuture; @@ -530,8 +530,8 @@ private CompletionStage onWatermark(long watermark, MessageCollector colle .toArray(CompletableFuture[]::new)); } - watermarkFuture = watermarkFuture.thenComposeAsync(res -> propagateWatermark(outputWm, collector, coordinator), - operatorExecutor); + watermarkFuture = + composeFutureWithExecutor(watermarkFuture, res -> propagateWatermark(outputWm, collector, coordinator)); } return watermarkFuture; @@ -679,6 +679,20 @@ final Collection handleMessage(M message, MessageCollector collector, TaskCo .toCompletableFuture().join(); } + @VisibleForTesting + final CompletionStage composeFutureWithExecutor(CompletionStage futureToChain, + Function> fn) { + return operatorExecutorEnabled ? futureToChain.thenComposeAsync(fn, operatorExecutor) + : futureToChain.thenCompose(fn); + } + + @VisibleForTesting + final CompletionStage acceptFutureWithExecutor(CompletionStage futureToChain, + Consumer consumer) { + return operatorExecutorEnabled ? futureToChain.thenAcceptAsync(consumer, operatorExecutor) + : futureToChain.thenAccept(consumer); + } + private HighResolutionClock createHighResClock(Config config) { MetricsConfig metricsConfig = new MetricsConfig(config); // The timer metrics calculation here is only enabled for debugging diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala index 285e7c8778..89738e2de0 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala @@ -93,9 +93,13 @@ class TaskInstance( val jobConfig = new JobConfig(jobContext.getConfig) val taskExecutorFactory = ReflectionUtil.getObj(jobConfig.getTaskExecutorFactory, classOf[TaskExecutorFactory]) + var operatorExecutor = Option.empty[java.util.concurrent.ExecutorService].orNull + if (jobConfig.getOperatorFrameworkExecutorEnabled) { + operatorExecutor = taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig) + } new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, tableManager, new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, streamMetadataCache, - systemStreamPartitions, taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig)) + systemStreamPartitions, operatorExecutor) } // need separate field for this instead of using it through Context, since Context throws an exception if it is null private val applicationTaskContextOption = applicationTaskContextFactoryOption diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java index 9cb307d57a..6709417b98 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java @@ -18,16 +18,24 @@ */ package org.apache.samza.operators.impl; +import com.google.common.collect.ImmutableMap; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.samza.config.Config; +import org.apache.samza.config.MapConfig; +import org.apache.samza.context.ContainerContext; import org.apache.samza.context.Context; import org.apache.samza.context.InternalTaskContext; -import org.apache.samza.context.MockContext; +import org.apache.samza.context.JobContext; +import org.apache.samza.context.TaskContext; import org.apache.samza.job.model.TaskModel; import org.apache.samza.metrics.Counter; import org.apache.samza.metrics.MetricsRegistryMap; @@ -44,33 +52,111 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.anyObject; -import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Matchers.*; +import static org.mockito.Mockito.*; public class TestOperatorImpl { private Context context; private InternalTaskContext internalTaskContext; + private JobContext jobContext; + + private TaskContext taskContext; + + private ContainerContext containerContext; + @Before public void setup() { - this.context = new MockContext(); + this.context = mock(Context.class); this.internalTaskContext = mock(InternalTaskContext.class); + this.jobContext = mock(JobContext.class); + this.taskContext = mock(TaskContext.class); + this.containerContext = mock(ContainerContext.class); when(this.internalTaskContext.getContext()).thenReturn(this.context); // might be necessary in the future when(this.internalTaskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(mock(EndOfStreamStates.class)); when(this.internalTaskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class)); - when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap()); - when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class)); - when(this.context.getTaskContext().getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor()); - when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + when(this.context.getJobContext()).thenReturn(jobContext); + when(this.context.getTaskContext()).thenReturn(taskContext); + when(this.taskContext.getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + when(this.taskContext.getTaskModel()).thenReturn(mock(TaskModel.class)); + when(this.taskContext.getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor()); + when(this.context.getContainerContext()).thenReturn(containerContext); + when(containerContext.getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + } + + @Test + public void testComposeFutureWithExecutorWithFrameworkExecutorEnabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Function> mockFunction = mock(Function.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.composeFutureWithExecutor(mockFuture, mockFunction); + + verify(mockFuture).thenComposeAsync(eq(mockFunction), eq(mockExecutor)); + } + + @Test + public void testComposeFutureWithExecutorWithFrameworkExecutorDisabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Function> mockFunction = mock(Function.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.composeFutureWithExecutor(mockFuture, mockFunction); + + verify(mockFuture).thenCompose(eq(mockFunction)); } + @Test + public void testAcceptFutureWithExecutorWithFrameworkExecutorDisabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Consumer mockConsumer = mock(Consumer.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer); + + verify(mockFuture).thenAccept(eq(mockConsumer)); + } + + @Test + public void testAcceptFutureWithExecutorWithFrameworkExecutorEnabled() { + OperatorImpl opImpl = new TestOpImpl(mock(Object.class)); + ExecutorService mockExecutor = mock(ExecutorService.class); + CompletionStage mockFuture = mock(CompletionStage.class); + Consumer mockConsumer = mock(Consumer.class); + + Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true")); + + when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor); + when(this.jobContext.getConfig()).thenReturn(config); + + opImpl.init(this.internalTaskContext); + opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer); + + verify(mockFuture).thenAcceptAsync(eq(mockConsumer), eq(mockExecutor)); + } @Test(expected = IllegalStateException.class) public void testMultipleInitShouldThrow() { OperatorImpl opImpl = new TestOpImpl(mock(Object.class));