-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Redis]: Vector database added. (#2032)
- Loading branch information
1 parent
13374a1
commit 5ab09ff
Showing
7 changed files
with
311 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
[Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data. | ||
|
||
### Installation | ||
```bash | ||
pip install redis redisvl | ||
``` | ||
|
||
Redis Stack using Docker: | ||
```bash | ||
docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest | ||
``` | ||
|
||
### Usage | ||
|
||
```python | ||
import os | ||
from mem0 import Memory | ||
|
||
os.environ["OPENAI_API_KEY"] = "sk-xx" | ||
|
||
config = { | ||
"vector_store": { | ||
"provider": "redis", | ||
"config": { | ||
"collection_name": "mem0", | ||
"embedding_model_dims": 1536, | ||
"redis_url": "redis://localhost:6379" | ||
} | ||
} | ||
} | ||
|
||
m = Memory.from_config(config) | ||
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) | ||
``` | ||
|
||
### Config | ||
|
||
Let's see the available parameters for the `redis` config: | ||
|
||
| Parameter | Description | Default Value | | ||
| --- | --- | --- | | ||
| `collection_name` | The name of the collection to store the vectors | `mem0` | | ||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | | ||
| `redis_url` | The URL of the Redis server | `None` | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Any, Dict | ||
|
||
from pydantic import BaseModel, Field, model_validator | ||
|
||
|
||
# TODO: Upgrade to latest pydantic version | ||
class RedisDBConfig(BaseModel): | ||
redis_url: str = Field(..., description="Redis URL") | ||
collection_name: str = Field("mem0", description="Collection name") | ||
embedding_model_dims: int = Field(1536, description="Embedding model dimensions") | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
allowed_fields = set(cls.model_fields.keys()) | ||
input_fields = set(values.keys()) | ||
extra_fields = input_fields - allowed_fields | ||
if extra_fields: | ||
raise ValueError( | ||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" | ||
) | ||
return values | ||
|
||
model_config = { | ||
"arbitrary_types_allowed": True, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
import json | ||
import logging | ||
from datetime import datetime | ||
from functools import reduce | ||
|
||
import numpy as np | ||
import pytz | ||
import redis | ||
from redis.commands.search.query import Query | ||
from redisvl.index import SearchIndex | ||
from redisvl.query import VectorQuery | ||
from redisvl.query.filter import Tag | ||
|
||
from mem0.vector_stores.base import VectorStoreBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them. | ||
DEFAULT_FIELDS = [ | ||
{"name": "memory_id", "type": "tag"}, | ||
{"name": "hash", "type": "tag"}, | ||
{"name": "agent_id", "type": "tag"}, | ||
{"name": "run_id", "type": "tag"}, | ||
{"name": "user_id", "type": "tag"}, | ||
{"name": "memory", "type": "text"}, | ||
{"name": "metadata", "type": "text"}, | ||
# TODO: Although it is numeric but also accepts string | ||
{"name": "created_at", "type": "numeric"}, | ||
{"name": "updated_at", "type": "numeric"}, | ||
{ | ||
"name": "embedding", | ||
"type": "vector", | ||
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"}, | ||
}, | ||
] | ||
|
||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} | ||
|
||
|
||
class MemoryResult: | ||
def __init__(self, id: str, payload: dict, score: float = None): | ||
self.id = id | ||
self.payload = payload | ||
self.score = score | ||
|
||
|
||
class RedisDB(VectorStoreBase): | ||
def __init__( | ||
self, | ||
redis_url: str, | ||
collection_name: str, | ||
embedding_model_dims: int, | ||
): | ||
""" | ||
Initialize the Redis vector store. | ||
Args: | ||
redis_url (str): Redis URL. | ||
collection_name (str): Collection name. | ||
embedding_model_dims (int): Embedding model dimensions. | ||
""" | ||
index_schema = { | ||
"name": collection_name, | ||
"prefix": f"mem0:{collection_name}", | ||
} | ||
|
||
fields = DEFAULT_FIELDS.copy() | ||
fields[-1]["attrs"]["dims"] = embedding_model_dims | ||
|
||
self.schema = {"index": index_schema, "fields": fields} | ||
|
||
self.client = redis.Redis.from_url(redis_url) | ||
self.index = SearchIndex.from_dict(self.schema) | ||
self.index.set_client(self.client) | ||
self.index.create(overwrite=True) | ||
|
||
# TODO: Implement multiindex support. | ||
def create_col(self, name, vector_size, distance): | ||
raise NotImplementedError("Collection/Index creation not supported yet.") | ||
|
||
def insert(self, vectors: list, payloads: list = None, ids: list = None): | ||
data = [] | ||
for vector, payload, id in zip(vectors, payloads, ids): | ||
# Start with required fields | ||
entry = { | ||
"memory_id": id, | ||
"hash": payload["hash"], | ||
"memory": payload["data"], | ||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), | ||
"embedding": np.array(vector, dtype=np.float32).tobytes(), | ||
} | ||
|
||
# Conditionally add optional fields | ||
for field in ["agent_id", "run_id", "user_id"]: | ||
if field in payload: | ||
entry[field] = payload[field] | ||
|
||
# Add metadata excluding specific keys | ||
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) | ||
|
||
data.append(entry) | ||
self.index.load(data, id_field="memory_id") | ||
|
||
def search(self, query: list, limit: int = 5, filters: dict = None): | ||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None] | ||
filter = reduce(lambda x, y: x & y, conditions) | ||
|
||
v = VectorQuery( | ||
vector=np.array(query, dtype=np.float32).tobytes(), | ||
vector_field_name="embedding", | ||
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"], | ||
filter_expression=filter, | ||
num_results=limit, | ||
) | ||
|
||
results = self.index.query(v) | ||
|
||
return [ | ||
MemoryResult( | ||
id=result["memory_id"], | ||
score=result["vector_distance"], | ||
payload={ | ||
"hash": result["hash"], | ||
"data": result["memory"], | ||
"created_at": datetime.fromtimestamp( | ||
int(result["created_at"]), tz=pytz.timezone("US/Pacific") | ||
).isoformat(timespec="microseconds"), | ||
**( | ||
{ | ||
"updated_at": datetime.fromtimestamp( | ||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific") | ||
).isoformat(timespec="microseconds") | ||
} | ||
if "updated_at" in result | ||
else {} | ||
), | ||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, | ||
**{k: v for k, v in json.loads(result["metadata"]).items()}, | ||
}, | ||
) | ||
for result in results | ||
] | ||
|
||
def delete(self, vector_id): | ||
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}") | ||
|
||
def update(self, vector_id=None, vector=None, payload=None): | ||
data = { | ||
"memory_id": vector_id, | ||
"hash": payload["hash"], | ||
"memory": payload["data"], | ||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), | ||
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()), | ||
"embedding": np.array(vector, dtype=np.float32).tobytes(), | ||
} | ||
|
||
for field in ["agent_id", "run_id", "user_id"]: | ||
if field in payload: | ||
data[field] = payload[field] | ||
|
||
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) | ||
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id") | ||
|
||
def get(self, vector_id): | ||
result = self.index.fetch(vector_id) | ||
payload = { | ||
"hash": result["hash"], | ||
"data": result["memory"], | ||
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat( | ||
timespec="microseconds" | ||
), | ||
**( | ||
{ | ||
"updated_at": datetime.fromtimestamp( | ||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific") | ||
).isoformat(timespec="microseconds") | ||
} | ||
if "updated_at" in result | ||
else {} | ||
), | ||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, | ||
**{k: v for k, v in json.loads(result["metadata"]).items()}, | ||
} | ||
|
||
return MemoryResult(id=result["memory_id"], payload=payload) | ||
|
||
def list_cols(self): | ||
return self.index.listall() | ||
|
||
def delete_col(self): | ||
self.index.delete() | ||
|
||
def col_info(self, name): | ||
return self.index.info() | ||
|
||
def list(self, filters: dict = None, limit: int = None) -> list: | ||
""" | ||
List all recent created memories from the vector store. | ||
""" | ||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None] | ||
filter = reduce(lambda x, y: x & y, conditions) | ||
query = Query(str(filter)).sort_by("created_at", asc=False) | ||
if limit is not None: | ||
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit) | ||
|
||
results = self.index.search(query) | ||
return [ | ||
[ | ||
MemoryResult( | ||
id=result["memory_id"], | ||
payload={ | ||
"hash": result["hash"], | ||
"data": result["memory"], | ||
"created_at": datetime.fromtimestamp( | ||
int(result["created_at"]), tz=pytz.timezone("US/Pacific") | ||
).isoformat(timespec="microseconds"), | ||
**( | ||
{ | ||
"updated_at": datetime.fromtimestamp( | ||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific") | ||
).isoformat(timespec="microseconds") | ||
} | ||
if result.__dict__.get("updated_at") | ||
else {} | ||
), | ||
**{ | ||
field: result[field] | ||
for field in ["agent_id", "run_id", "user_id"] | ||
if field in result.__dict__ | ||
}, | ||
**{k: v for k, v in json.loads(result["metadata"]).items()}, | ||
}, | ||
) | ||
for result in results.docs | ||
] | ||
] |