Skip to content

Commit

Permalink
support for Future-based predict-method, see apache/predictionio#495
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Wewerka committed Nov 9, 2018
1 parent 5445b93 commit 2267d9d
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 126 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ scalaVersion in ThisBuild := "2.11.11"

val mahoutVersion = "0.13.0"

val pioVersion = "0.12.0-incubating"
val pioVersion = "0.14.0-SNAPSHOT"

val elasticsearchVersion = "5.5.2"

Expand Down
30 changes: 18 additions & 12 deletions src/main/scala/EsClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import java.util

import grizzled.slf4j.Logger
import org.apache.http.util.EntityUtils
import org.apache.predictionio.data.storage.{ DataMap, Storage, StorageClientConfig }
import org.apache.predictionio.data.storage.{DataMap, Storage, StorageClientConfig}
import org.apache.predictionio.workflow.CleanupFunctions
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.elasticsearch.client.RestClient
import org.apache.http.HttpHost
import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials }
import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
import org.apache.http.entity.ContentType
import org.apache.http.entity.StringEntity
import org.apache.http.impl.client.BasicCredentialsProvider
Expand All @@ -42,6 +42,9 @@ import org.elasticsearch.spark._
import org.json4s.JValue
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JString
import ScalaRestClient.ExtendedScalaRestClient

import scala.concurrent.{ExecutionContext, Future}
// import org.json4s.native.Serialization.writePretty
import com.actionml.helpers.{ ItemID, ItemProps }

Expand Down Expand Up @@ -367,20 +370,23 @@ object EsClient {
* @param indexName the index to search
* @return a [PredictedResults] collection
*/
def search(query: String, indexName: String): Option[JValue] = {
def search(query: String, indexName: String)(implicit ec: ExecutionContext): Future[Option[JValue]] = {
logger.info(s"Query:\n${query}")
val response = client.performRequest(
val responseFuture = client.performRequestFuture(
"POST",
s"/$indexName/_search",
Map.empty[String, String].asJava,
Map.empty[String, String],
new StringEntity(query, ContentType.APPLICATION_JSON))
response.getStatusLine.getStatusCode match {
case 200 =>
logger.info(s"Got source from query: ${query}")
Some(parse(EntityUtils.toString(response.getEntity)))
case _ =>
logger.info(s"Query: ${query}\nproduced status code: ${response.getStatusLine.getStatusCode}")
None
responseFuture.map {
response =>
response.getStatusLine.getStatusCode match {
case 200 =>
logger.info(s"Got source from query: ${query}")
Some(parse(EntityUtils.toString(response.getEntity)))
case _ =>
logger.info(s"Query: ${query}\nproduced status code: ${response.getStatusLine.getStatusCode}")
None
}
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/main/scala/ScalaRestClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.actionml

import org.apache.http.{Header, HttpEntity}
import org.elasticsearch.client.{Response, ResponseListener, RestClient}
import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}

object ScalaRestClient {

implicit class ExtendedScalaRestClient(restClient: RestClient) {

def performRequestFuture(method: String, endpoint: String, params: Map[String, String],
entity: HttpEntity, headers: Header*): Future[Response] = {
val promise: Promise[Response] = Promise()
val responseListener = new ResponseListener {
override def onSuccess(response: Response): Unit = promise.success(response)
override def onFailure(exception: Exception): Unit = promise.failure(exception)
}
restClient.performRequestAsync(method, endpoint, params.asJava, entity, responseListener, headers: _*)
promise.future
}
}
}
238 changes: 125 additions & 113 deletions src/main/scala/URAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package com.actionml
import java.util

import grizzled.slf4j.Logger
import org.apache.predictionio.controller.{ P2LAlgorithm, Params }
import org.apache.predictionio.data.storage.{ DataMap, Event, NullModel, PropertyMap }
import org.apache.predictionio.controller.{P2LAlgorithm, Params}
import org.apache.predictionio.data.storage.{DataMap, Event, NullModel, PropertyMap}
import org.apache.predictionio.data.store.LEventStore
import org.apache.mahout.math.cf.{ DownsamplableCrossOccurrenceDataset, SimilarityAnalysis }
import org.apache.mahout.math.cf.{DownsamplableCrossOccurrenceDataset, SimilarityAnalysis}
import org.apache.mahout.sparkbindings.indexeddataset.IndexedDatasetSpark
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
Expand All @@ -34,10 +34,12 @@ import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import com.actionml.helpers._

import scala.concurrent.ExecutionContext
import scala.collection.JavaConverters._
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.language.{ implicitConversions, postfixOps }
import scala.language.{implicitConversions, postfixOps}
import ScalaRestClient.ExtendedScalaRestClient

/** Available value for algorithm param "RecsModel" */
object RecsModels { // todo: replace this with rankings
Expand Down Expand Up @@ -481,51 +483,58 @@ class URAlgorithm(val ap: URAlgorithmParams)
* @todo Need to prune that query to minimum required for data include, for instance no need for the popularity
* ranking if no PopModel is being used, same for "must" clause and dates.
*/
def predict(model: NullModel, query: Query): PredictedResult = {
override def predictAsync(model: NullModel, query: Query)(implicit ec: ExecutionContext): Future[PredictedResult] = {

queryEventNames = query.eventNames.getOrElse(modelEventNames) // eventNames in query take precedence

val (queryStr, blacklist) = buildQuery(ap, query, rankingFieldNames)
// old es1 query
// val searchHitsOpt = EsClient.search(queryStr, esIndex, queryEventNames)
val searchHitsOpt = EsClient.search(queryStr, esIndex)

val withRanks = query.withRanks.getOrElse(false)
val predictedResults = searchHitsOpt match {
case Some(searchHits) =>
val hits = (searchHits \ "hits" \ "hits").extract[Seq[JValue]]
val recs = hits.map { hit =>
if (withRanks) {
val source = hit \ "source"
val ranks: Map[String, Double] = rankingsParams map { backfillParams =>
val backfillType = backfillParams.`type`.getOrElse(DefaultURAlgoParams.BackfillType)
val backfillFieldName = backfillParams.name.getOrElse(PopModel.nameByType(backfillType))
backfillFieldName -> (source \ backfillFieldName).extract[Double]
} toMap

ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double],
ranks = if (ranks.nonEmpty) Some(ranks) else None)
} else {
ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double])
}
}.toArray
logger.info(s"Results: ${hits.length} retrieved of a possible ${(searchHits \ "hits" \ "total").extract[Long]}")
PredictedResult(recs)

case _ =>
logger.info(s"No results for query ${parse(queryStr)}")
PredictedResult(Array.empty[ItemScore])
val queryStrBlacklistFuture = buildQuery(ap, query, rankingFieldNames)

queryStrBlacklistFuture.flatMap {
case (queryStr, blacklist) =>
// old es1 query
// val searchHitsOpt = EsClient.search(queryStr, esIndex, queryEventNames)
val searchHitsOptFuture = EsClient.search(queryStr, esIndex)

val withRanks = query.withRanks.getOrElse(false)
searchHitsOptFuture.map {
searchHitsOpt =>
val predictedResults = searchHitsOpt match {
case Some(searchHits) =>
val hits = (searchHits \ "hits" \ "hits").extract[Seq[JValue]]
val recs = hits.map { hit =>
if (withRanks) {
val source = hit \ "source"
val ranks: Map[String, Double] = rankingsParams map { backfillParams =>
val backfillType = backfillParams.`type`.getOrElse(DefaultURAlgoParams.BackfillType)
val backfillFieldName = backfillParams.name.getOrElse(PopModel.nameByType(backfillType))
backfillFieldName -> (source \ backfillFieldName).extract[Double]
} toMap

ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double],
ranks = if (ranks.nonEmpty) Some(ranks) else None)
} else {
ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double])
}
}.toArray
logger.info(s"Results: ${hits.length} retrieved of a possible ${(searchHits \ "hits" \ "total").extract[Long]}")
PredictedResult(recs)

case _ =>
logger.info(s"No results for query ${parse(queryStr)}")
PredictedResult(Array.empty[ItemScore])
}

// todo: is this needed to remove ranked items from recs?
//if (recsModel == RecsModels.CF) {
// PredictedResult(predictedResults.filter(_.score != 0.0))
//} else PredictedResult(predictedResults)

// should have all blacklisted items excluded
// todo: need to add dithering, mean, sigma, seed required, make a seed that only changes on some fixed time
// period so the recs ordering stays fixed for that time period.
predictedResults
}
}

// todo: is this needed to remove ranked items from recs?
//if (recsModel == RecsModels.CF) {
// PredictedResult(predictedResults.filter(_.score != 0.0))
//} else PredictedResult(predictedResults)

// should have all blacklisted items excluded
// todo: need to add dithering, mean, sigma, seed required, make a seed that only changes on some fixed time
// period so the recs ordering stays fixed for that time period.
predictedResults
}

/** Calculate all fields and items needed for ranking.
Expand Down Expand Up @@ -563,56 +572,60 @@ class URAlgorithm(val ap: URAlgorithmParams)
def buildQuery(
ap: URAlgorithmParams,
query: Query,
backfillFieldNames: Seq[String] = Seq.empty): (String, Seq[Event]) = {
backfillFieldNames: Seq[String] = Seq.empty)(implicit ec: ExecutionContext): Future[(String, Seq[Event])] = {

logger.info(s"Got query: \n${query}")

val startPos = query.from.getOrElse(0)
logger.info(s"from: ${startPos}")

try {
// create a list of all query correlators that can have a bias (boost or filter) attached
val (boostable, events) = getBiasedRecentUserActions(query)
logger.info(s"getBiasedRecentUserActions returned boostable: ${boostable} and events: ${events}")

// since users have action history and items have correlators and both correspond to the same "actions" like
// purchase or view, we'll pass both to the query if the user history or items correlators are empty
// then metadata or backfill must be relied on to return results.
val numRecs = if (query.num.isDefined) query.num.get else limit // num in query orerrides num in config
logger.info(s"UR query num = ${query.num}")
logger.info(s"query.num.getOrElse returned numRecs: ${numRecs}")

val should = buildQueryShould(query, boostable)
logger.info(s"buildQueryShould returned should: ${should}")
val must = buildQueryMust(query, boostable)
logger.info(s"buildQueryMust returned must: ${must}")
val mustNot = buildQueryMustNot(query, events)
logger.info(s"buildQueryMustNot returned mustNot: ${mustNot}")
val sort = buildQuerySort()
logger.info(s"buildQuerySort returned sort: ${sort}")

val json =
("from" -> startPos) ~
("size" -> numRecs) ~
("query" ->
("bool" ->
("should" -> should) ~
("must" -> must) ~
("must_not" -> mustNot) ~
("minimum_should_match" -> 1))) ~
("sort" -> sort)

logger.info(s"json is: ${json}")
val compactJson = compact(render(json))
logger.info(s"compact json is: ${compactJson}")

//logger.info(s"Query:\n$compactJson")
(compactJson, events)
} catch {
case e: IllegalArgumentException => {
logger.warn("whoops, IllegalArgumentException for something in buildQuery.")
("", Seq.empty[Event])
}
// create a list of all query correlators that can have a bias (boost or filter) attached
val biasedRecentUserActionsFuture = getBiasedRecentUserActions(query)

biasedRecentUserActionsFuture.map {
case (boostable, events) =>
try {
logger.info(s"getBiasedRecentUserActions returned boostable: ${boostable} and events: ${events}")

// since users have action history and items have correlators and both correspond to the same "actions" like
// purchase or view, we'll pass both to the query if the user history or items correlators are empty
// then metadata or backfill must be relied on to return results.
val numRecs = if (query.num.isDefined) query.num.get else limit // num in query orerrides num in config
logger.info(s"UR query num = ${query.num}")
logger.info(s"query.num.getOrElse returned numRecs: ${numRecs}")

val should = buildQueryShould(query, boostable)
logger.info(s"buildQueryShould returned should: ${should}")
val must = buildQueryMust(query, boostable)
logger.info(s"buildQueryMust returned must: ${must}")
val mustNot = buildQueryMustNot(query, events)
logger.info(s"buildQueryMustNot returned mustNot: ${mustNot}")
val sort = buildQuerySort()
logger.info(s"buildQuerySort returned sort: ${sort}")

val json =
("from" -> startPos) ~
("size" -> numRecs) ~
("query" ->
("bool" ->
("should" -> should) ~
("must" -> must) ~
("must_not" -> mustNot) ~
("minimum_should_match" -> 1))) ~
("sort" -> sort)

logger.info(s"json is: ${json}")
val compactJson = compact(render(json))
logger.info(s"compact json is: ${compactJson}")

//logger.info(s"Query:\n$compactJson")
(compactJson, events)
} catch {
case e: IllegalArgumentException => {
logger.warn("whoops, IllegalArgumentException for something in buildQuery.")
("", Seq.empty[Event])
}
}
}
}

Expand Down Expand Up @@ -792,10 +805,10 @@ class URAlgorithm(val ap: URAlgorithmParams)
}

/** Get recent events of the user on items to create the recommendations query from */
def getBiasedRecentUserActions(query: Query): (Seq[BoostableCorrelators], Seq[Event]) = {
def getBiasedRecentUserActions(query: Query)(implicit ec: ExecutionContext): Future[(Seq[BoostableCorrelators], Seq[Event])] = {

val recentEvents = try {
LEventStore.findByEntity(
val recentEventsFuture =
LEventStore.findByEntityAsync(
appName = appName,
// entityType and entityId is specified for fast lookup
entityType = "user",
Expand All @@ -806,13 +819,9 @@ class URAlgorithm(val ap: URAlgorithmParams)
// targetEntityType = None,
// limit = Some(maxQueryEvents), // this will get all history then each action can be limited before using in
// the query
latest = true,
// set time limit to avoid super long DB access
timeout = Duration(200, "millis")).toSeq
} catch {
case e: scala.concurrent.TimeoutException =>
logger.error(s"Timeout when reading recent events. Empty list is used. $e")
Seq.empty[Event]
latest = true).map(_.toSeq)

val recoveredRecentEventsFuture = recentEventsFuture.recover {
case e: NoSuchElementException =>
logger.info("No user id for recs, returning item-based recs if an item is specified in the query.")
Seq.empty[Event]
Expand All @@ -821,21 +830,24 @@ class URAlgorithm(val ap: URAlgorithmParams)
Seq.empty[Event]
}

val userEventBias = query.userBias.getOrElse(userBias)
val userEventsBoost = if (userEventBias > 0 && userEventBias != 1) Some(userEventBias) else None
val rActions = queryEventNames.map { action =>
var items = Seq.empty[String]

for (event <- recentEvents) { // todo: use indidatorParams for each indicator type
if (event.event == action && items.size < indicatorParams(action).maxItemsPerUser) {
items = event.targetEntityId.get +: items
// todo: may throw exception and we should ignore the event instead of crashing
recoveredRecentEventsFuture.map {
recentEvents =>
val userEventBias = query.userBias.getOrElse(userBias)
val userEventsBoost = if (userEventBias > 0 && userEventBias != 1) Some(userEventBias) else None
val rActions = queryEventNames.map { action =>
var items = Seq.empty[String]

for (event <- recentEvents) { // todo: use indidatorParams for each indicator type
if (event.event == action && items.size < indicatorParams(action).maxItemsPerUser) {
items = event.targetEntityId.get +: items
// todo: may throw exception and we should ignore the event instead of crashing
}
// userBias may be None, which will cause no JSON output for this
}
BoostableCorrelators(action, items.distinct, userEventsBoost)
}
// userBias may be None, which will cause no JSON output for this
}
BoostableCorrelators(action, items.distinct, userEventsBoost)
(rActions, recentEvents)
}
(rActions, recentEvents)
}

/** get all metadata fields that potentially have boosts (not filters) */
Expand Down

0 comments on commit 2267d9d

Please sign in to comment.