Skip to content
This repository has been archived by the owner on May 12, 2021. It is now read-only.

Commit

Permalink
[PIO-193] Async support for predict and storage access, blocking code
Browse files Browse the repository at this point in the history
wrapped in blocking construct
  • Loading branch information
Chris Wewerka committed Nov 6, 2018
1 parent 998938b commit 18866a0
Show file tree
Hide file tree
Showing 19 changed files with 596 additions and 545 deletions.
2 changes: 1 addition & 1 deletion conf/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ log4j.appender.console.layout.ConversionPattern=[%p] [%c{1}] %m%n%throwable{0}

# file appender
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.File=${pio.log.dir}/pio.log
log4j.appender.file.File=/tmp/pio.log
log4j.appender.file.layout=org.apache.log4j.EnhancedPatternLayout
log4j.appender.file.layout.ConversionPattern=%d %-5p %c [%t] - %m%n

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import org.apache.predictionio.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
import scala.language.postfixOps
import scala.reflect._

/** Base class of a local algorithm.
Expand Down Expand Up @@ -72,11 +75,11 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
val glomQs: RDD[Array[(Long, Q)]] = qs.glom()
val cartesian: RDD[(M, Array[(Long, Q)])] = mRDD.cartesian(glomQs)
cartesian.flatMap { case (m, qArray) =>
qArray.map { case (qx, q) => (qx, predict(m, q)) }
qArray.map { case (qx, q) => (qx, blocking { Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) }) }
}
}

def predictBase(localBaseModel: Any, q: Q): P = {
def predictBase(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = {
predict(localBaseModel.asInstanceOf[M], q)
}

Expand All @@ -87,7 +90,7 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
* @param q An input query.
* @return A prediction.
*/
def predict(m: M, q: Q): P
def predict(m: M, q: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how local
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
import scala.language.postfixOps
import scala.reflect._

/** Base class of a parallel-to-local algorithm.
Expand Down Expand Up @@ -67,10 +70,10 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
* @return Batch of predicted results
*/
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = {
qs.mapValues { q => predict(m, q) }
qs.mapValues { q => blocking { Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) } }
}

def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q)
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = predict(bm.asInstanceOf[M], q)

/** Implement this method to produce a prediction from a query and trained
* model.
Expand All @@ -79,7 +82,7 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.predictionio.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.{ExecutionContext, Future}

/** Base class of a parallel algorithm.
*
* A parallel algorithm can be run in parallel on a cluster and produces a
Expand Down Expand Up @@ -72,7 +74,7 @@ abstract class PAlgorithm[PD, M, Q, P]
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] =
throw new NotImplementedError("batchPredict not implemented")

def predictBase(baseModel: Any, query: Q): P = {
def predictBase(baseModel: Any, query: Q)(implicit ec: ExecutionContext): Future[P] = {
predict(baseModel.asInstanceOf[M], query)
}

Expand All @@ -83,7 +85,7 @@ abstract class PAlgorithm[PD, M, Q, P]
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import net.jodah.typetools.TypeResolver
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.{ExecutionContext, Future}

/** :: DeveloperApi ::
* Base trait with default custom query serializer, exposed to engine developer
* via [[org.apache.predictionio.controller.CustomQuerySerializer]]
Expand Down Expand Up @@ -90,7 +92,7 @@ abstract class BaseAlgorithm[PD, M, Q, P]
* @return Predicted result
*/
@DeveloperApi
def predictBase(bm: Any, q: Q): P
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P]

/** :: DeveloperApi ::
* Engine developers should not use this directly. Prepare a model for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ import org.apache.predictionio.workflow.CleanupFunctions
import org.apache.spark.rdd.RDD
import org.json4s._
import org.json4s.native.JsonMethods._
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.concurrent.blocking
import scala.concurrent.{Await, Future}
import scala.language.existentials
import scala.concurrent.ExecutionContext.Implicits.global

case class BatchPredictConfig(
inputFilePath: String = "batchpredict-input.json",
Expand Down Expand Up @@ -207,23 +212,26 @@ object BatchPredict extends Logging {
// Deploy logic. First call Serving.supplement, then Algo.predict,
// finally Serving.serve.
val supplementedQuery = serving.supplementBase(query)
// TODO: Parallelize the following.
val predictions = algorithms.zip(models).map { case (a, m) =>
val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) =>
a.predictBase(m, supplementedQuery)
}
})
// Notice that it is by design to call Serving.serve with the
// *original* query.
val prediction = serving.serveBase(query, predictions)
// Combine query with prediction, so the batch results are
// self-descriptive.
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
Map("query" -> query,
"prediction" -> prediction),
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Return JSON string
compact(render(predictionJValue))
val predFutureRdds = predictionsFuture.map {
predictions =>
val prediction = serving.serveBase(query, predictions)
// Combine query with prediction, so the batch results are
// self-descriptive.
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
Map("query" -> query,
"prediction" -> prediction),
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Return JSON string
compact(render(predictionJValue))
}
blocking { Await.result(predFutureRdds, 60 minutes) }
}

predictionsRDD.saveAsTextFile(config.outputFilePath)
Expand Down
Loading

0 comments on commit 18866a0

Please sign in to comment.