Skip to content
This repository has been archived by the owner on May 12, 2021. It is now read-only.

Commit

Permalink
[PIO-193] Async support for predict and storage access, blocking code
Browse files Browse the repository at this point in the history
wrapped in blocking construct
  • Loading branch information
Chris Wewerka committed Nov 6, 2018
1 parent 998938b commit 24c0448
Show file tree
Hide file tree
Showing 18 changed files with 595 additions and 544 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import org.apache.predictionio.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
import scala.language.postfixOps
import scala.reflect._

/** Base class of a local algorithm.
Expand Down Expand Up @@ -72,11 +75,11 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
val glomQs: RDD[Array[(Long, Q)]] = qs.glom()
val cartesian: RDD[(M, Array[(Long, Q)])] = mRDD.cartesian(glomQs)
cartesian.flatMap { case (m, qArray) =>
qArray.map { case (qx, q) => (qx, predict(m, q)) }
qArray.map { case (qx, q) => (qx, blocking { Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) }) }
}
}

def predictBase(localBaseModel: Any, q: Q): P = {
def predictBase(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = {
predict(localBaseModel.asInstanceOf[M], q)
}

Expand All @@ -87,7 +90,7 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
* @param q An input query.
* @return A prediction.
*/
def predict(m: M, q: Q): P
def predict(m: M, q: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how local
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
import scala.language.postfixOps
import scala.reflect._

/** Base class of a parallel-to-local algorithm.
Expand Down Expand Up @@ -67,10 +70,10 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
* @return Batch of predicted results
*/
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = {
qs.mapValues { q => predict(m, q) }
qs.mapValues { q => blocking { Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) } }
}

def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q)
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = predict(bm.asInstanceOf[M], q)

/** Implement this method to produce a prediction from a query and trained
* model.
Expand All @@ -79,7 +82,7 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.predictionio.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.{ExecutionContext, Future}

/** Base class of a parallel algorithm.
*
* A parallel algorithm can be run in parallel on a cluster and produces a
Expand Down Expand Up @@ -72,7 +74,7 @@ abstract class PAlgorithm[PD, M, Q, P]
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] =
throw new NotImplementedError("batchPredict not implemented")

def predictBase(baseModel: Any, query: Q): P = {
def predictBase(baseModel: Any, query: Q)(implicit ec: ExecutionContext): Future[P] = {
predict(baseModel.asInstanceOf[M], query)
}

Expand All @@ -83,7 +85,7 @@ abstract class PAlgorithm[PD, M, Q, P]
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import net.jodah.typetools.TypeResolver
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.{ExecutionContext, Future}

/** :: DeveloperApi ::
* Base trait with default custom query serializer, exposed to engine developer
* via [[org.apache.predictionio.controller.CustomQuerySerializer]]
Expand Down Expand Up @@ -90,7 +92,7 @@ abstract class BaseAlgorithm[PD, M, Q, P]
* @return Predicted result
*/
@DeveloperApi
def predictBase(bm: Any, q: Q): P
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly. Prepare a model for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ import org.apache.predictionio.workflow.CleanupFunctions
import org.apache.spark.rdd.RDD
import org.json4s._
import org.json4s.native.JsonMethods._
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.concurrent.blocking
import scala.concurrent.{Await, Future}
import scala.language.existentials
import scala.concurrent.ExecutionContext.Implicits.global

case class BatchPredictConfig(
inputFilePath: String = "batchpredict-input.json",
Expand Down Expand Up @@ -207,23 +212,26 @@ object BatchPredict extends Logging {
// Deploy logic. First call Serving.supplement, then Algo.predict,
// finally Serving.serve.
val supplementedQuery = serving.supplementBase(query)
// TODO: Parallelize the following.
val predictions = algorithms.zip(models).map { case (a, m) =>
val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) =>
a.predictBase(m, supplementedQuery)
}
})
// Notice that it is by design to call Serving.serve with the
// *original* query.
val prediction = serving.serveBase(query, predictions)
// Combine query with prediction, so the batch results are
// self-descriptive.
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
Map("query" -> query,
"prediction" -> prediction),
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Return JSON string
compact(render(predictionJValue))
val predFutureRdds = predictionsFuture.map {
predictions =>
val prediction = serving.serveBase(query, predictions)
// Combine query with prediction, so the batch results are
// self-descriptive.
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
Map("query" -> query,
"prediction" -> prediction),
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Return JSON string
compact(render(predictionJValue))
}
blocking { Await.result(predFutureRdds, 60 minutes) }
}

predictionsRDD.saveAsTextFile(config.outputFilePath)
Expand Down
Loading

0 comments on commit 24c0448

Please sign in to comment.