Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Connect to index in KB init #38

Merged
merged 21 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 118 additions & 112 deletions context_engine/knoweldge_base/knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from copy import deepcopy
from datetime import datetime
from typing import List, Optional
from copy import deepcopy
import pandas as pd
import pinecone
from pinecone import list_indexes, delete_index, create_index, init \
as pinecone_init, whoami as pinecone_whoami

try:
from pinecone import GRPCIndex as Index
except ImportError:
from pinecone import Index

from pinecone_datasets import Dataset, DatasetMetadata, DenseModelMetadata

from context_engine.knoweldge_base.base import BaseKnowledgeBase
Expand All @@ -25,7 +32,7 @@ class KnowledgeBase(BaseKnowledgeBase):
DEFAULT_RERANKER = TransparentReranker

def __init__(self,
index_name_suffix: str,
index_name: str,
*,
encoder: Optional[RecordEncoder] = None,
chunker: Optional[Chunker] = None,
Expand All @@ -35,129 +42,133 @@ def __init__(self,
if default_top_k < 1:
raise ValueError("default_top_k must be greater than 0")

if index_name_suffix.startswith(INDEX_NAME_PREFIX):
index_name = index_name_suffix
else:
index_name = INDEX_NAME_PREFIX + index_name_suffix

self._index_name = index_name
self._index_name = self._get_full_index_name(index_name)
acatav marked this conversation as resolved.
Show resolved Hide resolved
self._default_top_k = default_top_k
self._encoder = encoder if encoder is not None else self.DEFAULT_RECORD_ENCODER() # noqa: E501
self._chunker = chunker if chunker is not None else self.DEFAULT_CHUNKER()
self._reranker = reranker if reranker is not None else self.DEFAULT_RERANKER()

self._index: Optional[pinecone.Index] = None

def connect(self, force: bool = False):
if self._index is None or force:

try:
pinecone.init()
pinecone.whoami()
except Exception as e:
raise RuntimeError("Failed to connect to Pinecone. "
"Please check your credentials") from e

try:
self._index = pinecone.Index(index_name=self._index_name)
self._index.describe_index_stats()
except Exception as e:
raise RuntimeError(
f"Failed to connect to the index {self._index_name}. "
"Please check your credentials and index name"
) from e

def create_index(self,
dimension: Optional[int] = None,
indexed_fields: List[str] = ['document_id'],
**kwargs
):
"""
Create a new Pinecone index that will be used to store the documents
Args:
dimension (Optional[int]): The dimension of the vectors to be indexed.
The knowledge base will try to infer it from the
encoder if not provided.
indexed_fields (List[str]): The fields that will be indexed and can be used
for metadata filtering.
Defaults to ['document_id'].
The 'text' field cannot be used for filtering.
**kwargs: Any additional arguments will be passed to the
`pinecone.create_index()` function.

Keyword Args:
index_type: type of index, one of {"approximated", "exact"}, defaults to
"approximated".
metric (str, optional): type of metric used in the vector index, one of
{"cosine", "dotproduct", "euclidean"}, defaults to "cosine".
- Use "cosine" for cosine similarity,
- "dotproduct" for dot-product,
- and "euclidean" for euclidean distance.
replicas (int, optional): the number of replicas, defaults to 1.
- Use at least 2 replicas if you need high availability (99.99%
uptime) for querying.
- For additional throughput (QPS) your index needs to support,
provision additional replicas.
shards (int, optional): the number of shards per index, defaults to 1.
- Use 1 shard per 1GB of vectors.
pods (int, optional): Total number of pods to be used by the index.
pods = shard*replicas.
pod_type (str, optional): the pod type to be used for the index.
can be one of p1 or s1.
index_config: Advanced configuration options for the index.
metadata_config (dict, optional): Configuration related to the metadata
index.
source_collection (str, optional): Collection name to create the index from.
timeout (int, optional): Timeout for wait until index gets ready.
If None, wait indefinitely; if >=0, time out after this many seconds;
if -1, return immediately and do not wait. Default: None.

Returns:
None
"""

if len(indexed_fields) == 0:
raise ValueError("Indexed_fields must contain at least one field")
self._index: Optional[Index] = self._connect_index()

@staticmethod
def _connect_pinecone():
try:
pinecone_init()
pinecone_whoami()
except Exception as e:
raise RuntimeError("Failed to connect to Pinecone. "
"Please check your credentials") from e

def _connect_index(self) -> Index:
self._connect_pinecone()

if self._index_name not in list_indexes():
raise RuntimeError(
f"Index {self._index_name} does not exist. "
"Please create it first using `create_with_new_index()`"
"or use the `ce create` command line"
acatav marked this conversation as resolved.
Show resolved Hide resolved
)

try:
index = Index(index_name=self._index_name)
index.describe_index_stats()
except Exception as e:
raise RuntimeError(
f"Unexpected error while connecting to index {self._index_name}."
f"Please check your credentials and try again."
) from e
return index

@staticmethod
acatav marked this conversation as resolved.
Show resolved Hide resolved
def create_with_new_index(index_name: str,
*,
encoder: RecordEncoder,
chunker: Chunker,
reranker: Optional[Reranker] = None,
default_top_k: int = 10,
indexed_fields: Optional[List[str]] = None,
dimension: Optional[int] = None,
**kwargs) -> 'KnowledgeBase':
acatav marked this conversation as resolved.
Show resolved Hide resolved

if indexed_fields is None:
indexed_fields = ['document_id']
elif "document_id" not in indexed_fields:
indexed_fields.append('document_id')

if 'text' in indexed_fields:
raise ValueError("The 'text' field cannot be used for metadata filtering. "
"Please remove it from indexed_fields")

if self._index is not None:
raise RuntimeError("Index already exists")
full_index_name = KnowledgeBase._get_full_index_name(index_name)

KnowledgeBase._connect_pinecone()

if full_index_name in list_indexes():
raise RuntimeError(
f"Index {full_index_name} already exists. "
"If you wish to delete it, use `delete_index()`. "
"If you wish to connect to it,"
"directly initialize a `KnowledgeBase` instance"
)

if dimension is None:
if self._encoder.dimension is not None:
dimension = self._encoder.dimension
if encoder.dimension is not None:
dimension = encoder.dimension
else:
raise ValueError("Could not infer dimension from encoder. "
"Please provide the vectors' dimension")
try:
create_index(name=full_index_name,
dimension=dimension,
metadata_config={
'indexed': indexed_fields
},
**kwargs)
except Exception as e:
raise RuntimeError(
f"Unexpected error while creating index {full_index_name}."
f"Please try again."
) from e

if full_index_name not in list_indexes():
raise RuntimeError(
f"Index {full_index_name} is probably still provisioning."
f"Please try creating KnowledgeBase again in a few minutes."
f"Or simply run `context-engine create` command line."
)
acatav marked this conversation as resolved.
Show resolved Hide resolved

return KnowledgeBase(index_name=index_name,
acatav marked this conversation as resolved.
Show resolved Hide resolved
encoder=encoder,
chunker=chunker,
reranker=reranker,
default_top_k=default_top_k)

pinecone.init()
pinecone.create_index(name=self._index_name,
dimension=dimension,
metadata_config={
'indexed': indexed_fields
},
**kwargs)
self.connect()
@staticmethod
def _get_full_index_name(index_name: str) -> str:
if index_name.startswith(INDEX_NAME_PREFIX):
return index_name
else:
return INDEX_NAME_PREFIX + index_name

def delete_index(self):
if self._index_name not in pinecone.list_indexes():
if self._index_name not in list_indexes():
raise RuntimeError(
"Index does not exist.")
pinecone.delete_index(self._index_name)
delete_index(self._index_name)
self._index = None

def _validate_not_deleted(self):
if self._index is None:
raise RuntimeError(
"index was deleted. "
"Please create it first using `create_with_new_index()`"
"or use the `context-engine create` command line"
)

def query(self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None
) -> List[QueryResult]:

if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")

queries: List[KBQuery] = self._encoder.encode_queries(queries)

results: List[KBQueryResult] = [self._query_index(q, global_metadata_filter)
Expand All @@ -172,6 +183,7 @@ def query(self,
def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
self._validate_not_deleted()
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")
Expand All @@ -182,7 +194,8 @@ def _query_index(self,
metadata_filter = {}
metadata_filter.update(global_metadata_filter)
top_k = query.top_k if query.top_k else self._default_top_k
result = self._index.query(vector=query.values,

result = self._index.query(vector=query.values, # type: ignore
sparse_vector=query.sparse_values,
top_k=top_k,
namespace=query.namespace,
Expand All @@ -207,9 +220,8 @@ def upsert(self,
documents: List[Document],
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")
self._validate_not_deleted()

chunks = self._chunker.chunk_documents(documents)
encoded_chunks = self._encoder.encode_documents(chunks)

Expand Down Expand Up @@ -246,9 +258,6 @@ def upsert_dataframe(self,
df: pd.DataFrame,
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")
expected_columns = ["id", "text", "metadata"]
if not all([c in df.columns for c in expected_columns]):
raise ValueError(
Expand All @@ -262,11 +271,8 @@ def upsert_dataframe(self,
def delete(self,
document_ids: List[str],
namespace: str = "") -> None:
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")

self._index.delete(
self._validate_not_deleted()
self._index.delete( # type: ignore
filter={"document_id": {"$in": document_ids}},
namespace=namespace
)
Expand Down
Loading