Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tuple processing manager #2516

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ trait InputGateway {

def tryPickChannel: Option[AmberFIFOChannel]

def getCurrentChannelId: Option[ChannelIdentity]

def getAllChannels: Iterable[AmberFIFOChannel]

def getAllDataChannels: Iterable[AmberFIFOChannel]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package edu.uci.ics.amber.engine.architecture.messaginglayer

import edu.uci.ics.amber.engine.common.AmberLogging
import edu.uci.ics.amber.engine.common.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import edu.uci.ics.amber.engine.common.virtualidentity.ActorVirtualIdentity
import edu.uci.ics.amber.engine.common.workflow.PortIdentity
import edu.uci.ics.texera.workflow.common.tuple.Tuple
import edu.uci.ics.texera.workflow.common.tuple.schema.Schema

import scala.collection.mutable

class InputManager(val actorId: ActorVirtualIdentity) extends AmberLogging {
private var inputBatch: Array[Tuple] = _
private var currentInputIdx: Int = -1
var currentChannelId: ChannelIdentity = _

private val ports: mutable.HashMap[PortIdentity, WorkerPort] = mutable.HashMap()
def getAllPorts: Set[PortIdentity] = {
Expand All @@ -36,25 +32,4 @@ class InputManager(val actorId: ActorVirtualIdentity) extends AmberLogging {
this.ports(portId).channels.values.forall(completed => completed)
}

def hasUnfinishedInput: Boolean = inputBatch != null && currentInputIdx + 1 < inputBatch.length

def getNextTuple: Tuple = {
currentInputIdx += 1
inputBatch(currentInputIdx)
}
def getCurrentTuple: Tuple = {
if (inputBatch == null) {
null
} else if (inputBatch.isEmpty) {
null // TODO: create input exhausted
} else {
inputBatch(currentInputIdx)
}
}

def initBatch(channelId: ChannelIdentity, batch: Array[Tuple]): Unit = {
currentChannelId = channelId
inputBatch = batch
currentInputIdx = -1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class NetworkInputGateway(val actorId: ActorVirtualIdentity)
private val inputChannels =
new mutable.HashMap[ChannelIdentity, AmberFIFOChannel]()

private var currentChannelId: Option[ChannelIdentity] = None

@transient lazy private val enforcers = mutable.ListBuffer[OrderEnforcer]()

def tryPickControlChannel: Option[AmberFIFOChannel] = {
Expand All @@ -33,6 +35,7 @@ class NetworkInputGateway(val actorId: ActorVirtualIdentity)
def tryPickChannel: Option[AmberFIFOChannel] = {
val control = tryPickControlChannel
val ret = if (control.isDefined) {
this.currentChannelId = control.map(_.channelId)
control
} else {
inputChannels
Expand All @@ -41,12 +44,18 @@ class NetworkInputGateway(val actorId: ActorVirtualIdentity)
!cid.isControl && channel.isEnabled && channel.hasMessage && enforcers
.forall(enforcer => enforcer.isCompleted || enforcer.canProceed(cid))
})
.map(_._2)
.map {
case (channelId, channel) =>
this.currentChannelId = Some(channelId)
channel
}
}
enforcers.filter(enforcer => enforcer.isCompleted).foreach(enforcer => enforcers -= enforcer)
ret
}

override def getCurrentChannelId: Option[ChannelIdentity] = this.currentChannelId

def getAllDataChannels: Iterable[AmberFIFOChannel] =
inputChannels.filter(!_._1.isControl).values

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package edu.uci.ics.amber.engine.architecture.messaginglayer

import edu.uci.ics.amber.engine.architecture.messaginglayer.OutputManager.{
DPOutputIterator,
getBatchSize,
toPartitioner
}
import edu.uci.ics.amber.engine.architecture.sendsemantics.partitioners._
import edu.uci.ics.amber.engine.architecture.sendsemantics.partitionings._
import edu.uci.ics.amber.engine.architecture.worker.DataProcessor.{FinalizeExecutor, FinalizePort}
import edu.uci.ics.amber.engine.common.AmberLogging
import edu.uci.ics.amber.engine.common.rpc.AsyncRPCServer.ControlCommand
import edu.uci.ics.amber.engine.common.tuple.amber.{SchemaEnforceable, TupleLike}
import edu.uci.ics.amber.engine.common.tuple.amber.SchemaEnforceable
import edu.uci.ics.amber.engine.common.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import edu.uci.ics.amber.engine.common.workflow.{PhysicalLink, PortIdentity}
import edu.uci.ics.texera.workflow.common.tuple.schema.Schema
Expand Down Expand Up @@ -49,32 +47,6 @@ object OutputManager {
}
}

class DPOutputIterator extends Iterator[(TupleLike, Option[PortIdentity])] {
val queue = new mutable.ListBuffer[(TupleLike, Option[PortIdentity])]
@transient var outputIter: Iterator[(TupleLike, Option[PortIdentity])] = Iterator.empty

def setTupleOutput(outputIter: Iterator[(TupleLike, Option[PortIdentity])]): Unit = {
if (outputIter != null) {
this.outputIter = outputIter
} else {
this.outputIter = Iterator.empty
}
}

override def hasNext: Boolean = outputIter.hasNext || queue.nonEmpty

override def next(): (TupleLike, Option[PortIdentity]) = {
if (outputIter.hasNext) {
outputIter.next()
} else {
queue.remove(0)
}
}

def appendSpecialTupleToEnd(tuple: TupleLike): Unit = {
queue.append((tuple, None))
}
}
}

/** This class is a container of all the transfer partitioners.
Expand All @@ -87,7 +59,6 @@ class OutputManager(
outputGateway: NetworkOutputGateway
) extends AmberLogging {

val outputIterator: DPOutputIterator = new DPOutputIterator()
private val partitioners: mutable.Map[PhysicalLink, Partitioner] =
mutable.HashMap[PhysicalLink, Partitioner]()

Expand Down Expand Up @@ -183,14 +154,6 @@ class OutputManager(

def getPort(portId: PortIdentity): WorkerPort = ports(portId)

def hasUnfinishedOutput: Boolean = outputIterator.hasNext

def finalizeOutput(): Unit = {
this.ports.keys
.foreach(outputPortId =>
outputIterator.appendSpecialTupleToEnd(FinalizePort(outputPortId, input = false))
)
outputIterator.appendSpecialTupleToEnd(FinalizeExecutor())
}
def getAllPortIds: Set[PortIdentity] = this.ports.keySet.toSet

}
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,19 @@ class DPThread(
//
// Main loop step 2: do input selection
//
var channelId: ChannelIdentity = null
var msgOpt: Option[WorkflowFIFOMessage] = None
var channelId: Option[ChannelIdentity] = None
var msg: Option[WorkflowFIFOMessage] = None
if (
dp.inputManager.hasUnfinishedInput || dp.outputManager.hasUnfinishedOutput || dp.pauseManager.isPaused
dp.tupleProcessingManager.inputIterator.hasNext || dp.tupleProcessingManager.outputIterator.hasNext || dp.pauseManager.isPaused
) {
dp.inputGateway.tryPickControlChannel match {
case Some(channel) =>
channelId = channel.channelId
msgOpt = Some(channel.take)
channelId = dp.inputGateway.getCurrentChannelId
msg = Some(channel.take)
case None =>
// continue processing
if (!dp.pauseManager.isPaused && !backpressureStatus) {
channelId = dp.inputManager.currentChannelId
channelId = dp.inputGateway.getCurrentChannelId
} else {
waitingForInput = true
}
Expand All @@ -162,20 +162,20 @@ class DPThread(
dp.inputGateway.tryPickChannel
} match {
case Some(channel) =>
channelId = channel.channelId
msgOpt = Some(channel.take)
channelId = dp.inputGateway.getCurrentChannelId
msg = Some(channel.take)
case None => waitingForInput = true
}
}

//
// Main loop step 3: process selected message payload
//
if (channelId != null) {
if (channelId.isDefined) {
// for logging, skip large data frames.
val msgToLog = msgOpt.filter(_.payload.isInstanceOf[ControlPayload])
logManager.withFaultTolerant(channelId, msgToLog) {
msgOpt match {
val msgToLog = msg.filter(_.payload.isInstanceOf[ControlPayload])
logManager.withFaultTolerant(channelId.get, msgToLog) {
msg match {
case None =>
dp.continueDataProcessing()
case Some(msg) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class DataProcessor(
val stateManager: WorkerStateManager = new WorkerStateManager()
val inputManager: InputManager = new InputManager(actorId)
val outputManager: OutputManager = new OutputManager(actorId, outputGateway)
val tupleProcessingManager: TupleProcessingManager = new TupleProcessingManager(actorId)
val channelMarkerManager: ChannelMarkerManager = new ChannelMarkerManager(actorId, inputGateway)
val serializationManager: SerializationManager = new SerializationManager(actorId)
def getQueuedCredit(channelId: ChannelIdentity): Long = {
Expand All @@ -80,12 +81,12 @@ class DataProcessor(
* process currentInputTuple through executor logic.
* this function is only called by the DP thread.
*/
private[this] def processInputTuple(tuple: Tuple): Unit = {
private[this] def processInputTuple(tuple: Tuple, portId: PortIdentity): Unit = {
try {
outputManager.outputIterator.setTupleOutput(
tupleProcessingManager.outputIterator.setInternalIter(
executor.processTupleMultiPort(
tuple,
this.inputGateway.getChannel(inputManager.currentChannelId).getPortId.id
portId.id
)
)
statisticsManager.increaseInputTupleCount()
Expand All @@ -101,13 +102,10 @@ class DataProcessor(
* process end of an input port with Executor.onFinish().
* this function is only called by the DP thread.
*/
private[this] def processInputExhausted(): Unit = {
private[this] def processOnFinish(portId: PortIdentity): Unit = {
try {
outputManager.outputIterator.setTupleOutput(
executor.onFinishMultiPort(
this.inputGateway.getChannel(inputManager.currentChannelId).getPortId.id
)
)
val output = executor.onFinishMultiPort(portId.id)
tupleProcessingManager.outputIterator.setInternalIter(output)
} catch safely {
case e =>
// forward input tuple to the user and pause DP thread
Expand All @@ -122,13 +120,13 @@ class DataProcessor(
adaptiveBatchingMonitor.startAdaptiveBatching()
var out: (TupleLike, Option[PortIdentity]) = null
try {
out = outputManager.outputIterator.next()
out = tupleProcessingManager.outputIterator.next()
} catch safely {
case e =>
// invalidate current output tuple
out = null
// also invalidate outputIterator
outputManager.outputIterator.setTupleOutput(Iterator.empty)
tupleProcessingManager.outputIterator.setInternalIter(Iterator.empty)
// forward input tuple to the user and pause DP thread
handleExecutorException(e)
}
Expand Down Expand Up @@ -163,10 +161,11 @@ class DataProcessor(

def continueDataProcessing(): Unit = {
val dataProcessingStartTime = System.nanoTime()
if (outputManager.hasUnfinishedOutput) {
if (tupleProcessingManager.outputIterator.hasNext) {
outputOneTuple()
} else {
processInputTuple(inputManager.getNextTuple)
val (tuple, portId) = tupleProcessingManager.inputIterator.next()
processInputTuple(tuple, portId)
}
statisticsManager.increaseDataProcessingTime(System.nanoTime() - dataProcessingStartTime)
}
Expand All @@ -188,22 +187,29 @@ class DataProcessor(
)
}
)
inputManager.initBatch(channelId, tuples)
processInputTuple(inputManager.getNextTuple)

tupleProcessingManager.inputIterator.setBatch(
inputGateway.getChannel(channelId).getPortId,
tuples
)
val (tuple, portId) = tupleProcessingManager.inputIterator.next()
processInputTuple(tuple, portId)
case EndOfUpstream() =>
val channel = this.inputGateway.getChannel(channelId)
val portId = channel.getPortId

this.inputManager.getPort(portId).channels(channelId) = true

if (inputManager.isPortCompleted(portId)) {
inputManager.initBatch(channelId, Array.empty)
processInputExhausted()
outputManager.outputIterator.appendSpecialTupleToEnd(FinalizePort(portId, input = true))
tupleProcessingManager.inputIterator.setBatch(portId, Array.empty)
processOnFinish(portId)
tupleProcessingManager.outputIterator.appendSpecialTupleToEnd(
FinalizePort(portId, input = true)
)
}
if (inputManager.getAllPorts.forall(portId => inputManager.isPortCompleted(portId))) {
// assuming all the output ports finalize after all input ports are finalized.
outputManager.finalizeOutput()
tupleProcessingManager.finalizeOutput(outputManager.getAllPortIds)
}
}
statisticsManager.increaseDataProcessingTime(System.nanoTime() - dataProcessingStartTime)
Expand Down
Loading
Loading