Skip to content

Commit

Permalink
Fix nested queries (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz authored Jul 14, 2020
1 parent f1cb38a commit f4d0a68
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.klibisz.elastiknn.client

import com.klibisz.elastiknn.ELASTIKNN_NAME
import com.klibisz.elastiknn.api.{ElasticsearchCodec, Mapping, NearestNeighborsQuery}
import com.sksamuel.elastic4s.{XContentBuilder, XContentFactory}
import com.sksamuel.elastic4s.requests.searches.queries.{CustomQuery, Query}

import scala.language.implicitConversions

object Elastic4sCompatibility {

implicit def convertQuery(nnq: NearestNeighborsQuery): Query = nnq.toQuery

implicit class NearestNeighborsQueryCompat(nnq: NearestNeighborsQuery) {
def toQuery: Query = new CustomQuery {
override def buildQueryBody(): XContentBuilder =
XContentFactory.jsonBuilder.rawField(s"${ELASTIKNN_NAME}_nearest_neighbors", ElasticsearchCodec.nospaces(nnq))

}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ object ElastiknnClient {
}
if (res.isError) {
Left(res.error.asException)
} else if (res.status != 200) Left(StrictFailureException(s"Returned non-200 response [$res] for request [$req]."))
} else if (res.status >= 300) Left(StrictFailureException(s"Returned non-200 response [$res] for request [$req]."))
else
res.result match {
case bulkResponse: BulkResponse if bulkResponse.hasFailures =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package com.klibisz.elastiknn.client

import com.klibisz.elastiknn.ELASTIKNN_NAME
import com.klibisz.elastiknn.api.{ElasticsearchCodec, Mapping, NearestNeighborsQuery, Vec}
import com.sksamuel.elastic4s.{ElasticDsl, Indexes, XContentBuilder, XContentFactory}
import com.sksamuel.elastic4s.requests.indexes.IndexRequest
import com.sksamuel.elastic4s.requests.mappings.PutMappingRequest
import com.sksamuel.elastic4s.requests.searches.SearchRequest
import com.sksamuel.elastic4s.requests.searches.queries.CustomQuery
import com.sksamuel.elastic4s.{ElasticDsl, Indexes, XContentFactory}
import com.klibisz.elastiknn.client.Elastic4sCompatibility._

/**
* Methods for creating Elastic4s requests for common elastiknn tasks.
Expand Down Expand Up @@ -43,21 +42,15 @@ trait ElastiknnRequests {
* @param storedIdField Field containing the document ID. See [[ElastiknnRequests.index()]] method.
* @return Instance of [[com.sksamuel.elastic4s.requests.searches.SearchRequest]].
*/
def nearestNeighbors(index: String, query: NearestNeighborsQuery, k: Int, storedIdField: String): SearchRequest = {
val json = ElasticsearchCodec.nospaces(query)
val customQuery = new CustomQuery {
override def buildQueryBody(): XContentBuilder =
XContentFactory.jsonBuilder.rawField(s"${ELASTIKNN_NAME}_nearest_neighbors", json)
}
def nearestNeighbors(index: String, query: NearestNeighborsQuery, k: Int, storedIdField: String): SearchRequest =
ElasticDsl
.search(index)
.query(customQuery)
.query(query)
.fetchSource(false)
.storedFields("_none_")
.docValues(Seq(storedIdField))
.preference(query.hashCode.toString)
.size(k)
}

/**
* Create a mapping containing a vector field and a stored ID field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,28 +114,30 @@ final case class KnnQueryBuilder(query: NearestNeighborsQuery) extends AbstractQ
private def getMapping(context: QueryShardContext): Mapping = {
import KnnQueryBuilder.mappingCache
val index = context.index.getName
mappingCache.get(
(index, query.field),
() =>
try {
try {
mappingCache.get(
(index, query.field),
() => {
val client = context.getClient
val request = new GetFieldMappingsRequest().indices(index).fields(query.field)
val response = client.execute(GetFieldMappingsAction.INSTANCE, request).actionGet(1000)
val srcMap = response
val mappingMap = response
.mappings()
.get(index)
.get("_doc")
.get(query.field)
.sourceAsMap()
.get(query.field)
.asInstanceOf[JavaJsonMap]
val srcJson = javaMapEncoder(srcMap)
val mapping = ElasticsearchCodec.decodeJsonGet[Mapping](srcJson)
mapping
} catch {
case e: Exception => throw new RuntimeException(s"Failed to retrieve mapping at index [$index] field [${query.field}]", e)
}
)
.get(query.field.split('.').last) // For nested fields e.g. "foo.bar.vec" -> "vec"
val mappingJsonMap = mappingMap.asInstanceOf[JavaJsonMap]
val mappingJson = javaMapEncoder(mappingJsonMap)
ElasticsearchCodec.decodeJsonGet[Mapping](mappingJson)
}
)
} catch {
// Pass along only the cause instead of the ExecutionException.
case e: java.util.concurrent.ExecutionException =>
throw new RuntimeException(s"Failed to retrieve mapping at index [$index] field [${query.field}]", e.getCause)
}
}

override def doEquals(other: KnnQueryBuilder): Boolean = other.query == this.query
Expand All @@ -158,7 +160,10 @@ final case class KnnQueryBuilder(query: NearestNeighborsQuery) extends AbstractQ
val srcJson: Json = javaMapEncoder(srcMap)
val vector = ElasticsearchCodec.decodeJsonGet[api.Vec](srcJson)
supplier.set(copy(query.withVec(vector)))
l.asInstanceOf[ActionListener[Any]].onResponse(null)
l match {
case a: ActionListener[Any] => a.onResponse(null)
case _ =>
}
} catch {
case e: Exception => l.onFailure(ex(e))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.klibisz.elastiknn.testing
import com.klibisz.elastiknn.client.ElastiknnClient
import com.sksamuel.elastic4s.http.JavaClient
import com.sksamuel.elastic4s.{ElasticClient, Executor}
import com.sksamuel.elastic4s.ElasticDsl._
import org.apache.http.HttpHost
import org.elasticsearch.client.RestClient
import org.scalatest.AsyncTestSuite
Expand All @@ -24,6 +25,12 @@ trait ElasticAsyncClient {
ElasticClient(jc)
}

protected def deleteIfExists(index: String): Future[Unit] =
for {
ex <- eknn.execute(indexExists(index)).map(_.result.exists).recover { case _ => false }
_ <- if (ex) eknn.execute(deleteIndex(index)) else Future.successful(())
} yield ()

protected lazy val eknn: ElastiknnClient[Future] = ElastiknnClient.futureClient()

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.klibisz.elastiknn.query

import com.klibisz.elastiknn.api._
import com.klibisz.elastiknn.client.Elastic4sCompatibility._
import com.klibisz.elastiknn.testing.{ElasticAsyncClient, SilentMatchers}
import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.XContentFactory
import com.sksamuel.elastic4s.requests.common.RefreshPolicy
import org.scalatest.{AsyncFunSpec, Inspectors, Matchers}

import scala.util.Random

class NearestNeighborsQuerySpec extends AsyncFunSpec with Matchers with Inspectors with ElasticAsyncClient with SilentMatchers {

// https://github.com/alexklibisz/elastiknn/issues/60
describe("Vectors in nested fields") {
implicit val rng: Random = new Random(0)
val index = "test-queries-nested-fields"
val vec = Vec.DenseFloat.random(10)
val mapping = Mapping.DenseFloat(vec.values.length)
val nestedFields = Seq(
"vec",
"foo.vec",
"foo.bar.vec",
"foo.bar.baz.vec"
)

for {
nestedField <- nestedFields
} yield {
val (mappingSource, docSource) = {
val subFields = nestedField.split('.')
val xcbMapping = XContentFactory.obj()
val xcbDoc = XContentFactory.obj()
xcbMapping.startObject("properties")
subFields.init.foreach { f =>
xcbMapping.startObject(f)
xcbMapping.startObject("properties")
xcbDoc.startObject(f)
}
xcbMapping.rawField(subFields.last, ElasticsearchCodec.mapping(mapping).spaces2)
xcbDoc.rawField(subFields.last, ElasticsearchCodec.vector(vec).spaces2)
subFields.init.foreach { _ =>
xcbMapping.endObject()
xcbMapping.endObject()
xcbDoc.endObject()
}
xcbMapping.endObject()
xcbDoc.endObject()
(xcbMapping.string(), xcbDoc.string())
}
it(s"works with nested field: $nestedField") {
for {
_ <- deleteIfExists(index)
_ <- eknn.execute(createIndex(index))
_ <- eknn.execute(putMapping(index).rawSource(mappingSource))
_ <- eknn.execute(indexInto(index).source(docSource).refresh(RefreshPolicy.IMMEDIATE))
res <- eknn.execute(search(index).query(NearestNeighborsQuery.Exact(nestedField, Similarity.L2, vec)))
} yield {
res.result.hits.hits should have length 1
res.result.hits.maxScore shouldBe 1.0
}
}
}
}
}

0 comments on commit f4d0a68

Please sign in to comment.