Skip to content

Commit

Permalink
Fix: Add chunking for GPT4Adapter requests bigger than 2048 tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Jun 13, 2024
1 parent 17f2f0f commit 54ba9db
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions datastew/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"):
logging.info(f"Getting embedding for {text}")
try:
if text is None or text == "" or text is np.nan:
logging.warn(f"Empty text passed to get_embedding")
logging.warning(f"Empty text passed to get_embedding")
return None
if isinstance(text, str):
text = text.replace("\n", " ")
Expand All @@ -32,10 +32,18 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"):
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str], model="text-embedding-ada-002"):
# store index of nan entries
response = openai.Embedding.create(input=messages, model=model)
return [item["embedding"] for item in response["data"]]
def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_chunk_length=2048):
embeddings = []
for message in messages:
if len(message) <= max_chunk_length:
embeddings.append(self.get_embedding(message, model))
else:
# Split message into chunks
chunks = [message[i:i+max_chunk_length] for i in range(0, len(message), max_chunk_length)]
for idx, chunk in enumerate(chunks):
logging.info(f'Processing chunk {idx}/{len(chunks)}')
embeddings.append(self.get_embedding(chunk, model))
return embeddings


class MPNetAdapter(EmbeddingModel):
Expand Down

0 comments on commit 54ba9db

Please sign in to comment.