Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code formatting #1986

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cookbooks/helper/mem0_teachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
from termcolor import colored
from mem0 import Memory
from mem0.configs.base import MemoryConfig


class Mem0Teachability(AgentCapability):
def __init__(
Expand Down Expand Up @@ -60,7 +60,6 @@ def process_last_received_message(self, text: Union[Dict, str]):
return expanded_text

def _consider_memo_storage(self, comment: Union[Dict, str]):
memo_added = False
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
Expand All @@ -85,8 +84,9 @@ def _consider_memo_storage(self, comment: Union[Dict, str]):

if self.verbosity >= 1:
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
self.memory.add([{"role": "user", "content": f"Task: {general_task}\nAdvice: {advice}"}], agent_id=self.agent_id)
memo_added = True
self.memory.add(
[{"role": "user", "content": f"Task: {general_task}\nAdvice: {advice}"}], agent_id=self.agent_id
)

response = self._analyze(
comment,
Expand All @@ -105,8 +105,9 @@ def _consider_memo_storage(self, comment: Union[Dict, str]):

if self.verbosity >= 1:
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
self.memory.add([{"role": "user", "content": f"Question: {question}\nAnswer: {answer}"}], agent_id=self.agent_id)
memo_added = True
self.memory.add(
[{"role": "user", "content": f"Question: {question}\nAnswer: {answer}"}], agent_id=self.agent_id
)

def _consider_memo_retrieval(self, comment: Union[Dict, str]):
if self.verbosity >= 1:
Expand Down
18 changes: 10 additions & 8 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str,
payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("client.add", self)
capture_client_event("client.add", self, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand Down Expand Up @@ -162,7 +162,7 @@ def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
capture_client_event(
"client.get_all",
self,
{"filters": len(params), "limit": kwargs.get("limit", 100)},
{"api_version": version, "keys": list(kwargs.keys())},
)
return response.json()

Expand All @@ -186,7 +186,7 @@ def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, An
payload.update({k: v for k, v in kwargs.items() if v is not None})
response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("client.search", self, {"limit": kwargs.get("limit", 100)})
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand Down Expand Up @@ -239,7 +239,7 @@ def delete_all(self, **kwargs) -> Dict[str, str]:
params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("client.delete_all", self, {"params": len(params)})
capture_client_event("client.delete_all", self, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand Down Expand Up @@ -390,7 +390,7 @@ async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dic
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("async_client.add", self.sync_client)
capture_client_event("async_client.add", self.sync_client, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand All @@ -409,7 +409,7 @@ async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
"async_client.get_all", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())}
)
return response.json()

Expand All @@ -419,7 +419,9 @@ async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[s
payload.update(self.sync_client._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
capture_client_event(
"async_client.search", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())}
)
return response.json()

@api_error_handler
Expand All @@ -441,7 +443,7 @@ async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
capture_client_event("async_client.delete_all", self.sync_client, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand Down
2 changes: 0 additions & 2 deletions mem0/configs/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import subprocess
import sys
from typing import Any, ClassVar, Dict, Optional

from pydantic import BaseModel, Field, model_validator
Expand Down
29 changes: 16 additions & 13 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.version = self.config.version
self.api_version = self.config.version

self.enable_graph = False

if self.version == "v1.1" and self.config.graph_store.config:
if self.api_version == "v1.1" and self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph

self.graph = MemoryGraph(self.config)
Expand Down Expand Up @@ -119,7 +119,7 @@ def add(
vector_store_result = future1.result()
graph_result = future2.result()

if self.version == "v1.1":
if self.api_version == "v1.1":
return {
"results": vector_store_result,
"relations": graph_result,
Expand Down Expand Up @@ -226,13 +226,13 @@ def _add_to_vector_store(self, messages, metadata, filters):
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")

capture_event("mem0.add", self)
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})

return returned_memories

def _add_to_graph(self, messages, filters):
added_entities = []
if self.version == "v1.1" and self.enable_graph:
if self.api_version == "v1.1" and self.enable_graph:
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
elif filters["agent_id"]:
Expand Down Expand Up @@ -305,13 +305,13 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
if run_id:
filters["run_id"] = run_id

capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})

with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit)
if self.version == "v1.1" and self.enable_graph
if self.api_version == "v1.1" and self.enable_graph
else None
)

Expand All @@ -322,7 +322,7 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None

if self.version == "v1.1":
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
else:
Expand Down Expand Up @@ -398,14 +398,14 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
capture_event(
"mem0.search",
self,
{"filters": len(filters), "limit": limit, "version": self.version},
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)

with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit)
if self.version == "v1.1" and self.enable_graph
if self.api_version == "v1.1" and self.enable_graph
else None
)

Expand All @@ -416,7 +416,7 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None

if self.version == "v1.1":
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
else:
Expand Down Expand Up @@ -518,14 +518,14 @@ def delete_all(self, user_id=None, agent_id=None, run_id=None):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)

capture_event("mem0.delete_all", self, {"filters": len(filters)})
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)

logger.info(f"Deleted {len(memories)} memories")

if self.version == "v1.1" and self.enable_graph:
if self.api_version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)

return {"message": "Memories deleted successfully!"}
Expand Down Expand Up @@ -561,6 +561,7 @@ def _create_memory(self, data, existing_embeddings, metadata=None):
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id

def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
Expand Down Expand Up @@ -603,6 +604,7 @@ def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
return memory_id

def _delete_memory(self, memory_id):
Expand All @@ -611,6 +613,7 @@ def _delete_memory(self, memory_id):
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion mem0/memory/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.version}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.api_version}",
}
if additional_data:
event_data.update(additional_data)
Expand Down
10 changes: 3 additions & 7 deletions tests/embeddings/test_azure_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ def test_embed_text(mock_openai_client):

@pytest.mark.parametrize(
"default_headers, expected_header",
[
(None, None),
({"Test": "test_value"}, "test_value"),
({}, None)
],
[(None, None), ({"Test": "test_value"}, "test_value"), ({}, None)],
)
def test_embed_text_with_default_headers(default_headers, expected_header):
config = BaseEmbedderConfig(
Expand All @@ -47,8 +43,8 @@ def test_embed_text_with_default_headers(default_headers, expected_header):
"api_version": "test_version",
"azure_endpoint": "test_endpoint",
"azuer_deployment": "test_deployment",
"default_headers": default_headers
}
"default_headers": default_headers,
},
)
embedder = AzureOpenAIEmbedding(config)
assert embedder.client.api_key == "test"
Expand Down
15 changes: 3 additions & 12 deletions tests/embeddings/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,11 @@ def mock_genai():

@pytest.fixture
def config():
return BaseEmbedderConfig(
api_key="dummy_api_key",
model="test_model"
)
return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model")


def test_embed_query(mock_genai, config):

mock_embedding_response = {
'embedding': [0.1, 0.2, 0.3, 0.4]
}
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]}
mock_genai.return_value = mock_embedding_response

embedder = GoogleGenAIEmbedding(config)
Expand All @@ -31,7 +25,4 @@ def test_embed_query(mock_genai, config):
embedding = embedder.embed(text)

assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(
model="test_model",
content="Hello, world!"
)
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!")
31 changes: 8 additions & 23 deletions tests/embeddings/test_vertexai_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from unittest.mock import Mock, patch
from mem0.embeddings.vertexai import VertexAIEmbedding
from mem0.configs.embeddings.base import BaseEmbedderConfig


@pytest.fixture
Expand Down Expand Up @@ -35,15 +34,11 @@ def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_co
embedder = VertexAIEmbedding(config)

mock_embedding = Mock(values=[0.1, 0.2, 0.3])
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
mock_embedding
]
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding]

result = embedder.embed("Hello world")
embedder.embed("Hello world")

mock_text_embedding_model.from_pretrained.assert_called_once_with(
"text-embedding-004"
)
mock_text_embedding_model.from_pretrained.assert_called_once_with("text-embedding-004")
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with(
texts=["Hello world"], output_dimensionality=256
)
Expand All @@ -60,15 +55,11 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
embedder = VertexAIEmbedding(config)

mock_embedding = Mock(values=[0.4, 0.5, 0.6])
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
mock_embedding
]
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding]

result = embedder.embed("Test embedding")

mock_text_embedding_model.from_pretrained.assert_called_with(
"custom-embedding-model"
)
mock_text_embedding_model.from_pretrained.assert_called_with("custom-embedding-model")
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with(
texts=["Test embedding"], output_dimensionality=512
)
Expand All @@ -93,26 +84,20 @@ def test_missing_credentials(mock_os, mock_text_embedding_model, mock_config):

config = mock_config()

with pytest.raises(
ValueError, match="Google application credentials JSON is not provided"
):
with pytest.raises(ValueError, match="Google application credentials JSON is not provided"):
VertexAIEmbedding(config)


@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_different_dimensions(
mock_text_embedding_model, mock_os_environ, mock_config
):
def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_environ, mock_config):
mock_config.vertex_credentials_json = "/path/to/credentials.json"
mock_config.return_value.embedding_dims = 1024

config = mock_config()
embedder = VertexAIEmbedding(config)

mock_embedding = Mock(values=[0.1] * 1024)
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
mock_embedding
]
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding]

result = embedder.embed("Large embedding test")

Expand Down
Loading
Loading