Skip to content

Commit

Permalink
Fix RedisVectorStoreDriver bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed May 15, 2024
1 parent c39ad53 commit fd7f5d4
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 20 deletions.
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ print(result)

The format for creating a vector index should be similar to the following:
```
FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA tag TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE
FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE
```

## OpenSearch Vector Store Driver
Expand Down
12 changes: 8 additions & 4 deletions griptape/drivers/vector/redis_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def upsert_vector(
mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
mapping["vec_string"] = bytes_vector

if namespace:
mapping["namespace"] = namespace

if meta:
mapping["metadata"] = json.dumps(meta)

Expand Down Expand Up @@ -120,8 +123,9 @@ def query(

vector = self.embedding_driver.embed_string(query)

filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*"
query_expression = (
Query(f"*=>[KNN {count or 10} @vector $vector as score]")
Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
.sort_by("score")
.return_fields("id", "score", "metadata", "vec_string")
.paging(0, count or 10)
Expand All @@ -134,15 +138,15 @@ def query(

query_results = []
for document in results:
metadata = getattr(document, "metadata", None)
metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None
namespace = document.id.split(":")[0] if ":" in document.id else None
vector_id = document.id.split(":")[1] if ":" in document.id else document.id
vector_float_list = json.loads(document["vec_string"]) if include_vectors else None
vector_float_list = json.loads(document.vec_string) if include_vectors else None
query_results.append(
BaseVectorStoreDriver.QueryResult(
id=vector_id,
vector=vector_float_list,
score=float(document["score"]),
score=float(document.score),
meta=metadata,
namespace=namespace,
)
Expand Down
80 changes: 65 additions & 15 deletions tests/unit/drivers/vector/test_redis_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unittest.mock import MagicMock
import pytest
import redis
from tests.mocks.mock_embedding_driver import MockEmbeddingDriver
Expand All @@ -6,43 +7,92 @@

class TestRedisVectorStorageDriver:
@pytest.fixture(autouse=True)
def mock_redis(self, mocker):
fake_hgetall_response = {b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", b"metadata": b'{"foo": "bar"}'}
def mock_client(self, mocker):
return mocker.patch("redis.Redis").return_value

mocker.patch.object(redis.StrictRedis, "hset", return_value=None)
mocker.patch.object(redis.StrictRedis, "hgetall", return_value=fake_hgetall_response)
mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"some_namespace:some_vector_id"])

fake_redisearch = mocker.MagicMock()
fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[]))
fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found"))
fake_redisearch.create_index = mocker.MagicMock(return_value=None)
@pytest.fixture
def mock_keys(self, mock_client):
mock_client.keys.return_value = [b"some_vector_id"]
return mock_client.keys

mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch)
@pytest.fixture
def mock_hgetall(self, mock_client):
mock_client.hgetall.return_value = {
b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@",
b"metadata": b'{"foo": "bar"}',
}
return mock_client.hgetall

@pytest.fixture
def driver(self):
return RedisVectorStoreDriver(
host="localhost", port=6379, index="test_index", db=0, embedding_driver=MockEmbeddingDriver()
)

@pytest.fixture
def mock_search(self, mock_client):
mock_client.ft.return_value.search.return_value.docs = [
MagicMock(
id="some_namespace:some_vector_id",
score="0.456198036671",
metadata='{"foo": "bar"}',
vec_string="[1.0, 2.0, 3.0]",
)
]
return mock_client.ft.return_value.search

def test_upsert_vector(self, driver):
assert (
driver.upsert_vector([1.0, 2.0, 3.0], vector_id="some_vector_id", namespace="some_namespace")
== "some_vector_id"
)

def test_load_entry(self, driver):
def test_load_entry(self, driver, mock_hgetall):
entry = driver.load_entry("some_vector_id")
mock_hgetall.assert_called_once_with("some_vector_id")
assert entry.id == "some_vector_id"
assert entry.vector == [1.0, 2.0, 3.0]
assert entry.meta == {"foo": "bar"}

def test_load_entry_with_namespace(self, driver, mock_hgetall):
entry = driver.load_entry("some_vector_id", namespace="some_namespace")
mock_hgetall.assert_called_once_with("some_namespace:some_vector_id")
assert entry.id == "some_vector_id"
assert entry.vector == [1.0, 2.0, 3.0]
assert entry.meta == {"foo": "bar"}

def test_load_entries(self, driver):
def test_load_entries(self, driver, mock_keys, mock_hgetall):
entries = driver.load_entries()
mock_keys.assert_called_once_with("*")
mock_hgetall.assert_called_once_with("some_vector_id")
assert len(entries) == 1
assert entries[0].vector == [1.0, 2.0, 3.0]
assert entries[0].meta == {"foo": "bar"}

def test_load_entries_with_namespace(self, driver, mock_keys, mock_hgetall):
entries = driver.load_entries(namespace="some_namespace")
mock_keys.assert_called_once_with("some_namespace:*")
mock_hgetall.assert_called_once_with("some_namespace:some_vector_id")
assert len(entries) == 1
assert entries[0].vector == [1.0, 2.0, 3.0]
assert entries[0].meta == {"foo": "bar"}

def test_query(self, driver):
assert driver.query("some_vector_id") == []
def test_query(self, driver, mock_search):
results = driver.query("Some query")
mock_search.assert_called_once()
assert len(results) == 1
assert results[0].namespace == "some_namespace"
assert results[0].id == "some_vector_id"
assert results[0].score == 0.456198036671
assert results[0].meta == {"foo": "bar"}
assert results[0].vector is None

def test_query_with_include_vectors(self, driver, mock_search):
results = driver.query("Some query", include_vectors=True)
mock_search.assert_called_once()
assert len(results) == 1
assert results[0].namespace == "some_namespace"
assert results[0].id == "some_vector_id"
assert results[0].score == 0.456198036671
assert results[0].meta == {"foo": "bar"}
assert results[0].vector == [1.0, 2.0, 3.0]

0 comments on commit fd7f5d4

Please sign in to comment.