Skip to content
This repository has been archived by the owner on May 12, 2021. It is now read-only.

[PIO-193] Async support for predict method and storage access, blocking code wrapped in blocking construct #495

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import org.apache.predictionio.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
import scala.language.postfixOps
import scala.reflect._

/** Base class of a local algorithm.
Expand Down Expand Up @@ -72,13 +75,40 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
val glomQs: RDD[Array[(Long, Q)]] = qs.glom()
val cartesian: RDD[(M, Array[(Long, Q)])] = mRDD.cartesian(glomQs)
cartesian.flatMap { case (m, qArray) =>
qArray.map { case (qx, q) => (qx, predict(m, q)) }
qArray.map {
case (qx, q) =>
(qx,
Await.result(predictAsync(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) )
}
}
}

def predictBase(localBaseModel: Any, q: Q): P = {
override def predictBaseAsync(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext)
: Future[P] =
predictAsync(localBaseModel.asInstanceOf[M], q)(ec)

@deprecated(message =
"this method is just here for backward compatibility, predictBaseAsync() is called now",
since = "0.14.0")
override def predictBase(localBaseModel: Any, q: Q): P =
predict(localBaseModel.asInstanceOf[M], q)
}

/** Implement this method to produce a Future of a prediction in a non blocking way
* from a query and trained model.
*
* This method is implemented to just delegate to blocking predict() for
* backward compatibility reasons.
* Definitely overwrite it to implement your blocking prediction method, and leave
* the old blocking predict() as it is (throwing an exception), it won't be called from
* now on.
*
* @param model Trained model produced by [[train]].
* @param query An input query.
* @param ec ExecutionContext to use for async operations
* @return A Future of a prediction.
*/
def predictAsync(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] =
Future.successful(blocking(predict(model, query)))

/** Implement this method to produce a prediction from a query and trained
* model.
Expand All @@ -87,7 +117,9 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
* @param q An input query.
* @return A prediction.
*/
def predict(m: M, q: Q): P
@deprecated(message = "override non blocking predictAsync() instead", since = "0.14.0")
def predict(m: M, q: Q): P =
throw new NotImplementedError("predict() is deprecated, override predictAsync() instead")

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how local
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
import scala.language.postfixOps
import scala.reflect._

/** Base class of a parallel-to-local algorithm.
Expand Down Expand Up @@ -67,10 +70,35 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
* @return Batch of predicted results
*/
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = {
qs.mapValues { q => predict(m, q) }
qs.mapValues { q =>
Await.result(predictAsync(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes)
}
}

def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q)
override def predictBaseAsync(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
predictAsync(bm.asInstanceOf[M], q)(ec)

@deprecated(message =
"this method is just here for backward compatibility, predictBaseAsync() is called now",
since = "0.14.0")
override def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q)

/** Implement this method to produce a Future of a prediction in a non blocking way
* from a query and trained model.
*
* This method is implemented to just delegate to blocking predict() for
* backward compatibility reasons.
* Definitely overwrite it to implement your blocking prediction method, and leave
* the old blocking predict() as it is (throwing an exception), it won't be called from
* now on.
*
* @param model Trained model produced by [[train]].
* @param query An input query.
* @param ec ExecutionContext to use for async operations
* @return A Future of a prediction.
*/
def predictAsync(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] =
Future.successful(blocking(predict(model, query)))

/** Implement this method to produce a prediction from a query and trained
* model.
Expand All @@ -79,7 +107,9 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
@deprecated(message = "override non blocking predictAsync() instead", since = "0.14.0")
def predict(model: M, query: Q): P =
throw new NotImplementedError("predict() is deprecated, override predictAsync() instead")

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.predictionio.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.{ExecutionContext, Future, blocking}

/** Base class of a parallel algorithm.
*
* A parallel algorithm can be run in parallel on a cluster and produces a
Expand Down Expand Up @@ -72,9 +74,31 @@ abstract class PAlgorithm[PD, M, Q, P]
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] =
throw new NotImplementedError("batchPredict not implemented")

def predictBase(baseModel: Any, query: Q): P = {
override def predictBaseAsync(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
predictAsync(bm.asInstanceOf[M], q)(ec)

@deprecated(message =
"this method is just here for backward compatibility, predictBaseAsync() is called now",
since = "0.14.0")
override def predictBase(baseModel: Any, query: Q): P =
predict(baseModel.asInstanceOf[M], query)
}

/** Implement this method to produce a Future of a prediction in a non blocking way
* from a query and trained model.
*
* This method is implemented to just delegate to blocking predict() for
* backward compatibility reasons.
* Definitely overwrite it to implement your blocking prediction method, and leave
* the old blocking predict() as it is (throwing an exception), it won't be called from
* now on.
*
* @param model Trained model produced by [[train]].
* @param query An input query.
* @param ec ExecutionContext to use for async operations
* @return A Future of a prediction.
*/
def predictAsync(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] =
Future.successful(blocking(predict(model, query)))

/** Implement this method to produce a prediction from a query and trained
* model.
Expand All @@ -83,7 +107,9 @@ abstract class PAlgorithm[PD, M, Q, P]
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
@deprecated(message = "override non blocking predictAsync() instead", since = "0.14.0")
def predict(model: M, query: Q): P =
throw new NotImplementedError("predict() is deprecated, override predictAsync() instead")

/** :: DeveloperApi ::
* Engine developers should not use this directly (read on to see how parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import net.jodah.typetools.TypeResolver
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.concurrent.{ExecutionContext, Future, blocking}

/** :: DeveloperApi ::
* Base trait with default custom query serializer, exposed to engine developer
* via [[org.apache.predictionio.controller.CustomQuerySerializer]]
Expand Down Expand Up @@ -81,6 +83,19 @@ abstract class BaseAlgorithm[PD, M, Q, P]
def batchPredictBase(sc: SparkContext, bm: Any, qs: RDD[(Long, Q)])
: RDD[(Long, P)]

/** :: DeveloperApi ::
* Engine developers should not use this directly. Called by serving to
* perform a single prediction.
*
* @param bm Model
* @param q Query
* @param ec ExecutionContext to use for async operations
* @return Future of a Predicted result
*/
@DeveloperApi
def predictBaseAsync(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
Future.successful(blocking {predictBase(bm, q)})

/** :: DeveloperApi ::
* Engine developers should not use this directly. Called by serving to
* perform a single prediction.
Expand All @@ -90,7 +105,11 @@ abstract class BaseAlgorithm[PD, M, Q, P]
* @return Predicted result
*/
@DeveloperApi
def predictBase(bm: Any, q: Q): P
@deprecated(message = "override non blocking predictBaseAsync() instead", since = "0.14.0")
def predictBase(bm: Any, q: Q): P =
throw new NotImplementedError(
"predictBase() is deprecated, override predictBaseAsync() instead"
)

/** :: DeveloperApi ::
* Engine developers should not use this directly. Prepare a model for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ import org.apache.predictionio.workflow.CleanupFunctions
import org.apache.spark.rdd.RDD
import org.json4s._
import org.json4s.native.JsonMethods._
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.concurrent.blocking
import scala.concurrent.{Await, Future}
import scala.language.existentials
import scala.concurrent.ExecutionContext.Implicits.global

case class BatchPredictConfig(
inputFilePath: String = "batchpredict-input.json",
Expand Down Expand Up @@ -207,23 +212,26 @@ object BatchPredict extends Logging {
// Deploy logic. First call Serving.supplement, then Algo.predict,
// finally Serving.serve.
val supplementedQuery = serving.supplementBase(query)
// TODO: Parallelize the following.
val predictions = algorithms.zip(models).map { case (a, m) =>
a.predictBase(m, supplementedQuery)
}
val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) =>
a.predictBaseAsync(m, supplementedQuery)
})
// Notice that it is by design to call Serving.serve with the
// *original* query.
val prediction = serving.serveBase(query, predictions)
// Combine query with prediction, so the batch results are
// self-descriptive.
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
Map("query" -> query,
"prediction" -> prediction),
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Return JSON string
compact(render(predictionJValue))
val predFutureRdds = predictionsFuture.map {
predictions =>
val prediction = serving.serveBase(query, predictions)
// Combine query with prediction, so the batch results are
// self-descriptive.
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
Map("query" -> query,
"prediction" -> prediction),
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Return JSON string
compact(render(predictionJValue))
}
Await.result(predFutureRdds, 60 minutes)
}

predictionsRDD.saveAsTextFile(config.outputFilePath)
Expand Down
Loading