Skip to content

Commit

Permalink
Continuous benchmark tweaks (#118)
Browse files Browse the repository at this point in the history
- Less logging
- Add optional timeout to client (necessary for requests to merge shards which push the default 30s limit)
- Only using L2 and Angular in benchmark
- L2 dataset has same dimensions as mnist/fashion mnist. Recall is still pretty poor.
  • Loading branch information
alexklibisz authored Jul 29, 2020
1 parent 3a67e30 commit 5301738
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,29 @@ import zio.console.Console
*/
object ContinuousBenchmark extends App {

private val randomDenseFloats = Dataset.RandomDenseFloat(1000, 50000, 1000)
private val randomSparseBools = Dataset.RandomSparseBool(3000, 50000, 1000)
private val randomDenseFloats = Dataset.RandomDenseFloat(784, 60000, 10000)
private val field = "vec"
private val bucket = s"elastiknn-benchmarks"
private val k = 100

private val experiments = Seq(
// L2 exact, LSH
// L2
Experiment(
randomDenseFloats,
Mapping.DenseFloat(randomDenseFloats.dims),
NearestNeighborsQuery.Exact(field, Similarity.L2),
Mapping.L2Lsh(randomDenseFloats.dims, 400, 1, 3),
Mapping.L2Lsh(randomDenseFloats.dims, 300, 2, 3),
Seq(
Query(NearestNeighborsQuery.L2Lsh(field, 1000), k)
Query(NearestNeighborsQuery.L2Lsh(field, 4000), k)
)
),
Experiment(
randomDenseFloats,
Mapping.DenseFloat(randomDenseFloats.dims),
NearestNeighborsQuery.Exact(field, Similarity.Angular),
// Angular LSH seems to benefit a lot from increasing k.
Mapping.AngularLsh(randomDenseFloats.dims, 250, 3),
Seq(
Query(NearestNeighborsQuery.AngularLsh(field, 1000), k)
)
),
// Jaccard exact, sparse indexed, LSH
Experiment(
randomSparseBools,
Mapping.SparseBool(randomSparseBools.dims),
NearestNeighborsQuery.Exact(field, Similarity.Jaccard),
Mapping.JaccardLsh(randomSparseBools.dims, 400, 1),
Seq(
Query(NearestNeighborsQuery.JaccardLsh(field, 1000), k)
)
),
// Angular exact, LSH on Glove100 dataset. Still experimental, excluded by default.
// Angular
Experiment(
Dataset.AnnbGlove25,
Mapping.DenseFloat(Dataset.AnnbGlove25.dims),
NearestNeighborsQuery.Exact(field, Similarity.Angular),
Mapping.AngularLsh(Dataset.AnnbGlove25.dims, 250, 3),
Mapping.AngularLsh(Dataset.AnnbGlove25.dims, 100, 3),
Seq(
Query(NearestNeighborsQuery.AngularLsh(field, 1000), k),
Query(NearestNeighborsQuery.AngularLsh(field, 2000), k),
Query(NearestNeighborsQuery.AngularLsh(field, 4000), k)
)
),
Expand All @@ -64,8 +41,6 @@ object ContinuousBenchmark extends App {
NearestNeighborsQuery.Exact(field, Similarity.Angular),
Mapping.PermutationLsh(Dataset.AnnbGlove25.dims, 15, repeating = false),
Seq(
Query(NearestNeighborsQuery.PermutationLsh(field, Similarity.Angular, 1000), k),
Query(NearestNeighborsQuery.PermutationLsh(field, Similarity.Angular, 2000), k),
Query(NearestNeighborsQuery.PermutationLsh(field, Similarity.Angular, 4000), k)
)
),
Expand All @@ -75,8 +50,6 @@ object ContinuousBenchmark extends App {
NearestNeighborsQuery.Exact(field, Similarity.Angular),
Mapping.PermutationLsh(Dataset.AnnbGlove25.dims, 15, repeating = true),
Seq(
Query(NearestNeighborsQuery.PermutationLsh(field, Similarity.Angular, 1000), k),
Query(NearestNeighborsQuery.PermutationLsh(field, Similarity.Angular, 2000), k),
Query(NearestNeighborsQuery.PermutationLsh(field, Similarity.Angular, 4000), k)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ object ElastiknnZioClient {
override def execute[T, U](t: T)(implicit handler: Handler[T, U], manifest: Manifest[U]): Task[Response[U]]
}

def fromFutureClient(host: String, port: Int, strictFailure: Boolean): Layer[Throwable, Has[Service]] =
def fromFutureClient(host: String, port: Int, strictFailure: Boolean, timeoutMillis: Int): Layer[Throwable, Has[Service]] =
ZLayer.fromEffect(ZIO.fromFuture { implicit ec =>
val client = ElastiknnClient.futureClient(host, port, strictFailure)
val client = ElastiknnClient.futureClient(host, port, strictFailure, timeoutMillis)
val service = new Service {
override def execute[T, U](t: T)(implicit handler: Handler[T, U], manifest: Manifest[U]): Task[Response[U]] =
Task.fromFuture(_ => client.execute(t))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ object Execute extends App {
case (vec, i) =>
for {
(dur, res) <- eknnClient.nearestNeighbors(trainIndex, eknnQuery.withVec(vec), k, storedIdField).timed
_ <- if (i % 10 == 0) log.debug(s"Completed query $i in ${dur.toMillis} ms") else ZIO.succeed(())
_ <- if (i % 100 == 0) log.debug(s"Completed query $i in $trainIndex in ${dur.toMillis} ms") else ZIO.succeed(())
} yield QueryResult(res.result.hits.hits.map(_.id), res.result.took)
}
(dur, responses) <- requests.run(ZSink.collectAll).timed
Expand Down Expand Up @@ -223,7 +223,7 @@ object Execute extends App {
blockingWithS3 ++
(blockingWithS3 >>> ResultClient.s3(params.resultsBucket, params.resultsPrefix)) ++
(blockingWithS3 >>> DatasetClient.s3(params.datasetsBucket, params.datasetsPrefix)) ++
ElastiknnZioClient.fromFutureClient("localhost", 9200, true) ++
ElastiknnZioClient.fromFutureClient("localhost", 9200, true, 60000) ++
Slf4jLogger.make((_, s) => s, Some(getClass.getSimpleName))

val logic = for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import com.sksamuel.elastic4s.requests.bulk.{BulkResponse, BulkResponseItem}
import com.sksamuel.elastic4s.requests.indexes.PutMappingResponse
import com.sksamuel.elastic4s.requests.searches.{SearchRequest, SearchResponse}
import org.apache.http.HttpHost
import org.elasticsearch.client.RestClient
import org.apache.http.client.config.RequestConfig
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder
import org.elasticsearch.client.{RestClient, RestClientBuilder}
import org.elasticsearch.client.RestClientBuilder.{HttpClientConfigCallback, RequestConfigCallback}

import scala.annotation.tailrec
import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -80,9 +83,14 @@ object ElastiknnClient {

final case class StrictFailureException(message: String, cause: Throwable = None.orNull) extends RuntimeException(message, cause)

def futureClient(host: String = "localhost", port: Int = 9200, strictFailure: Boolean = true)(
def futureClient(host: String = "localhost", port: Int = 9200, strictFailure: Boolean = true, timeoutMillis: Int = 30000)(
implicit ec: ExecutionContext): ElastiknnFutureClient = {
val rc: RestClient = RestClient.builder(new HttpHost(host, port)).build()
val rc: RestClient = RestClient
.builder(new HttpHost(host, port))
.setRequestConfigCallback(
(requestConfigBuilder: RequestConfig.Builder) => requestConfigBuilder.setSocketTimeout(timeoutMillis)
)
.build()
val jc: JavaClient = new JavaClient(rc)
new ElastiknnFutureClient {
implicit val executor: Executor[Future] = Executor.FutureExecutor(ec)
Expand Down

0 comments on commit 5301738

Please sign in to comment.