Skip to content

Commit

Permalink
feat/metadata
Browse files Browse the repository at this point in the history
allow to request metadata in queries
  • Loading branch information
JarbasAl committed Jul 25, 2024
1 parent 9fa2200 commit 9c02c06
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions ovos_plugin_manager/templates/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,18 +356,21 @@ def delete_document(self, document: str) -> None:
"""
self.db.delete_embeddings(document)

def query(self, document: str, top_k: int = 5) -> List[Tuple[str, float]]:
def query(self, document: str, top_k: int = 5,
return_metadata: bool = False) -> List[Tuple[str, float]]:
"""Query the database for the top_k closest embeddings to the document.
Args:
document (str): The document to query.
top_k (int, optional): The number of top results to return. Defaults to 5.
return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.
Returns:
List[Tuple[str, float]]: List of tuples containing the document and distance.
"""
embeddings = self.get_text_embeddings(document)
return self.db.query(embeddings, top_k)
return self.db.query(embeddings, top_k,
return_metadata=return_metadata)

def distance(self, text_a: str, text_b: str, metric: str = "cosine") -> float:
"""Calculate the distance between embeddings of two texts.
Expand Down Expand Up @@ -452,18 +455,21 @@ def predict(self, frame: EmbeddingsArray, top_k: int = 3, thresh: float = 0.15)
return None
return best_match[0]

def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]:
def query(self, frame: EmbeddingsArray, top_k: int = 5,
return_metadata: bool = False) -> List[Tuple[str, float]]:
"""Query the database for the top_k closest face embeddings to the frame.
Args:
frame (np.ndarray): The input image frame containing a face.
top_k (int, optional): The number of top results to return. Defaults to 5.
return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.
Returns:
List[Tuple[str, float]]: List of tuples containing the user ID and distance.
"""
embeddings = self.get_face_embeddings(frame)
return self.db.query(embeddings, top_k)
return self.db.query(embeddings, top_k,
return_metadata=return_metadata)

def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str = "cosine") -> float:
"""Calculate the distance between embeddings of two faces.
Expand Down Expand Up @@ -564,18 +570,21 @@ def predict(self, audio_data: EmbeddingsArray, top_k: int = 3, thresh: float = 0
return None
return best_match[0]

def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]:
def query(self, audio_data: EmbeddingsArray, top_k: int = 5,
return_metadata: bool = False) -> List[Tuple[str, float]]:
"""Query the database for the top_k closest voice embeddings to the audio_data.
Args:
audio_data (np.ndarray): The input audio data.
top_k (int, optional): The number of top results to return. Defaults to 5.
return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.
Returns:
List[Tuple[str, float]]: List of tuples containing the user ID and distance.
"""
embeddings = self.get_voice_embeddings(audio_data)
return self.db.query(embeddings, top_k)
return self.db.query(embeddings, top_k,
return_metadata=return_metadata)

def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: str = "cosine") -> float:
"""Calculate the distance between embeddings of two voices.
Expand Down

0 comments on commit 9c02c06

Please sign in to comment.