From fd7f5d423dd33b0905d18f5739656c5441a22a15 Mon Sep 17 00:00:00 2001 From: dylanholmes <4370153+dylanholmes@users.noreply.github.com> Date: Wed, 15 May 2024 12:28:21 +0200 Subject: [PATCH] Fix RedisVectorStoreDriver bugs --- .../drivers/vector-store-drivers.md | 2 +- .../vector/redis_vector_store_driver.py | 12 ++- .../vector/test_redis_vector_store_driver.py | 80 +++++++++++++++---- 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/docs/griptape-framework/drivers/vector-store-drivers.md b/docs/griptape-framework/drivers/vector-store-drivers.md index e2e85d2c9..73a416c84 100644 --- a/docs/griptape-framework/drivers/vector-store-drivers.md +++ b/docs/griptape-framework/drivers/vector-store-drivers.md @@ -316,7 +316,7 @@ print(result) The format for creating a vector index should be similar to the following: ``` -FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA tag TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE +FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE ``` ## OpenSearch Vector Store Driver diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index db99725a3..3772818ab 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -64,6 +64,9 @@ def upsert_vector( mapping["vector"] = np.array(vector, dtype=np.float32).tobytes() mapping["vec_string"] = bytes_vector + if namespace: + mapping["namespace"] = namespace + if meta: mapping["metadata"] = json.dumps(meta) @@ -120,8 +123,9 @@ def query( vector = self.embedding_driver.embed_string(query) + filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*" query_expression = ( - Query(f"*=>[KNN {count or 10} @vector $vector as score]") + Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]") .sort_by("score") .return_fields("id", "score", "metadata", "vec_string") .paging(0, count or 10) @@ -134,15 +138,15 @@ def query( query_results = [] for document in results: - metadata = getattr(document, "metadata", None) + metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None namespace = document.id.split(":")[0] if ":" in document.id else None vector_id = document.id.split(":")[1] if ":" in document.id else document.id - vector_float_list = json.loads(document["vec_string"]) if include_vectors else None + vector_float_list = json.loads(document.vec_string) if include_vectors else None query_results.append( BaseVectorStoreDriver.QueryResult( id=vector_id, vector=vector_float_list, - score=float(document["score"]), + score=float(document.score), meta=metadata, namespace=namespace, ) diff --git a/tests/unit/drivers/vector/test_redis_vector_store_driver.py b/tests/unit/drivers/vector/test_redis_vector_store_driver.py index b5dfa2832..3c98180e7 100644 --- a/tests/unit/drivers/vector/test_redis_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_redis_vector_store_driver.py @@ -1,3 +1,4 @@ +from unittest.mock import MagicMock import pytest import redis from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -6,19 +7,21 @@ class TestRedisVectorStorageDriver: @pytest.fixture(autouse=True) - def mock_redis(self, mocker): - fake_hgetall_response = {b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", b"metadata": b'{"foo": "bar"}'} + def mock_client(self, mocker): + return mocker.patch("redis.Redis").return_value - mocker.patch.object(redis.StrictRedis, "hset", return_value=None) - mocker.patch.object(redis.StrictRedis, "hgetall", return_value=fake_hgetall_response) - mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"some_namespace:some_vector_id"]) - - fake_redisearch = mocker.MagicMock() - fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[])) - fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found")) - fake_redisearch.create_index = mocker.MagicMock(return_value=None) + @pytest.fixture + def mock_keys(self, mock_client): + mock_client.keys.return_value = [b"some_vector_id"] + return mock_client.keys - mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch) + @pytest.fixture + def mock_hgetall(self, mock_client): + mock_client.hgetall.return_value = { + b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", + b"metadata": b'{"foo": "bar"}', + } + return mock_client.hgetall @pytest.fixture def driver(self): @@ -26,23 +29,70 @@ def driver(self): host="localhost", port=6379, index="test_index", db=0, embedding_driver=MockEmbeddingDriver() ) + @pytest.fixture + def mock_search(self, mock_client): + mock_client.ft.return_value.search.return_value.docs = [ + MagicMock( + id="some_namespace:some_vector_id", + score="0.456198036671", + metadata='{"foo": "bar"}', + vec_string="[1.0, 2.0, 3.0]", + ) + ] + return mock_client.ft.return_value.search + def test_upsert_vector(self, driver): assert ( driver.upsert_vector([1.0, 2.0, 3.0], vector_id="some_vector_id", namespace="some_namespace") == "some_vector_id" ) - def test_load_entry(self, driver): + def test_load_entry(self, driver, mock_hgetall): + entry = driver.load_entry("some_vector_id") + mock_hgetall.assert_called_once_with("some_vector_id") + assert entry.id == "some_vector_id" + assert entry.vector == [1.0, 2.0, 3.0] + assert entry.meta == {"foo": "bar"} + + def test_load_entry_with_namespace(self, driver, mock_hgetall): entry = driver.load_entry("some_vector_id", namespace="some_namespace") + mock_hgetall.assert_called_once_with("some_namespace:some_vector_id") assert entry.id == "some_vector_id" assert entry.vector == [1.0, 2.0, 3.0] assert entry.meta == {"foo": "bar"} - def test_load_entries(self, driver): + def test_load_entries(self, driver, mock_keys, mock_hgetall): + entries = driver.load_entries() + mock_keys.assert_called_once_with("*") + mock_hgetall.assert_called_once_with("some_vector_id") + assert len(entries) == 1 + assert entries[0].vector == [1.0, 2.0, 3.0] + assert entries[0].meta == {"foo": "bar"} + + def test_load_entries_with_namespace(self, driver, mock_keys, mock_hgetall): entries = driver.load_entries(namespace="some_namespace") + mock_keys.assert_called_once_with("some_namespace:*") + mock_hgetall.assert_called_once_with("some_namespace:some_vector_id") assert len(entries) == 1 assert entries[0].vector == [1.0, 2.0, 3.0] assert entries[0].meta == {"foo": "bar"} - def test_query(self, driver): - assert driver.query("some_vector_id") == [] + def test_query(self, driver, mock_search): + results = driver.query("Some query") + mock_search.assert_called_once() + assert len(results) == 1 + assert results[0].namespace == "some_namespace" + assert results[0].id == "some_vector_id" + assert results[0].score == 0.456198036671 + assert results[0].meta == {"foo": "bar"} + assert results[0].vector is None + + def test_query_with_include_vectors(self, driver, mock_search): + results = driver.query("Some query", include_vectors=True) + mock_search.assert_called_once() + assert len(results) == 1 + assert results[0].namespace == "some_namespace" + assert results[0].id == "some_vector_id" + assert results[0].score == 0.456198036671 + assert results[0].meta == {"foo": "bar"} + assert results[0].vector == [1.0, 2.0, 3.0]