Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up DAGScheduler datastructures after job completes #414

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions core/src/main/scala/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with

val nextRunId = new AtomicInteger(0)

val runIdToStageIds = new HashMap[Int, HashSet[Int]]

val nextStageId = new AtomicInteger(0)

val idToStage = new TimeStampedHashMap[Int, Stage]
Expand Down Expand Up @@ -143,6 +145,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
idToStage(id) = stage
val stageIdSet = runIdToStageIds.getOrElseUpdate(priority, new HashSet)
stageIdSet += id
stage
}

Expand Down Expand Up @@ -285,6 +289,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
removeStages(job)
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
Expand Down Expand Up @@ -420,13 +425,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
activeJobs -= job
resultStageToJob -= stage
running -= stage
removeStages(job)
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
case None =>
logInfo("Ignoring result from " + rt + " because its job has finished")
Expand Down Expand Up @@ -558,9 +564,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
activeJobs -= job
resultStageToJob -= resultStage
removeStages(job)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
Expand Down Expand Up @@ -637,6 +644,19 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
}

def removeStages(job: ActiveJob) = {
runIdToStageIds(job.runId).foreach(stageId => {
idToStage.get(stageId).map( stage => {
pendingTasks -= stage
waiting -= stage
running -= stage
failed -= stage
})
idToStage -= stageId
})
runIdToStageIds -= job.runId
}

def stop() {
eventQueue.put(StopDAGScheduler)
metadataCleaner.cancel()
Expand Down
88 changes: 88 additions & 0 deletions core/src/test/scala/spark/DAGSchedulerSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package spark

import org.scalatest.FunSuite
import scheduler.{DAGScheduler, TaskSchedulerListener, TaskSet, TaskScheduler}
import collection.mutable

class TaskSchedulerMock(f: (Int) => TaskEndReason ) extends TaskScheduler {
// Listener object to pass upcalls into
var listener: TaskSchedulerListener = null
var taskCount = 0

override def start(): Unit = {}

// Disconnect from the cluster.
override def stop(): Unit = {}

// Submit a sequence of tasks to run.
override def submitTasks(taskSet: TaskSet): Unit = {
taskSet.tasks.foreach( task => {
val m = new mutable.HashMap[Long, Any]()
m.put(task.stageId, 1)
taskCount += 1
listener.taskEnded(task, f(taskCount), 1, m)
})
}

// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}

// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
override def defaultParallelism(): Int = {
2
}
}

class DAGSchedulerSuite extends FunSuite {
def assertDagSchedulerEmpty(dagScheduler: DAGScheduler) = {
assert(dagScheduler.pendingTasks.isEmpty)
assert(dagScheduler.activeJobs.isEmpty)
assert(dagScheduler.failed.isEmpty)
assert(dagScheduler.runIdToStageIds.isEmpty)
assert(dagScheduler.idToStage.isEmpty)
assert(dagScheduler.resultStageToJob.isEmpty)
assert(dagScheduler.running.isEmpty)
assert(dagScheduler.shuffleToMapStage.isEmpty)
assert(dagScheduler.waiting.isEmpty)
}

test("oneGoodJob") {
val sc = new SparkContext("local", "test")
val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success))
try {
val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty)
val func = (tc: TaskContext, iter: Iterator[Int]) => 1
val callSite = Utils.getSparkCallSite

val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false)
assertDagSchedulerEmpty(dagScheduler)
} finally {
dagScheduler.stop()
sc.stop()
// pause to let dagScheduler stop (separate thread)
Thread.sleep(10)
}
}

test("manyGoodJobs") {
val sc = new SparkContext("local", "test")
val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success))
try {
val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty)
val func = (tc: TaskContext, iter: Iterator[Int]) => 1
val callSite = Utils.getSparkCallSite

1.to(100).foreach( v => {
val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false)
})
assertDagSchedulerEmpty(dagScheduler)
} finally {
dagScheduler.stop()
sc.stop()
// pause to let dagScheduler stop (separate thread)
Thread.sleep(10)
}
}
}