diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dde131696f..bfb9a7bf47 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -713,6 +713,42 @@ abstract class RDD[T: ClassManifest]( return buf.toArray } + /** + * Drop the first drop elements and then take next num elements of the RDD. This currently scans the partitions *one by one*, so + * it will be slow if a lot of partitions are required. In that case, use collect().drop(drop) to get the + * whole RDD instead and drop the required drop elements. + */ + def dropTake(drop: Int, num: Int): Array[T] = { + if (num == 0) { + return new Array[T](0) + } + val buf = new ArrayBuffer[T] + var p = 0 + var dropped = sc.accumulator(0) + while (buf.size < num && p < partitions.size) { + val left = num - buf.size + //read dropped so far from accumulator + val accDropped = dropped.value + val res = sc.runJob(this, (it: Iterator[T]) => { + var leftToDrop = drop - accDropped + while (leftToDrop > 0 && it.hasNext) { + it.next() + leftToDrop -= 1 + } + //accumulate all that have been dropped here + dropped += (drop - accDropped) - leftToDrop + //if still left to drop then don't take + val taken = if (leftToDrop > 0) it.take(0) else it.take(left) + taken.toArray + }, Array(p), true) + buf ++= res(0) + if (buf.size == num) + return buf.toArray + p += 1 + } + return buf.toArray + } + /** * Return the first element in this RDD. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index a761dd77c5..37eec0431d 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -26,6 +26,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) + assert(nums.take(2).toList === List(1, 2)) assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) @@ -43,6 +44,16 @@ class RDDSuite extends FunSuite with LocalSparkContext { intercept[UnsupportedOperationException] { nums.filter(_ > 5).reduce(_ + _) } + val sixteen = sc.makeRDD(Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 4) + //drop none + assert(sixteen.dropTake(0, 2).toList === List(1, 2)) + //drop only from first partition + assert(sixteen.dropTake(2, 2).toList === List(3, 4)) + //drop(4+2) all 4 from first and 2 from second partition. take 2. + assert(sixteen.dropTake(6, 2).toList === List(7, 8)) + //drop(4+4+2) all 4 from first and second and 2 from third partition and take 6 values. + //The take should spill over to the next partition + assert(sixteen.dropTake(10, 6).toList === List(11, 12, 13, 14, 15, 16)) } test("SparkContext.union") {