Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Connect to index in KB init #38

Merged
merged 21 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
needs: run-linters
strategy:
matrix:
python-version: [3.8, 3.9, '3.10', 3.11]
python-version: [3.9, '3.10', 3.11]

steps:
- uses: actions/checkout@v3
Expand All @@ -53,7 +53,7 @@ jobs:
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PINECONE_ENVIRONMENT: eu-west1-gcp
run: poetry run pytest -n 3 --dist loadgroup --html=report_system.html --self-contained-html tests/system
run: poetry run pytest -n 3 --dist loadscope --html=report_system.html --self-contained-html tests/system
- name: upload pytest report.html
uses: actions/upload-artifact@v3
if: always()
Expand Down
274 changes: 157 additions & 117 deletions context_engine/knoweldge_base/knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from copy import deepcopy
from datetime import datetime
import time
from typing import List, Optional
from copy import deepcopy
import pandas as pd
import pinecone
from pinecone import list_indexes, delete_index, create_index, init \
as pinecone_init, whoami as pinecone_whoami

try:
from pinecone import GRPCIndex as Index
except ImportError:
from pinecone import Index

from pinecone_datasets import Dataset, DatasetMetadata, DenseModelMetadata

from context_engine.knoweldge_base.base import BaseKnowledgeBase
Expand All @@ -16,6 +24,9 @@


INDEX_NAME_PREFIX = "context-engine-"
TIMEOUT_INDEX_CREATE = 300
TIMEOUT_INDEX_PROVISION = 30
INDEX_PROVISION_TIME_INTERVAL = 3


class KnowledgeBase(BaseKnowledgeBase):
Expand All @@ -25,7 +36,7 @@ class KnowledgeBase(BaseKnowledgeBase):
DEFAULT_RERANKER = TransparentReranker

def __init__(self,
index_name_suffix: str,
index_name: str,
*,
encoder: Optional[RecordEncoder] = None,
chunker: Optional[Chunker] = None,
Expand All @@ -35,129 +46,164 @@ def __init__(self,
if default_top_k < 1:
raise ValueError("default_top_k must be greater than 0")

if index_name_suffix.startswith(INDEX_NAME_PREFIX):
index_name = index_name_suffix
else:
index_name = INDEX_NAME_PREFIX + index_name_suffix

self._index_name = index_name
self._index_name = self._get_full_index_name(index_name)
acatav marked this conversation as resolved.
Show resolved Hide resolved
self._default_top_k = default_top_k
self._encoder = encoder if encoder is not None else self.DEFAULT_RECORD_ENCODER() # noqa: E501
self._chunker = chunker if chunker is not None else self.DEFAULT_CHUNKER()
self._reranker = reranker if reranker is not None else self.DEFAULT_RERANKER()

self._index: Optional[pinecone.Index] = None

def connect(self, force: bool = False):
if self._index is None or force:

try:
pinecone.init()
pinecone.whoami()
except Exception as e:
raise RuntimeError("Failed to connect to Pinecone. "
"Please check your credentials") from e
self._index: Optional[Index] = self._connect_index(self._index_name)

@staticmethod
def _connect_pinecone():
try:
pinecone_init()
pinecone_whoami()
except Exception as e:
raise RuntimeError("Failed to connect to Pinecone. "
"Please check your credentials and try again") from e

@classmethod
def _connect_index(cls,
full_index_name: str,
connect_pinecone: bool = True
) -> Index:
if connect_pinecone:
cls._connect_pinecone()

if full_index_name not in list_indexes():
raise RuntimeError(
f"Index {full_index_name} does not exist. "
"Please create it first using `create_with_new_index()`"
)

try:
self._index = pinecone.Index(index_name=self._index_name)
self._index.describe_index_stats()
except Exception as e:
raise RuntimeError(
f"Failed to connect to the index {self._index_name}. "
"Please check your credentials and index name"
) from e

def create_index(self,
dimension: Optional[int] = None,
indexed_fields: List[str] = ['document_id'],
**kwargs
):
"""
Create a new Pinecone index that will be used to store the documents
Args:
dimension (Optional[int]): The dimension of the vectors to be indexed.
The knowledge base will try to infer it from the
encoder if not provided.
indexed_fields (List[str]): The fields that will be indexed and can be used
for metadata filtering.
Defaults to ['document_id'].
The 'text' field cannot be used for filtering.
**kwargs: Any additional arguments will be passed to the
`pinecone.create_index()` function.
Keyword Args:
index_type: type of index, one of {"approximated", "exact"}, defaults to
"approximated".
metric (str, optional): type of metric used in the vector index, one of
{"cosine", "dotproduct", "euclidean"}, defaults to "cosine".
- Use "cosine" for cosine similarity,
- "dotproduct" for dot-product,
- and "euclidean" for euclidean distance.
replicas (int, optional): the number of replicas, defaults to 1.
- Use at least 2 replicas if you need high availability (99.99%
uptime) for querying.
- For additional throughput (QPS) your index needs to support,
provision additional replicas.
shards (int, optional): the number of shards per index, defaults to 1.
- Use 1 shard per 1GB of vectors.
pods (int, optional): Total number of pods to be used by the index.
pods = shard*replicas.
pod_type (str, optional): the pod type to be used for the index.
can be one of p1 or s1.
index_config: Advanced configuration options for the index.
metadata_config (dict, optional): Configuration related to the metadata
index.
source_collection (str, optional): Collection name to create the index from.
timeout (int, optional): Timeout for wait until index gets ready.
If None, wait indefinitely; if >=0, time out after this many seconds;
if -1, return immediately and do not wait. Default: None.
Returns:
None
"""

if len(indexed_fields) == 0:
raise ValueError("Indexed_fields must contain at least one field")
try:
index = Index(index_name=full_index_name)
index.describe_index_stats()
except Exception as e:
raise RuntimeError(
f"Unexpected error while connecting to index {full_index_name}. "
f"Please check your credentials and try again."
) from e
return index

@classmethod
def create_with_new_index(cls,
index_name: str,
*,
encoder: RecordEncoder,
chunker: Chunker,
reranker: Optional[Reranker] = None,
default_top_k: int = 10,
indexed_fields: Optional[List[str]] = None,
dimension: Optional[int] = None,
create_index_params: Optional[dict] = None
) -> 'KnowledgeBase':

# validate inputs
if indexed_fields is None:
indexed_fields = ['document_id']
elif "document_id" not in indexed_fields:
indexed_fields.append('document_id')

if 'text' in indexed_fields:
raise ValueError("The 'text' field cannot be used for metadata filtering. "
"Please remove it from indexed_fields")

if self._index is not None:
raise RuntimeError("Index already exists")

if dimension is None:
if self._encoder.dimension is not None:
dimension = self._encoder.dimension
if encoder.dimension is not None:
dimension = encoder.dimension
else:
raise ValueError("Could not infer dimension from encoder. "
"Please provide the vectors' dimension")

pinecone.init()
pinecone.create_index(name=self._index_name,
dimension=dimension,
metadata_config={
'indexed': indexed_fields
},
**kwargs)
self.connect()
# connect to pinecone and create index
cls._connect_pinecone()

def delete_index(self):
if self._index_name not in pinecone.list_indexes():
full_index_name = cls._get_full_index_name(index_name)

if full_index_name in list_indexes():
raise RuntimeError(
"Index does not exist.")
pinecone.delete_index(self._index_name)
f"Index {full_index_name} already exists. "
"If you wish to delete it, use `delete_index()`. "
"If you wish to connect to it,"
"directly initialize a `KnowledgeBase` instance"
)

# create index
create_index_params = create_index_params or {}
try:
create_index(name=full_index_name,
dimension=dimension,
metadata_config={
'indexed': indexed_fields
},
timeout=TIMEOUT_INDEX_CREATE,
**create_index_params)
except Exception as e:
raise RuntimeError(
f"Unexpected error while creating index {full_index_name}."
f"Please try again."
) from e

# wait for index to be provisioned
cls._wait_for_index_provision(full_index_name=full_index_name)

# initialize KnowledgeBase
return cls(index_name=index_name,
encoder=encoder,
chunker=chunker,
reranker=reranker,
default_top_k=default_top_k)

@classmethod
def _wait_for_index_provision(cls,
full_index_name: str):
start_time = time.time()
while True:
try:
cls._connect_index(full_index_name,
connect_pinecone=False)
break
except RuntimeError:
pass

time_passed = time.time() - start_time
if time_passed > TIMEOUT_INDEX_PROVISION:
raise RuntimeError(
f"Index {full_index_name} failed to provision "
f"for {time_passed} seconds."
f"Please try creating KnowledgeBase again in a few minutes."
)
time.sleep(INDEX_PROVISION_TIME_INTERVAL)

@staticmethod
def _get_full_index_name(index_name: str) -> str:
if index_name.startswith(INDEX_NAME_PREFIX):
return index_name
else:
return INDEX_NAME_PREFIX + index_name

@property
def index_name(self) -> str:
return self._index_name

def delete_index(self):
self._validate_not_deleted()
delete_index(self._index_name)
self._index = None

def _validate_not_deleted(self):
if self._index is None:
raise RuntimeError(
"index was deleted. "
"Please create it first using `create_with_new_index()`"
)

def query(self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None
) -> List[QueryResult]:

if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")

queries: List[KBQuery] = self._encoder.encode_queries(queries)

results: List[KBQueryResult] = [self._query_index(q, global_metadata_filter)
Expand All @@ -172,17 +218,16 @@ def query(self,
def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")
self._validate_not_deleted()

metadata_filter = deepcopy(query.metadata_filter)
if global_metadata_filter is not None:
if metadata_filter is None:
metadata_filter = {}
metadata_filter.update(global_metadata_filter)
top_k = query.top_k if query.top_k else self._default_top_k
result = self._index.query(vector=query.values,

result = self._index.query(vector=query.values, # type: ignore
sparse_vector=query.sparse_values,
top_k=top_k,
namespace=query.namespace,
Expand All @@ -207,9 +252,8 @@ def upsert(self,
documents: List[Document],
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")
self._validate_not_deleted()

chunks = self._chunker.chunk_documents(documents)
encoded_chunks = self._encoder.encode_documents(chunks)

Expand Down Expand Up @@ -246,9 +290,8 @@ def upsert_dataframe(self,
df: pd.DataFrame,
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")
self._validate_not_deleted()

expected_columns = ["id", "text", "metadata"]
if not all([c in df.columns for c in expected_columns]):
raise ValueError(
Expand All @@ -262,11 +305,8 @@ def upsert_dataframe(self,
def delete(self,
document_ids: List[str],
namespace: str = "") -> None:
if self._index is None:
raise RuntimeError(
"Index does not exist. Please call `connect()` first")

self._index.delete(
self._validate_not_deleted()
self._index.delete( # type: ignore
filter={"document_id": {"$in": document_ids}},
namespace=namespace
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ asyncio = "^3.4.3"
pytest-asyncio = "^0.14.0"
pytest-mock = "^3.6.1"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
Loading