Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tweaks to the SparkLR example #872

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions core/src/main/scala/spark/util/Vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ class Vector(val elements: Array[Double]) extends Serializable {

def addInPlace(other: Vector) = this +=other

/**
* Perform a saxpy operation: multiply the given vector by the given scalar and add the
* result to this vector, returning this vector.
*/
def saxpy(a: Double, x: Vector) = {
if (length != x.length)
throw new IllegalArgumentException("Vectors of different length")

var i = 0
while (i < length) {
elements(i) += a * x(i)
i += 1
}
this
}

def * (scale: Double): Vector = Vector(length, i => this(i) * scale)

def multiply (d: Double) = this * d
Expand Down
32 changes: 26 additions & 6 deletions examples/src/main/scala/spark/examples/SparkLR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,33 @@ import java.util.Random
import scala.math.exp
import spark.util.Vector
import spark._
import com.esotericsoftware.kryo.Kryo

/**
* Logistic regression based classification.
*/
object SparkLR {
val N = 10000 // Number of data points
case class DataPoint(x: Vector, y: Double)

class MyRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.setRegistrationRequired(true)

kryo.register(classOf[scala.collection.mutable.WrappedArray.ofRef[_]])
kryo.register(classOf[java.lang.Class[_]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Mike, why do you need to register WrappedArray and Class? Doesn't seem like they'll occur in our data here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came up with that list by adding setRegistrationRequired, then fixing each exception that was thrown. So Kryo said they were used; I haven't considered why.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Did Kryo actually make a performance difference for you at all? We are not caching data in serialized form here, so it would only be used to send back task results, but that's just one Vector per partition. I think the WrappedArrays are because we somehow send that back within an array.

kryo.register(classOf[DataPoint])
kryo.register(classOf[Array[DataPoint]])
kryo.register(classOf[Vector])
kryo.register(classOf[Array[Double]])
}
}

var N = 10000 // Number of data points
val D = 10 // Numer of dimensions
val R = 0.7 // Scaling factor
val ITERATIONS = 5
val rand = new Random(42)

case class DataPoint(x: Vector, y: Double)

def generateData = {
def generatePoint(i: Int) = {
val y = if(i % 2 == 0) -1 else 1
Expand All @@ -45,12 +59,17 @@ object SparkLR {

def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: SparkLR <master> [<slices>]")
System.err.println("Usage: SparkLR <master> [<slices> [<points>]]")
System.exit(1)
}

System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.kryo.registrator", "spark.examples.SparkLR$MyRegistrator")
val sc = new SparkContext(args(0), "SparkLR",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))

val numSlices = if (args.length > 1) args(1).toInt else 2
if (args.length > 2) N = args(2).toInt
val points = sc.parallelize(generateData, numSlices).cache()

// Initialize w to a random value
Expand All @@ -59,9 +78,10 @@ object SparkLR {

for (i <- 1 to ITERATIONS) {
println("On iteration " + i)
val zero = Vector.zeros(D)
val gradient = points.map { p =>
(1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x
}.reduce(_ + _)
((1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y, p.x)
}.aggregate(zero)((sum,v) => sum saxpy (v._1,v._2), _ += _)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the second argument to sum.saxpy(v._1, v._2); we don't use infix notation for methods unless they're operators.

w -= gradient
}

Expand Down