diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b320be8863..87a4143877 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -60,6 +60,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val nextRunId = new AtomicInteger(0) + val runIdToStageIds = new HashMap[Int, HashSet[Int]] + val nextStageId = new AtomicInteger(0) val idToStage = new TimeStampedHashMap[Int, Stage] @@ -143,6 +145,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) idToStage(id) = stage + val stageIdSet = runIdToStageIds.getOrElseUpdate(priority, new HashSet) + stageIdSet += id stage } @@ -285,6 +289,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case StopDAGScheduler => // Cancel any active jobs for (job <- activeJobs) { + removeStages(job) val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) } @@ -420,13 +425,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!job.finished(rt.outputId)) { job.finished(rt.outputId) = true job.numFinished += 1 - job.listener.taskSucceeded(rt.outputId, event.result) // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage running -= stage + removeStages(job) } + job.listener.taskSucceeded(rt.outputId, event.result) } case None => logInfo("Ignoring result from " + rt + " because its job has finished") @@ -558,9 +564,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - job.listener.jobFailed(new SparkException("Job failed: " + reason)) activeJobs -= job resultStageToJob -= resultStage + removeStages(job) + job.listener.jobFailed(new SparkException("Job failed: " + reason)) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -637,6 +644,19 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } + def removeStages(job: ActiveJob) = { + runIdToStageIds(job.runId).foreach(stageId => { + idToStage.get(stageId).map( stage => { + pendingTasks -= stage + waiting -= stage + running -= stage + failed -= stage + }) + idToStage -= stageId + }) + runIdToStageIds -= job.runId + } + def stop() { eventQueue.put(StopDAGScheduler) metadataCleaner.cancel() diff --git a/core/src/test/scala/spark/DAGSchedulerSuite.scala b/core/src/test/scala/spark/DAGSchedulerSuite.scala new file mode 100644 index 0000000000..2a3b30ae42 --- /dev/null +++ b/core/src/test/scala/spark/DAGSchedulerSuite.scala @@ -0,0 +1,88 @@ +package spark + +import org.scalatest.FunSuite +import scheduler.{DAGScheduler, TaskSchedulerListener, TaskSet, TaskScheduler} +import collection.mutable + +class TaskSchedulerMock(f: (Int) => TaskEndReason ) extends TaskScheduler { + // Listener object to pass upcalls into + var listener: TaskSchedulerListener = null + var taskCount = 0 + + override def start(): Unit = {} + + // Disconnect from the cluster. + override def stop(): Unit = {} + + // Submit a sequence of tasks to run. + override def submitTasks(taskSet: TaskSet): Unit = { + taskSet.tasks.foreach( task => { + val m = new mutable.HashMap[Long, Any]() + m.put(task.stageId, 1) + taskCount += 1 + listener.taskEnded(task, f(taskCount), 1, m) + }) + } + + // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. + override def setListener(listener: TaskSchedulerListener) { + this.listener = listener + } + + // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. + override def defaultParallelism(): Int = { + 2 + } +} + +class DAGSchedulerSuite extends FunSuite { + def assertDagSchedulerEmpty(dagScheduler: DAGScheduler) = { + assert(dagScheduler.pendingTasks.isEmpty) + assert(dagScheduler.activeJobs.isEmpty) + assert(dagScheduler.failed.isEmpty) + assert(dagScheduler.runIdToStageIds.isEmpty) + assert(dagScheduler.idToStage.isEmpty) + assert(dagScheduler.resultStageToJob.isEmpty) + assert(dagScheduler.running.isEmpty) + assert(dagScheduler.shuffleToMapStage.isEmpty) + assert(dagScheduler.waiting.isEmpty) + } + + test("oneGoodJob") { + val sc = new SparkContext("local", "test") + val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success)) + try { + val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty) + val func = (tc: TaskContext, iter: Iterator[Int]) => 1 + val callSite = Utils.getSparkCallSite + + val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false) + assertDagSchedulerEmpty(dagScheduler) + } finally { + dagScheduler.stop() + sc.stop() + // pause to let dagScheduler stop (separate thread) + Thread.sleep(10) + } + } + + test("manyGoodJobs") { + val sc = new SparkContext("local", "test") + val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success)) + try { + val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty) + val func = (tc: TaskContext, iter: Iterator[Int]) => 1 + val callSite = Utils.getSparkCallSite + + 1.to(100).foreach( v => { + val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false) + }) + assertDagSchedulerEmpty(dagScheduler) + } finally { + dagScheduler.stop() + sc.stop() + // pause to let dagScheduler stop (separate thread) + Thread.sleep(10) + } + } +}