Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added stageId <--> jobId mapping in DAGScheduler #842

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ private[spark] class MapOutputTracker extends Logging {
}
}

def has(shuffleId: Int): Boolean = {
cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
}

def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
Expand All @@ -247,12 +251,12 @@ private[spark] class MapOutputTracker extends Logging {
case Some(bytes) =>
return bytes
case None =>
statuses = mapStatuses(shuffleId)
statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
// out a snapshot of the locations as "statuses"; let's serialize and return that
val bytes = serializeStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
Expand All @@ -261,7 +265,7 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
bytes
}

// Serialize an array of map output locations into an efficient byte format so that we can send
Expand Down
228 changes: 185 additions & 43 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class DAGScheduler(

val nextStageId = new AtomicInteger(0)

val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]

val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]

val stageIdToStage = new TimeStampedHashMap[Int, Stage]

val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
Expand Down Expand Up @@ -179,7 +183,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep, jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
Expand All @@ -188,7 +192,8 @@ class DAGScheduler(
/**
* Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
* as a result stage for the final RDD used directly in an action. The stage will also be
* associated with the provided jobId.
* associated with the provided jobId. Shuffle map stages, whose shuffleId may have previously
* been registered in the MapOutputTracker, should be (re)-created using newOrUsedStage.
*/
private def newStage(
rdd: RDD[_],
Expand All @@ -197,19 +202,42 @@ class DAGScheduler(
callSite: Option[String] = None)
: Stage =
{
if (shuffleDep != None) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
registerJobIdWithStages(jobId, stage)
stageToInfos(stage) = StageInfo(stage)
stage
}

/**
* Create a shuffle map Stage for the given RDD. The stage will also be associated with the
* provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is
* present in the MapOutputTracker, then the number and location of available outputs are
* recovered from the MapOutputTracker
*/
private def newOrUsedStage(
rdd: RDD[_],
shuffleDep: ShuffleDependency[_,_],
jobId: Int,
callSite: Option[String] = None)
: Stage =
{
val stage = newStage(rdd, Some(shuffleDep), jobId, callSite)
if (mapOutputTracker.has(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedLocations(shuffleDep.shuffleId)
val locs = mapOutputTracker.deserializeStatuses(serLocs)
for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i))
stage.numAvailableOutputs = locs.size
} else {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
}
stage
}

/**
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided jobId if they haven't already been created with a lower jobId.
Expand Down Expand Up @@ -261,6 +289,90 @@ class DAGScheduler(
missing.toList
}

/**
* Registers the given jobId among the jobs that need the given stage and
* all of that stage's ancestors.
*/
private def registerJobIdWithStages(jobId: Int, stage: Stage) {
def registerJobIdWithStageList(stages: List[Stage]) {
if (!stages.isEmpty) {
val s = stages.head
stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
val parents = getParentStages(s.rdd, jobId)
val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
registerJobIdWithStageList(parentsWithoutThisJobId ++ stages.tail)
}
}
registerJobIdWithStageList(List(stage))
}

private def jobIdToStageIdsAdd(jobId: Int) {
val stageSet = jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]())
stageIdToJobIds.foreach { case (stageId, jobSet) =>
if (jobSet.contains(jobId)) {
stageSet += stageId
}
}
}

private def jobIdToStageIdsRemove(jobId: Int) {
def removeStage(stageId: Int) {
// data structures based on Stage
stageIdToStage.get(stageId).foreach { s =>
stageToInfos -= s
shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove(_))
if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
logError("Tasks still pending for stage %d even though there are no more jobs registered for that stage."
.format(stageId))
}
pendingTasks -= s
if (waiting.contains(s)) {
logError("Still waiting on stage %d even though there are no more jobs registered for that stage."
.format(stageId))
waiting -= s
}
if (running.contains(s)) {
logError("Stage %d still running even though there are no more jobs registered for that stage."
.format(stageId))
running -= s
}
if (failed.contains(s)) {
logError("Stage %d still registered as failed even though there are no more jobs registered for that stage."
.format(stageId))
failed -= s
}
}
// data structures based on StageId
stageIdToStage -= stageId
stageIdToJobIds -= stageId

logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size))
}

if (!jobIdToStageIds.contains(jobId)) {
logError("Trying to remove unregistered job " + jobId)
} else {
val registeredStages = jobIdToStageIds(jobId)
if (registeredStages.isEmpty) {
logError("No stages registered for job " + jobId)
} else {
stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
case (stageId, jobSet) =>
if (!jobSet.contains(jobId)) {
logError("Job %d not registered for stage %d even though that stage was registered for the job"
.format(jobId, stageId))
} else {
jobSet -= jobId
if (jobSet.isEmpty) { // nobody needs this stage anymore
removeStage(stageId)
}
}
}
}
jobIdToStageIds -= jobId
}
}

/**
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
* JobWaiter whose getResult() method will return the result of the job when it is complete.
Expand Down Expand Up @@ -354,10 +466,11 @@ class DAGScheduler(
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
listenerBus.post(SparkListenerJobStart(job, properties))
idToActiveJob(jobId) = job
activeJobs += job
resultStageToJob(finalStage) = job
jobIdToStageIdsAdd(jobId)
listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties))
submitStage(finalStage)
}

Expand All @@ -375,6 +488,11 @@ class DAGScheduler(
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
handleTaskCompletion(completion)

case LocalJobCompleted(stage) =>
stageIdToJobIds -= stage.id // clean up data structures that were populated for a local job,
stageIdToStage -= stage.id // but that won't get cleaned up via the normal paths through
stageToInfos -= stage // completion events or stage abort

case TaskSetFailed(taskSet, reason) =>
abortStage(stageIdToStage(taskSet.stageId), reason)

Expand Down Expand Up @@ -488,30 +606,51 @@ class DAGScheduler(
} catch {
case e: Exception =>
job.listener.jobFailed(e)
} finally {
eventQueue.put(LocalJobCompleted(job.finalStage))
}
}

/** Finds the earliest-created active job that needs the stage */
// TODO: Probably should actually find among the active jobs that need this
// stage the one with the highest priority (highest-priority pool, earliest created).
// That should take care of at least part of the priority inversion problem with
// cross-job dependencies.
private def activeJobForStage(stage: Stage): Option[Int] = {
if (stageIdToJobIds.contains(stage.id)) {
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
jobsThatUseStage.find(idToActiveJob.contains(_))
} else {
None
}
}

/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
for (parent <- missing) {
submitStage(parent)
val jobId = activeJobForStage(stage)
if (jobId.isDefined) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage, jobId.get)
running += stage
} else {
for (parent <- missing) {
submitStage(parent)
}
waiting += stage
}
waiting += stage
}
} else {
abortStage(stage, "No active job for stage " + stage.id)
}
}

/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage) {
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
Expand All @@ -533,7 +672,7 @@ class DAGScheduler(
}
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
val properties = idToActiveJob(stage.jobId).properties
val properties = idToActiveJob(jobId).properties
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))

if (tasks.size > 0) {
Expand Down Expand Up @@ -576,7 +715,7 @@ class DAGScheduler(
def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.completionTime = Some(System.currentTimeMillis)
Expand Down Expand Up @@ -605,6 +744,7 @@ class DAGScheduler(
resultStageToJob -= stage
markStageAsFinished(stage)
listenerBus.post(SparkListenerJobEnd(job, JobSucceeded))
jobIdToStageIdsRemove(job.jobId)
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
Expand Down Expand Up @@ -640,7 +780,7 @@ class DAGScheduler(
changeEpoch = true)
}
clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
if (stage.outputLocs.exists(_ == Nil)) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " (" + stage.name +
Expand All @@ -657,9 +797,12 @@ class DAGScheduler(
}
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
for {
stage <- newlyRunnable.sortBy(_.id)
jobId <- activeJobForStage(stage)
} {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
submitMissingTasks(stage, jobId)
}
}
}
Expand Down Expand Up @@ -752,6 +895,7 @@ class DAGScheduler(
val error = new SparkException("Job failed: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
jobIdToStageIdsRemove(job.jobId)
idToActiveJob -= resultStage.jobId
activeJobs -= job
resultStageToJob -= resultStage
Expand Down Expand Up @@ -816,30 +960,28 @@ class DAGScheduler(
case n: NarrowDependency[_] =>
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocs(n.rdd, inPart)
if (locs != Nil)
if (locs != Nil) {
return locs
}
}
case _ =>
})
Nil
}

private def cleanup(cleanupTime: Long) {
var sizeBefore = stageIdToStage.size
stageIdToStage.clearOldValues(cleanupTime)
logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size)

sizeBefore = shuffleToMapStage.size
shuffleToMapStage.clearOldValues(cleanupTime)
logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)

sizeBefore = pendingTasks.size
pendingTasks.clearOldValues(cleanupTime)
logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)

sizeBefore = stageToInfos.size
stageToInfos.clearOldValues(cleanupTime)
logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size)
Map(
"stageIdToStage" -> stageIdToStage,
"shuffleToMapStage" -> shuffleToMapStage,
"pendingTasks" -> pendingTasks,
"stageToInfos" -> stageToInfos,
"jobIdToStageIds" -> jobIdToStageIds,
"stageIdToJobIds" -> stageIdToJobIds).
foreach { case(s, t) => {
val sizeBefore = t.size
t.clearOldValues(cleanupTime)
logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
}}
}

def stop() {
Expand Down
Loading