From 43c73c6bdc148123eb7d0b5ff1383497d5706cd9 Mon Sep 17 00:00:00 2001 From: Oskar Stark Date: Fri, 6 Dec 2024 11:19:10 +0100 Subject: [PATCH] [BC BREAK] Extend `VectorSearchInterface::query()` for a `$minScore` --- src/Bridge/Azure/Store/SearchStore.php | 2 +- src/Bridge/ChromaDB/Store.php | 2 +- src/Bridge/MongoDB/Store.php | 19 ++++++++++++++++--- src/Bridge/Pinecone/Store.php | 2 +- src/Store/VectorStoreInterface.php | 2 +- 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/Bridge/Azure/Store/SearchStore.php b/src/Bridge/Azure/Store/SearchStore.php index e1e74144..53d11b96 100644 --- a/src/Bridge/Azure/Store/SearchStore.php +++ b/src/Bridge/Azure/Store/SearchStore.php @@ -34,7 +34,7 @@ public function add(VectorDocument ...$documents): void ]); } - public function query(Vector $vector, array $options = []): array + public function query(Vector $vector, array $options = [], ?float $minScore = null): array { $result = $this->request('search', [ 'vectorQueries' => [$this->buildVectorQuery($vector)], diff --git a/src/Bridge/ChromaDB/Store.php b/src/Bridge/ChromaDB/Store.php index 57fdb60e..cd2cfe91 100644 --- a/src/Bridge/ChromaDB/Store.php +++ b/src/Bridge/ChromaDB/Store.php @@ -34,7 +34,7 @@ public function add(VectorDocument ...$documents): void $collection->add($ids, $vectors, $metadata); } - public function query(Vector $vector, array $options = []): array + public function query(Vector $vector, array $options = [], ?float $minScore = null): array { $collection = $this->client->getOrCreateCollection($this->collectionName); $queryResponse = $collection->query( diff --git a/src/Bridge/MongoDB/Store.php b/src/Bridge/MongoDB/Store.php index 7d72ab8e..93325491 100644 --- a/src/Bridge/MongoDB/Store.php +++ b/src/Bridge/MongoDB/Store.php @@ -95,9 +95,9 @@ public function add(VectorDocument ...$documents): void * filter?: array * } $options */ - public function query(Vector $vector, array $options = []): array + public function query(Vector $vector, array $options = [], ?float $minScore = null): array { - $results = $this->getCollection()->aggregate([ + $pipeline = [ [ '$vectorSearch' => array_merge([ 'index' => $this->indexName, @@ -112,7 +112,20 @@ public function query(Vector $vector, array $options = []): array 'score' => ['$meta' => 'vectorSearchScore'], ], ], - ], ['typeMap' => ['root' => 'array', 'document' => 'array', 'array' => 'array']]); + ]; + + if (null !== $minScore) { + $pipeline[] = [ + '$match' => [ + 'score' => ['$gte' => $minScore], + ], + ]; + } + + $results = $this->getCollection()->aggregate( + $pipeline, + ['typeMap' => ['root' => 'array', 'document' => 'array', 'array' => 'array']] + ); $documents = []; diff --git a/src/Bridge/Pinecone/Store.php b/src/Bridge/Pinecone/Store.php index fa100886..1b2164ce 100644 --- a/src/Bridge/Pinecone/Store.php +++ b/src/Bridge/Pinecone/Store.php @@ -43,7 +43,7 @@ public function add(VectorDocument ...$documents): void $this->getVectors()->upsert($vectors); } - public function query(Vector $vector, array $options = []): array + public function query(Vector $vector, array $options = [], ?float $minScore = null): array { $response = $this->getVectors()->query( vector: $vector->getData(), diff --git a/src/Store/VectorStoreInterface.php b/src/Store/VectorStoreInterface.php index 4676e168..87bc2cdf 100644 --- a/src/Store/VectorStoreInterface.php +++ b/src/Store/VectorStoreInterface.php @@ -14,5 +14,5 @@ interface VectorStoreInterface extends StoreInterface * * @return VectorDocument[] */ - public function query(Vector $vector, array $options = []): array; + public function query(Vector $vector, array $options = [], ?float $minScore = null): array; }