diff --git a/docs/learn/documentation/versioned/jobs/configuration-table.html b/docs/learn/documentation/versioned/jobs/configuration-table.html
index e00c983d8b..390be03761 100644
--- a/docs/learn/documentation/versioned/jobs/configuration-table.html
+++ b/docs/learn/documentation/versioned/jobs/configuration-table.html
@@ -494,6 +494,16 @@
Samza Configuration Reference
+
+ job.operator.framework.executor.enabled
+ false
+
+ If enabled, framework thread pool will be used for message hand off and sub DAG execution. Otherwise, the
+ execution will fall back to using caller thread or java fork join pool depending on the type of work
+ chained as part of message hand off.
+
+
+
Zookeeper-based job configuration
diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
index 3d0b532625..17f527252e 100644
--- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
@@ -197,6 +197,10 @@ public class JobConfig extends MapConfig {
public static final String JOB_ELASTICITY_FACTOR = "job.elasticity.factor";
public static final int DEFAULT_JOB_ELASTICITY_FACTOR = 1;
+ public static final String JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = "job.operator.framework.executor.enabled";
+
+ public static final boolean DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = false;
+
public JobConfig(Config config) {
super(config);
}
@@ -528,4 +532,8 @@ public int getElasticityFactor() {
public String getCoordinatorExecuteCommand() {
return get(COORDINATOR_EXECUTE_COMMAND, DEFAULT_COORDINATOR_EXECUTE_COMMAND);
}
+
+ public boolean getOperatorFrameworkExecutorEnabled() {
+ return getBoolean(JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED, DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED);
+ }
}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index 8b477d42db..c870264e91 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -22,6 +22,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
+import java.util.function.Function;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
@@ -95,6 +97,7 @@ public abstract class OperatorImpl {
private ControlMessageSender controlMessageSender;
private int elasticityFactor;
private ExecutorService operatorExecutor;
+ private boolean operatorExecutorEnabled;
/**
* Initialize this {@link OperatorImpl} and its user-defined functions.
@@ -136,7 +139,9 @@ public final void init(InternalTaskContext internalTaskContext) {
this.taskModel = taskContext.getTaskModel();
this.callbackScheduler = taskContext.getCallbackScheduler();
handleInit(context);
- this.elasticityFactor = new JobConfig(config).getElasticityFactor();
+ JobConfig jobConfig = new JobConfig(config);
+ this.elasticityFactor = jobConfig.getElasticityFactor();
+ this.operatorExecutorEnabled = jobConfig.getOperatorFrameworkExecutorEnabled();
this.operatorExecutor = context.getTaskContext().getOperatorExecutor();
initialized = true;
@@ -192,21 +197,20 @@ public final CompletionStage onMessageAsync(M message, MessageCollector co
getOpImplId(), getOperatorSpec().getSourceLocation(), expectedType, actualType), e);
}
- CompletionStage result = completableResultsFuture.thenComposeAsync(results -> {
+ CompletionStage result = composeFutureWithExecutor(completableResultsFuture, results -> {
long endNs = this.highResClock.nanoTime();
this.handleMessageNs.update(endNs - startNs);
return CompletableFuture.allOf(results.stream()
- .flatMap(r -> this.registeredOperators.stream()
- .map(op -> op.onMessageAsync(r, collector, coordinator)))
+ .flatMap(r -> this.registeredOperators.stream().map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));
- }, operatorExecutor);
+ });
WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn();
if (watermarkFn != null) {
// check whether there is new watermark emitted from the user function
Long outputWm = watermarkFn.getOutputWatermark();
- return result.thenComposeAsync(ignored -> propagateWatermark(outputWm, collector, coordinator), operatorExecutor);
+ return composeFutureWithExecutor(result, ignored -> propagateWatermark(outputWm, collector, coordinator));
}
return result;
@@ -245,11 +249,9 @@ public final CompletionStage onTimer(MessageCollector collector, TaskCoord
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));
- return resultFuture.thenComposeAsync(x ->
- CompletableFuture.allOf(this.registeredOperators
- .stream()
- .map(op -> op.onTimer(collector, coordinator))
- .toArray(CompletableFuture[]::new)), operatorExecutor);
+ return composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(this.registeredOperators.stream()
+ .map(op -> op.onTimer(collector, coordinator))
+ .toArray(CompletableFuture[]::new)));
}
/**
@@ -315,15 +317,14 @@ public final CompletionStage aggregateEndOfStream(EndOfStreamMessage eos,
}
// populate the end-of-stream through the dag
- endOfStreamFuture = onEndOfStream(collector, coordinator)
- .thenAcceptAsync(result -> {
- if (eosStates.allEndOfStream()) {
- // all inputs have been end-of-stream, shut down the task
- LOG.info("All input streams have reached the end for task {}", taskName.getTaskName());
- coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
- coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
- }
- }, operatorExecutor);
+ endOfStreamFuture = acceptFutureWithExecutor(onEndOfStream(collector, coordinator), result -> {
+ if (eosStates.allEndOfStream()) {
+ // all inputs have been end-of-stream, shut down the task
+ LOG.info("All input streams have reached the end for task {}", taskName.getTaskName());
+ coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ }
+ });
}
return endOfStreamFuture;
@@ -347,10 +348,10 @@ private CompletionStage onEndOfStream(MessageCollector collector, TaskCoor
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));
- endOfStreamFuture = resultFuture.thenComposeAsync(x ->
- CompletableFuture.allOf(this.registeredOperators.stream()
+ endOfStreamFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(
+ this.registeredOperators.stream()
.map(op -> op.onEndOfStream(collector, coordinator))
- .toArray(CompletableFuture[]::new)), operatorExecutor);
+ .toArray(CompletableFuture[]::new)));
}
return endOfStreamFuture;
@@ -406,15 +407,14 @@ public final CompletionStage aggregateDrainMessages(DrainMessage drainMess
controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector);
}
- drainFuture = onDrainOfStream(collector, coordinator)
- .thenAcceptAsync(result -> {
- if (drainStates.areAllStreamsDrained()) {
- // All input streams have been drained, shut down the task
- LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
- coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
- coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
- }
- }, operatorExecutor);
+ drainFuture = acceptFutureWithExecutor(onDrainOfStream(collector, coordinator), result -> {
+ if (drainStates.areAllStreamsDrained()) {
+ // All input streams have been drained, shut down the task
+ LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
+ coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ }
+ });
}
return drainFuture;
@@ -439,10 +439,10 @@ private CompletionStage onDrainOfStream(MessageCollector collector, TaskCo
.toArray(CompletableFuture[]::new));
// propagate DrainMessage to downstream operators
- drainFuture = resultFuture.thenComposeAsync(x ->
- CompletableFuture.allOf(this.registeredOperators.stream()
+ drainFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(
+ this.registeredOperators.stream()
.map(op -> op.onDrainOfStream(collector, coordinator))
- .toArray(CompletableFuture[]::new)), operatorExecutor);
+ .toArray(CompletableFuture[]::new)));
}
return drainFuture;
@@ -474,8 +474,8 @@ public final CompletionStage aggregateWatermark(WatermarkMessage watermark
controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector);
}
// populate the watermark through the dag
- watermarkFuture = onWatermark(watermark, collector, coordinator)
- .thenAcceptAsync(ignored -> watermarkStates.updateAggregateMetric(ssp, watermark), operatorExecutor);
+ watermarkFuture = acceptFutureWithExecutor(onWatermark(watermark, collector, coordinator),
+ ignored -> watermarkStates.updateAggregateMetric(ssp, watermark));
}
return watermarkFuture;
@@ -530,8 +530,8 @@ private CompletionStage onWatermark(long watermark, MessageCollector colle
.toArray(CompletableFuture[]::new));
}
- watermarkFuture = watermarkFuture.thenComposeAsync(res -> propagateWatermark(outputWm, collector, coordinator),
- operatorExecutor);
+ watermarkFuture =
+ composeFutureWithExecutor(watermarkFuture, res -> propagateWatermark(outputWm, collector, coordinator));
}
return watermarkFuture;
@@ -679,6 +679,20 @@ final Collection handleMessage(M message, MessageCollector collector, TaskCo
.toCompletableFuture().join();
}
+ @VisibleForTesting
+ final CompletionStage composeFutureWithExecutor(CompletionStage futureToChain,
+ Function super T, ? extends CompletionStage> fn) {
+ return operatorExecutorEnabled ? futureToChain.thenComposeAsync(fn, operatorExecutor)
+ : futureToChain.thenCompose(fn);
+ }
+
+ @VisibleForTesting
+ final CompletionStage acceptFutureWithExecutor(CompletionStage futureToChain,
+ Consumer super T> consumer) {
+ return operatorExecutorEnabled ? futureToChain.thenAcceptAsync(consumer, operatorExecutor)
+ : futureToChain.thenAccept(consumer);
+ }
+
private HighResolutionClock createHighResClock(Config config) {
MetricsConfig metricsConfig = new MetricsConfig(config);
// The timer metrics calculation here is only enabled for debugging
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 285e7c8778..89738e2de0 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -93,9 +93,13 @@ class TaskInstance(
val jobConfig = new JobConfig(jobContext.getConfig)
val taskExecutorFactory = ReflectionUtil.getObj(jobConfig.getTaskExecutorFactory, classOf[TaskExecutorFactory])
+ var operatorExecutor = Option.empty[java.util.concurrent.ExecutorService].orNull
+ if (jobConfig.getOperatorFrameworkExecutorEnabled) {
+ operatorExecutor = taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig)
+ }
new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, tableManager,
new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, streamMetadataCache,
- systemStreamPartitions, taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig))
+ systemStreamPartitions, operatorExecutor)
}
// need separate field for this instead of using it through Context, since Context throws an exception if it is null
private val applicationTaskContextOption = applicationTaskContextFactoryOption
diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
index 9cb307d57a..6709417b98 100644
--- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
+++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
@@ -18,16 +18,24 @@
*/
package org.apache.samza.operators.impl;
+import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.context.ContainerContext;
import org.apache.samza.context.Context;
import org.apache.samza.context.InternalTaskContext;
-import org.apache.samza.context.MockContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.context.TaskContext;
import org.apache.samza.job.model.TaskModel;
import org.apache.samza.metrics.Counter;
import org.apache.samza.metrics.MetricsRegistryMap;
@@ -44,33 +52,111 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.anyLong;
-import static org.mockito.Matchers.anyObject;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.*;
public class TestOperatorImpl {
private Context context;
private InternalTaskContext internalTaskContext;
+ private JobContext jobContext;
+
+ private TaskContext taskContext;
+
+ private ContainerContext containerContext;
+
@Before
public void setup() {
- this.context = new MockContext();
+ this.context = mock(Context.class);
this.internalTaskContext = mock(InternalTaskContext.class);
+ this.jobContext = mock(JobContext.class);
+ this.taskContext = mock(TaskContext.class);
+ this.containerContext = mock(ContainerContext.class);
when(this.internalTaskContext.getContext()).thenReturn(this.context);
// might be necessary in the future
when(this.internalTaskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(mock(EndOfStreamStates.class));
when(this.internalTaskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
- when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
- when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
- when(this.context.getTaskContext().getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
- when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+ when(this.context.getJobContext()).thenReturn(jobContext);
+ when(this.context.getTaskContext()).thenReturn(taskContext);
+ when(this.taskContext.getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+ when(this.taskContext.getTaskModel()).thenReturn(mock(TaskModel.class));
+ when(this.taskContext.getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
+ when(this.context.getContainerContext()).thenReturn(containerContext);
+ when(containerContext.getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+ }
+
+ @Test
+ public void testComposeFutureWithExecutorWithFrameworkExecutorEnabled() {
+ OperatorImpl opImpl = new TestOpImpl(mock(Object.class));
+ ExecutorService mockExecutor = mock(ExecutorService.class);
+ CompletionStage mockFuture = mock(CompletionStage.class);
+ Function> mockFunction = mock(Function.class);
+
+ Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));
+
+ when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+ when(this.jobContext.getConfig()).thenReturn(config);
+
+ opImpl.init(this.internalTaskContext);
+ opImpl.composeFutureWithExecutor(mockFuture, mockFunction);
+
+ verify(mockFuture).thenComposeAsync(eq(mockFunction), eq(mockExecutor));
+ }
+
+ @Test
+ public void testComposeFutureWithExecutorWithFrameworkExecutorDisabled() {
+ OperatorImpl opImpl = new TestOpImpl(mock(Object.class));
+ ExecutorService mockExecutor = mock(ExecutorService.class);
+ CompletionStage mockFuture = mock(CompletionStage.class);
+ Function> mockFunction = mock(Function.class);
+
+ Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));
+
+ when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+ when(this.jobContext.getConfig()).thenReturn(config);
+
+ opImpl.init(this.internalTaskContext);
+ opImpl.composeFutureWithExecutor(mockFuture, mockFunction);
+
+ verify(mockFuture).thenCompose(eq(mockFunction));
}
+ @Test
+ public void testAcceptFutureWithExecutorWithFrameworkExecutorDisabled() {
+ OperatorImpl opImpl = new TestOpImpl(mock(Object.class));
+ ExecutorService mockExecutor = mock(ExecutorService.class);
+ CompletionStage mockFuture = mock(CompletionStage.class);
+ Consumer mockConsumer = mock(Consumer.class);
+
+ Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));
+
+ when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+ when(this.jobContext.getConfig()).thenReturn(config);
+
+ opImpl.init(this.internalTaskContext);
+ opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);
+
+ verify(mockFuture).thenAccept(eq(mockConsumer));
+ }
+
+ @Test
+ public void testAcceptFutureWithExecutorWithFrameworkExecutorEnabled() {
+ OperatorImpl opImpl = new TestOpImpl(mock(Object.class));
+ ExecutorService mockExecutor = mock(ExecutorService.class);
+ CompletionStage mockFuture = mock(CompletionStage.class);
+ Consumer mockConsumer = mock(Consumer.class);
+
+ Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));
+
+ when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+ when(this.jobContext.getConfig()).thenReturn(config);
+
+ opImpl.init(this.internalTaskContext);
+ opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);
+
+ verify(mockFuture).thenAcceptAsync(eq(mockConsumer), eq(mockExecutor));
+ }
@Test(expected = IllegalStateException.class)
public void testMultipleInitShouldThrow() {
OperatorImpl opImpl = new TestOpImpl(mock(Object.class));