Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code formatting #1986

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cookbooks/helper/mem0_teachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
from termcolor import colored
from mem0 import Memory
from mem0.configs.base import MemoryConfig


class Mem0Teachability(AgentCapability):
def __init__(
Expand Down Expand Up @@ -60,7 +60,6 @@ def process_last_received_message(self, text: Union[Dict, str]):
return expanded_text

def _consider_memo_storage(self, comment: Union[Dict, str]):
memo_added = False
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
Expand All @@ -85,8 +84,9 @@ def _consider_memo_storage(self, comment: Union[Dict, str]):

if self.verbosity >= 1:
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
self.memory.add([{"role": "user", "content": f"Task: {general_task}\nAdvice: {advice}"}], agent_id=self.agent_id)
memo_added = True
self.memory.add(
[{"role": "user", "content": f"Task: {general_task}\nAdvice: {advice}"}], agent_id=self.agent_id
)

response = self._analyze(
comment,
Expand All @@ -105,8 +105,9 @@ def _consider_memo_storage(self, comment: Union[Dict, str]):

if self.verbosity >= 1:
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
self.memory.add([{"role": "user", "content": f"Question: {question}\nAnswer: {answer}"}], agent_id=self.agent_id)
memo_added = True
self.memory.add(
[{"role": "user", "content": f"Question: {question}\nAnswer: {answer}"}], agent_id=self.agent_id
)

def _consider_memo_retrieval(self, comment: Union[Dict, str]):
if self.verbosity >= 1:
Expand Down
46 changes: 30 additions & 16 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str,
payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("client.add", self)
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("client.add", self, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand All @@ -135,7 +137,7 @@ def get(self, memory_id: str) -> Dict[str, Any]:
"""
response = self.client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("client.get", self)
capture_client_event("client.get", self, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -159,10 +161,12 @@ def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
elif version == "v2":
response = self.client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"client.get_all",
self,
{"filters": len(params), "limit": kwargs.get("limit", 100)},
{"api_version": version, "keys": list(kwargs.keys())},
)
return response.json()

Expand All @@ -186,7 +190,9 @@ def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, An
payload.update({k: v for k, v in kwargs.items() if v is not None})
response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("client.search", self, {"limit": kwargs.get("limit", 100)})
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand All @@ -199,7 +205,7 @@ def update(self, memory_id: str, data: str) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The response from the server.
"""
capture_client_event("client.update", self)
capture_client_event("client.update", self, {"memory_id": memory_id})
response = self.client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
return response.json()
Expand All @@ -219,7 +225,7 @@ def delete(self, memory_id: str) -> Dict[str, Any]:
"""
response = self.client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("client.delete", self)
capture_client_event("client.delete", self, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -239,7 +245,7 @@ def delete_all(self, **kwargs) -> Dict[str, str]:
params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("client.delete_all", self, {"params": len(params)})
capture_client_event("client.delete_all", self, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand All @@ -257,7 +263,7 @@ def history(self, memory_id: str) -> List[Dict[str, Any]]:
"""
response = self.client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("client.history", self)
capture_client_event("client.history", self, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand Down Expand Up @@ -390,14 +396,16 @@ async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dic
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("async_client.add", self.sync_client)
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("async_client.add", self.sync_client, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.get", self.sync_client)
capture_client_event("async_client.get", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -408,8 +416,10 @@ async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
"async_client.get_all", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())}
)
return response.json()

Expand All @@ -419,36 +429,40 @@ async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[s
payload.update(self.sync_client._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"async_client.search", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())}
)
return response.json()

@api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("async_client.update", self.sync_client)
capture_client_event("async_client.update", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.delete", self.sync_client)
capture_client_event("async_client.delete", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
capture_client_event("async_client.delete_all", self.sync_client, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("async_client.history", self.sync_client)
capture_client_event("async_client.history", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand Down
2 changes: 0 additions & 2 deletions mem0/configs/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import subprocess
import sys
from typing import Any, ClassVar, Dict, Optional

from pydantic import BaseModel, Field, model_validator
Expand Down
29 changes: 16 additions & 13 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.version = self.config.version
self.api_version = self.config.version

self.enable_graph = False

if self.version == "v1.1" and self.config.graph_store.config:
if self.api_version == "v1.1" and self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph

self.graph = MemoryGraph(self.config)
Expand Down Expand Up @@ -119,7 +119,7 @@ def add(
vector_store_result = future1.result()
graph_result = future2.result()

if self.version == "v1.1":
if self.api_version == "v1.1":
return {
"results": vector_store_result,
"relations": graph_result,
Expand Down Expand Up @@ -226,13 +226,13 @@ def _add_to_vector_store(self, messages, metadata, filters):
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")

capture_event("mem0.add", self)
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})

return returned_memories

def _add_to_graph(self, messages, filters):
added_entities = []
if self.version == "v1.1" and self.enable_graph:
if self.api_version == "v1.1" and self.enable_graph:
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
elif filters["agent_id"]:
Expand Down Expand Up @@ -305,13 +305,13 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
if run_id:
filters["run_id"] = run_id

capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})

with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit)
if self.version == "v1.1" and self.enable_graph
if self.api_version == "v1.1" and self.enable_graph
else None
)

Expand All @@ -322,7 +322,7 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None

if self.version == "v1.1":
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
else:
Expand Down Expand Up @@ -398,14 +398,14 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
capture_event(
"mem0.search",
self,
{"filters": len(filters), "limit": limit, "version": self.version},
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)

with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit)
if self.version == "v1.1" and self.enable_graph
if self.api_version == "v1.1" and self.enable_graph
else None
)

Expand All @@ -416,7 +416,7 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None

if self.version == "v1.1":
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
else:
Expand Down Expand Up @@ -518,14 +518,14 @@ def delete_all(self, user_id=None, agent_id=None, run_id=None):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)

capture_event("mem0.delete_all", self, {"filters": len(filters)})
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)

logger.info(f"Deleted {len(memories)} memories")

if self.version == "v1.1" and self.enable_graph:
if self.api_version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)

return {"message": "Memories deleted successfully!"}
Expand Down Expand Up @@ -561,6 +561,7 @@ def _create_memory(self, data, existing_embeddings, metadata=None):
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id

def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
Expand Down Expand Up @@ -603,6 +604,7 @@ def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
return memory_id

def _delete_memory(self, memory_id):
Expand All @@ -611,6 +613,7 @@ def _delete_memory(self, memory_id):
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion mem0/memory/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.version}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.api_version}",
}
if additional_data:
event_data.update(additional_data)
Expand Down
10 changes: 3 additions & 7 deletions tests/embeddings/test_azure_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ def test_embed_text(mock_openai_client):

@pytest.mark.parametrize(
"default_headers, expected_header",
[
(None, None),
({"Test": "test_value"}, "test_value"),
({}, None)
],
[(None, None), ({"Test": "test_value"}, "test_value"), ({}, None)],
)
def test_embed_text_with_default_headers(default_headers, expected_header):
config = BaseEmbedderConfig(
Expand All @@ -47,8 +43,8 @@ def test_embed_text_with_default_headers(default_headers, expected_header):
"api_version": "test_version",
"azure_endpoint": "test_endpoint",
"azuer_deployment": "test_deployment",
"default_headers": default_headers
}
"default_headers": default_headers,
},
)
embedder = AzureOpenAIEmbedding(config)
assert embedder.client.api_key == "test"
Expand Down
15 changes: 3 additions & 12 deletions tests/embeddings/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,11 @@ def mock_genai():

@pytest.fixture
def config():
return BaseEmbedderConfig(
api_key="dummy_api_key",
model="test_model"
)
return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model")


def test_embed_query(mock_genai, config):

mock_embedding_response = {
'embedding': [0.1, 0.2, 0.3, 0.4]
}
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]}
mock_genai.return_value = mock_embedding_response

embedder = GoogleGenAIEmbedding(config)
Expand All @@ -31,7 +25,4 @@ def test_embed_query(mock_genai, config):
embedding = embedder.embed(text)

assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(
model="test_model",
content="Hello, world!"
)
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!")
Loading
Loading