Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Remove create_with_new_index() #57

Merged
merged 15 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion resin/knoweldge_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def delete(self,
pass

@abstractmethod
def verify_connection_health(self) -> None:
def verify_index_connection(self) -> None:
pass

@abstractmethod
Expand Down
142 changes: 56 additions & 86 deletions resin/knoweldge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
from resin.models.data_models import Query, Document


INDEX_DELETED_MESSAGE = (
"index was deleted. "
"Please create it first using `create_with_new_index()`"
)

INDEX_NAME_PREFIX = "resin--"
TIMEOUT_INDEX_CREATE = 300
TIMEOUT_INDEX_PROVISION = 30
Expand All @@ -42,7 +37,6 @@


class KnowledgeBase(BaseKnowledgeBase):

DEFAULT_RECORD_ENCODER = OpenAIRecordEncoder
DEFAULT_CHUNKER = MarkdownChunker
DEFAULT_RERANKER = TransparentReranker
Expand All @@ -54,6 +48,7 @@ def __init__(self,
chunker: Optional[Chunker] = None,
reranker: Optional[Reranker] = None,
default_top_k: int = 5,
index_params: Optional[dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we take this params only on the create method? getting it here makes an illusion that you pass index params and get a KB with it, but the truth is you can use it only in create method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will explain it in the docstring.
It has to be here, so you can do that:

kb = KnowledgeBase.from_config(cfg)
kb.create_index()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, not ideal but good enough for now. I think it means that we want the defaults to be in the config.yaml

):
if default_top_k < 1:
raise ValueError("default_top_k must be greater than 0")
Expand All @@ -64,7 +59,8 @@ def __init__(self,
self._chunker = chunker if chunker is not None else self.DEFAULT_CHUNKER()
self._reranker = reranker if reranker is not None else self.DEFAULT_RERANKER()

self._index: Optional[Index] = self._connect_index(self._index_name)
self._index: Optional[Index] = None
self._index_params = index_params

@staticmethod
def _connect_pinecone():
Expand All @@ -75,67 +71,56 @@ def _connect_pinecone():
raise RuntimeError("Failed to connect to Pinecone. "
"Please check your credentials and try again") from e

@classmethod
def _connect_index(cls,
full_index_name: str,
def _connect_index(self,
connect_pinecone: bool = True
) -> Index:
if connect_pinecone:
cls._connect_pinecone()
self._connect_pinecone()

if full_index_name not in list_indexes():
if self.index_name not in list_indexes():
raise RuntimeError(
f"Index {full_index_name} does not exist. "
"Please create it first using `create_with_new_index()`"
f"The index {self.index_name} does not exist or was deleted. "
"Please create it by calling knowledge_base.create_resin_index() or "
"running the `resin new` command"
)

try:
index = Index(index_name=full_index_name)
index.describe_index_stats()
index = Index(index_name=self.index_name)
except Exception as e:
raise RuntimeError(
f"Unexpected error while connecting to index {full_index_name}. "
f"Unexpected error while connecting to index {self.index_name}. "
f"Please check your credentials and try again."
) from e
return index

def verify_connection_health(self) -> None:
@property
def _connection_error_msg(self) -> str:
return (
f"KnowledgeBase is not connected to index {self.index_name}, "
f"Please call knowledge_base.connect(). "
)

def connect(self) -> None:
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
self._index = self._connect_index()
self.verify_index_connection()

def verify_index_connection(self) -> None:
if self._index is None:
raise RuntimeError(self._connection_error_msg)

try:
self._index.describe_index_stats()
except Exception as e:
try:
pinecone_whoami()
except Exception:
raise RuntimeError(
"Failed to connect to Pinecone. "
"Please check your credentials and try again"
) from e

if self._index_name not in list_indexes():
raise RuntimeError(
f"index {self._index_name} does not exist anymore"
"and was probably deleted. "
"Please create it first using `create_with_new_index()`"
) from e
raise RuntimeError("Index unexpectedly did not respond. "
"Please try again in few moments") from e

@classmethod
def create_with_new_index(cls,
index_name: str,
*,
record_encoder: Optional[RecordEncoder] = None,
chunker: Optional[Chunker] = None,
reranker: Optional[Reranker] = None,
default_top_k: int = 10,
indexed_fields: Optional[List[str]] = None,
dimension: Optional[int] = None,
create_index_params: Optional[dict] = None
) -> 'KnowledgeBase':
raise RuntimeError(
"The index did not respond. Please check your credentials and try again"
) from e

def create_resin_index(self,
acatav marked this conversation as resolved.
Show resolved Hide resolved
indexed_fields: Optional[List[str]] = None,
dimension: Optional[int] = None,
index_params: Optional[dict] = None
):
# validate inputs
if indexed_fields is None:
indexed_fields = ['document_id']
Expand All @@ -147,68 +132,53 @@ def create_with_new_index(cls,
"Please remove it from indexed_fields")

if dimension is None:
record_encoder = record_encoder if record_encoder is not None else cls.DEFAULT_RECORD_ENCODER() # noqa: E501
if record_encoder.dimension is not None:
dimension = record_encoder.dimension
if self._encoder.dimension is not None:
dimension = self._encoder.dimension
else:
raise ValueError("Could not infer dimension from encoder. "
"Please provide the vectors' dimension")

# connect to pinecone and create index
cls._connect_pinecone()
self._connect_pinecone()

full_index_name = cls._get_full_index_name(index_name)

if full_index_name in list_indexes():
if self.index_name in list_indexes():
raise RuntimeError(
f"Index {full_index_name} already exists. "
f"Index {self.index_name} already exists. "
"If you wish to delete it, use `delete_index()`. "
"If you wish to connect to it,"
"directly initialize a `KnowledgeBase` instance"
)

# create index
create_index_params = create_index_params or {}
index_params = index_params or self._index_params or {}
try:
create_index(name=full_index_name,
create_index(name=self.index_name,
dimension=dimension,
metadata_config={
'indexed': indexed_fields
},
timeout=TIMEOUT_INDEX_CREATE,
**create_index_params)
**index_params)
except Exception as e:
raise RuntimeError(
f"Unexpected error while creating index {full_index_name}."
f"Unexpected error while creating index {self.index_name}."
f"Please try again."
) from e

# wait for index to be provisioned
cls._wait_for_index_provision(full_index_name=full_index_name)

# initialize KnowledgeBase
return cls(index_name=index_name,
record_encoder=record_encoder,
chunker=chunker,
reranker=reranker,
default_top_k=default_top_k)

@classmethod
def _wait_for_index_provision(cls,
full_index_name: str):
self._wait_for_index_provision()

def _wait_for_index_provision(self):
start_time = time.time()
while True:
try:
cls._connect_index(full_index_name,
connect_pinecone=False)
self._index = self._connect_index(connect_pinecone=False)
break
except RuntimeError:
pass

time_passed = time.time() - start_time
if time_passed > TIMEOUT_INDEX_PROVISION:
raise RuntimeError(
f"Index {full_index_name} failed to provision "
f"Index {self.index_name} failed to provision "
f"for {time_passed} seconds."
f"Please try creating KnowledgeBase again in a few minutes."
)
Expand All @@ -234,19 +204,19 @@ def index_name(self) -> str:

def delete_index(self):
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)
delete_index(self._index_name)
self._index = None

def query(self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None
) -> List[QueryResult]:
queries: List[KBQuery] = self._encoder.encode_queries(queries)

results: List[KBQueryResult] = [self._query_index(q, global_metadata_filter)
for q in queries]
if self._index is None:
raise RuntimeError(self._connection_error_msg)

queries = self._encoder.encode_queries(queries)
results = [self._query_index(q, global_metadata_filter) for q in queries]
results = self._reranker.rerank(results)

return [
Expand All @@ -267,7 +237,7 @@ def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

metadata_filter = deepcopy(query.metadata_filter)
if global_metadata_filter is not None:
Expand Down Expand Up @@ -303,7 +273,7 @@ def upsert(self,
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

for doc in documents:
metadata_keys = set(doc.metadata.keys())
Expand Down Expand Up @@ -353,7 +323,7 @@ def upsert_dataframe(self,
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

documents = self._df_to_documents(df)

Expand All @@ -363,7 +333,7 @@ def delete(self,
document_ids: List[str],
namespace: str = "") -> None:
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

if self._is_starter_env():
for i in range(0, len(document_ids), DELETE_STARTER_BATCH_SIZE):
Expand Down
5 changes: 3 additions & 2 deletions resin_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def delete(
)
async def health_check():
try:
await run_in_threadpool(kb.verify_connection_health)
await run_in_threadpool(kb.verify_index_connection)
except Exception as e:
err_msg = f"Failed connecting to Pinecone Index {kb._index_name}"
logger.exception(err_msg)
Expand Down Expand Up @@ -192,9 +192,10 @@ def _init_engines():
kb = KnowledgeBase(index_name=INDEX_NAME)
context_engine = ContextEngine(knowledge_base=kb)
llm = OpenAILLM()

chat_engine = ChatEngine(context_engine=context_engine, llm=llm)

kb.connect()


def start(host="0.0.0.0", port=8000, reload=False):
uvicorn.run("resin_cli.app:app",
Expand Down
19 changes: 12 additions & 7 deletions resin_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from resin.knoweldge_base import KnowledgeBase
from resin.models.data_models import Document
from resin.knoweldge_base.knowledge_base import INDEX_NAME_PREFIX
from resin.tokenizer import OpenAITokenizer, Tokenizer
from resin_cli.data_loader import (
load_dataframe_from_path,
Expand Down Expand Up @@ -100,14 +99,13 @@ def health(host, port, ssl):
@click.argument("index-name", nargs=1, envvar="INDEX_NAME", type=str, required=True)
@click.option("--tokenizer-model", default="gpt-3.5-turbo", help="Tokenizer model")
def new(index_name, tokenizer_model):
kb = KnowledgeBase(index_name=index_name)
click.echo("Resin is going to create a new index: ", nl=False)
click.echo(click.style(f"{INDEX_NAME_PREFIX}{index_name}", fg="green"))
click.echo(click.style(f"{kb.index_name}", fg="green"))
click.confirm(click.style("Do you want to continue?", fg="red"), abort=True)
Tokenizer.initialize(OpenAITokenizer, tokenizer_model)
with spinner:
_ = KnowledgeBase.create_with_new_index(
index_name=index_name
)
kb.create_resin_index()
click.echo(click.style("Success!", fg="green"))
os.environ["INDEX_NAME"] = index_name

Expand All @@ -132,12 +130,19 @@ def upsert(index_name, data_path, tokenizer_model):
+ " please provide it with --data-path or set it with env var"
click.echo(click.style(msg, fg="red"), err=True)
sys.exit(1)

kb = KnowledgeBase(index_name=index_name)
try:
kb.connect()
except RuntimeError as e:
click.echo(click.style(str(e), fg="red"), err=True)
sys.exit(1)

click.echo("Resin is going to upsert data from ", nl=False)
click.echo(click.style(f'{data_path}', fg='yellow'), nl=False)
click.echo(" to index: ")
click.echo(click.style(f'{INDEX_NAME_PREFIX}{index_name} \n', fg='green'))
click.echo(click.style(f'{kb.index_name} \n', fg='green'))
with spinner:
kb = KnowledgeBase(index_name=index_name)
try:
data = load_dataframe_from_path(data_path)
except IndexNotUniqueError:
Expand Down
5 changes: 3 additions & 2 deletions tests/e2e/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def index_name(testrun_uid):
@pytest.fixture(scope="module", autouse=True)
def knowledge_base(index_name):
pinecone.init()
KnowledgeBase.create_with_new_index(index_name=index_name,)
kb = KnowledgeBase(index_name=index_name)
kb.create_resin_index()

return KnowledgeBase(index_name=index_name)
return kb


@pytest.fixture(scope="module")
Expand Down
Loading