Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility function to get a setup & cleanup function for mapping each partition #456

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
45 changes: 40 additions & 5 deletions core/src/main/scala/spark/RDD.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package spark

import java.net.URL
import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.Random

import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap

import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
Expand All @@ -26,14 +23,16 @@ import spark.rdd.FlatMappedRDD
import spark.rdd.GlommedRDD
import spark.rdd.MappedRDD
import spark.rdd.MapPartitionsRDD
import spark.rdd.MapPartitionsWithSetupAndCleanup
import spark.rdd.MapPartitionsWithSplitRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
import spark.storage.StorageLevel

import SparkContext._
import spark.SparkContext._
import spark.RDD.PartitionMapper

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -339,6 +338,18 @@ abstract class RDD[T: ClassManifest](
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)

/**
* Return a new RDD by applying a function to every element in this RDD, with extra setup & cleanup
* at the beginning & end of processing every partition.
*
* This might be useful if you need to setup some resources per task & cleanup them up at the end, eg.
* a db connection
*/
def mapWithSetupAndCleanup[U: ClassManifest](
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename the instances of WithSetupAndCleanup to just WithCleanup for simplicity. Having cleanup will also imply that there's something to be cleaned up.

m: PartitionMapper[T,U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithSetupAndCleanup(this, m, preservesPartitioning)

/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
Expand Down Expand Up @@ -680,3 +691,27 @@ abstract class RDD[T: ClassManifest](
origin)

}

object RDD {

/**
* Defines a map function over elements of an RDD, but with extra setup and cleanup
* that happens
*/
trait PartitionMapper[T,U] extends Serializable {
/**
* called at the start of processing of each partition
*/
def setup(partiton:Int)

/**
* transform one element of the partition
*/
def map(t: T) : U

/**
* called at the end of each partition. This will get called even if the map failed (eg., an exception was thrown)
*/
def cleanup
}
}
12 changes: 12 additions & 0 deletions core/src/main/scala/spark/api/java/JavaDoublePartitionMapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package spark.api.java;

import java.io.Serializable;

public abstract class JavaDoublePartitionMapper<T> implements Serializable {

public abstract void setup(int partition);

public abstract Double map(T t) throws Exception;

public abstract void cleanup();
}
14 changes: 14 additions & 0 deletions core/src/main/scala/spark/api/java/JavaPairPartitionMapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package spark.api.java;

import scala.Tuple2;

import java.io.Serializable;

public abstract class JavaPairPartitionMapper<T, K, V> implements Serializable {

public abstract void setup(int partition);

public abstract Tuple2<K,V> map(T t) throws Exception;

public abstract void cleanup();
}
37 changes: 37 additions & 0 deletions core/src/main/scala/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction,
import spark.partial.{PartialResult, BoundedDouble}
import spark.storage.StorageLevel
import com.google.common.base.Optional
import spark.RDD.PartitionMapper
import spark.api.java.ManifestHelper.fakeManifest


trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
Expand Down Expand Up @@ -116,6 +118,41 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}

/**
* Return a new RDD by applying a function to each element of the RDD, with an additional
* setup & cleanup that happens before & after computing each partition
*/
def mapWithSetupAndCleanup[U](m: PartitionMapper[T,U]): JavaRDD[U] = {
JavaRDD.fromRDD(rdd.mapWithSetupAndCleanup(m)(fakeManifest[U]))(fakeManifest[U])
}

/**
* Return a new RDD by applying a function to each element of the RDD, with an additional
* setup & cleanup that happens before & after computing each partition
*/
def mapWithSetupAndCleanup[K,V](m: JavaPairPartitionMapper[T,K,V]): JavaPairRDD[K,V] = {
val scalaMapper = new PartitionMapper[T,(K,V)] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CanJavaPairPartitionMapper<T, K, V> be an abstract class that extends or implements PartitionMapper<T, Tuple2<K, V>>? If you can do that, then you wouldn't have to wrap the the Java PartitionMapper to convert it into its Scala counterpart.

def setup(partition:Int) = m.setup(partition)
def map(t:T) = m.map(t)
def cleanup = m.cleanup()
}
JavaPairRDD.fromRDD(rdd.mapWithSetupAndCleanup(scalaMapper)(fakeManifest[(K,V)]))(
fakeManifest[K], fakeManifest[V])
}

/**
* Return a new RDD by applying a function to each element of the RDD, with an additional
* setup & cleanup that happens before & after computing each partition
*/
def mapWithSetupAndCleanup(m: JavaDoublePartitionMapper[T]): JavaDoubleRDD = {
val scalaMapper = new PartitionMapper[T,Double] {
def setup(partition:Int) = m.setup(partition)
def map(t:T) = m.map(t)
def cleanup = m.cleanup()
}
JavaDoubleRDD.fromRDD(rdd.mapWithSetupAndCleanup(scalaMapper)(manifest[Double]))
}

/**
* Return an RDD created by coalescing all elements within each partition into an array.
*/
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/spark/api/java/ManifestHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package spark.api.java;

import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;

class ManifestHelper {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ManifestHelper class is a good idea. We could also use it to create the fake manifests in the Java Function* classes.


public static <R> ClassManifest<R> fakeManifest() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package spark.rdd

import spark.{TaskContext, Split, RDD}
import spark.RDD.PartitionMapper

/**
*
*/

class MapPartitionsWithSetupAndCleanup[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
m: PartitionMapper[T,U],
preservesPartitioning: Boolean
) extends RDD[U](prev){

override def getSplits = firstParent[T].splits

override val partitioner = if (preservesPartitioning) prev.partitioner else None

override def compute(split: Split, context: TaskContext) = {
context.addOnCompleteCallback(m.cleanup _)
m.setup(split.index)
firstParent[T].iterator(split, context).map(m.map _)
}

}
46 changes: 46 additions & 0 deletions core/src/test/scala/spark/RDDSuite.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package spark

import scala.collection.mutable.HashMap
import scala.collection.Set
import org.scalatest.FunSuite
import spark.SparkContext._
import spark.rdd.{CoalescedRDD, PartitionPruningRDD}
import spark.RDD.PartitionMapper

class RDDSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -173,4 +175,48 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(prunedData.size === 1)
assert(prunedData(0) === 10)
}

test("mapPartitionWithSetupAndCleanup") {
sc = new SparkContext("local[4]", "test")
val data = sc.parallelize(1 to 100, 4)
val acc = sc.accumulableCollection(new HashMap[Int,Set[Int]]())
val mapped = data.mapWithSetupAndCleanup(new PartitionMapper[Int,Int](){
var partition = -1
var values = Set[Int]()
def setup(partition:Int) {this.partition = partition}
def map(i:Int) = {values += i; i * 2}
def cleanup = {
//the purpose of this strange code is just to make sure this method is called
// after the data has been iterated through completely.
acc.localValue += (partition -> values)
}
}).collect

assert(mapped.toSet === (1 to 100).map{_ * 2}.toSet)
assert(acc.value.keySet == (0 to 3).toSet)
acc.value.foreach { case(partition, values) =>
assert(values.size === 25)
}


//the naive alternative doesn't work
val acc2 = sc.accumulableCollection(new HashMap[Int,Set[Int]]())
val m2 = data.mapPartitionsWithSplit{
case (partition, itr) =>
var values = Set[Int]()
val mItr = itr.map{i => values += i; i * 2}
//you haven't actually put anything into values yet, b/c itr.map defines another
// iterator, which is lazily computed. so the Set is empty
acc2.localValue += (partition -> values)
mItr
}.collect

assert(m2.toSet === (1 to 100).map{_ * 2}.toSet)
assert(acc2.value.keySet === (0 to 3).toSet)
//this condition will fail
// acc2.value.foreach { case(partition, values) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is supposed to fail, should we wrap it in an intercept block and check that an appropriate exception is actually thrown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, its not really supposed to fail -- there is no exception, it just doesn't give the "expected" result. that second part isn't really a unit test at all, its just documentation of why this method is needed. probably doesn't belong here at all, I just wanted it as part of the pull request to demonstrate why the method was needed.

I guess I can just write it up somewhere else (seems too long to put in the spark docs also, at least in the current layout ...)

// assert(values.size === 25)
// }

}
}