This repository has been archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support OctoAI LLM and embeddings (#301)
* "fixed typo in dense.py docstring" * adding octoAI embeddings * added octoai system test * increased batch size * added information for OctoAI env vars * updated record_encoder batch size * support for OctoAI LLM adaptor * changed prefix after code review * added OctoAI to llm unit tests * fixed linting
- Loading branch information
Showing
8 changed files
with
263 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# =========================================================== | ||
# Configuration file for Canopy Server | ||
# =========================================================== | ||
tokenizer: | ||
# ------------------------------------------------------------------------------------------- | ||
# Tokenizer configuration | ||
# Use LLamaTokenizer from HuggingFace with the relevant OSS model (e.g. LLama2) | ||
# ------------------------------------------------------------------------------------------- | ||
type: LlamaTokenizer # Options: [OpenAITokenizer, LlamaTokenizer] | ||
params: | ||
model_name: hf-internal-testing/llama-tokenizer | ||
|
||
chat_engine: | ||
# ------------------------------------------------------------------------------------------- | ||
# Chat engine configuration | ||
# Use OctoAI as the open source LLM provider | ||
# You can find the list of supported LLMs at https://octo.ai/docs/text-gen-solution/rest-api | ||
# ------------------------------------------------------------------------------------------- | ||
params: | ||
max_prompt_tokens: 2048 # The maximum number of tokens to use for input prompt to the LLM. | ||
llm: &llm | ||
type: OctoAILLM | ||
params: | ||
model_name: mistral-7b-instruct-fp16 # The name of the model to use. | ||
|
||
# query_builder: | ||
# type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator, InstructionQueryGenerator] | ||
# llm: | ||
# type: OctoAILLM | ||
# params: | ||
# model_name: mistral-7b-instruct-fp16 | ||
|
||
context_engine: | ||
# ------------------------------------------------------------------------------------------------------------- | ||
# ContextEngine configuration | ||
# ------------------------------------------------------------------------------------------------------------- | ||
knowledge_base: | ||
# ----------------------------------------------------------------------------------------------------------- | ||
# KnowledgeBase configuration | ||
# ----------------------------------------------------------------------------------------------------------- | ||
record_encoder: | ||
# -------------------------------------------------------------------------- | ||
# Configuration for the RecordEncoder subcomponent of the knowledge base. | ||
# Use OctoAI's Embedding endpoint for dense encoding | ||
# -------------------------------------------------------------------------- | ||
type: OctoAIRecordEncoder | ||
params: | ||
model_name: # The name of the model to use for encoding | ||
thenlper/gte-large | ||
batch_size: 2048 # The number of document chunks to encode in each call to the encoding model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import os | ||
from typing import List | ||
from pinecone_text.dense.openai_encoder import OpenAIEncoder | ||
from canopy.knowledge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery | ||
from canopy.knowledge_base.record_encoder.dense import DenseRecordEncoder | ||
from canopy.models.data_models import Query | ||
|
||
OCTOAI_BASE_URL = "https://text.octoai.run/v1" | ||
|
||
|
||
class OctoAIRecordEncoder(DenseRecordEncoder): | ||
""" | ||
OctoAIRecordEncoder is a type of DenseRecordEncoder that uses the OpenAI `embeddings` API. | ||
The implementation uses the `OpenAIEncoder` class from the `pinecone-text` library. | ||
For more information about see: https://github.com/pinecone-io/pinecone-text | ||
""" # noqa: E501 | ||
""" | ||
Initialize the OctoAIRecordEncoder | ||
Args: | ||
api_key: The OctoAI Endpoint API Key | ||
base_url: The Base URL for the OctoAI Endpoint | ||
model_name: The name of the OctoAI embeddings model to use for encoding. See https://octo.ai/docs/text-gen-solution/getting-started | ||
batch_size: The number of documents or queries to encode at once. | ||
Defaults to 1. | ||
**kwargs: Additional arguments to pass to the underlying `pinecone-text. OpenAIEncoder`. | ||
""" # noqa: E501 | ||
def __init__(self, | ||
*, | ||
api_key: str = "", | ||
base_url: str = OCTOAI_BASE_URL, | ||
model_name: str = "thenlper/gte-large", | ||
batch_size: int = 1024, | ||
**kwargs): | ||
|
||
octoai_api_key = api_key or os.environ.get("OCTOAI_API_KEY") | ||
if not octoai_api_key: | ||
raise ValueError( | ||
"An OctoAI API token is required to use OctoAI. " | ||
"Please provide it as an argument " | ||
"or set the OCTOAI_API_KEY environment variable." | ||
) | ||
octoai_base_url = base_url | ||
encoder = OpenAIEncoder(model_name, | ||
base_url=octoai_base_url, api_key=octoai_api_key, | ||
**kwargs) | ||
super().__init__(dense_encoder=encoder, batch_size=batch_size) | ||
|
||
def encode_documents(self, documents: List[KBDocChunk]) -> List[KBEncodedDocChunk]: | ||
""" | ||
Encode a list of documents, takes a list of KBDocChunk and returns a list of KBEncodedDocChunk. | ||
Args: | ||
documents: A list of KBDocChunk to encode. | ||
Returns: | ||
encoded chunks: A list of KBEncodedDocChunk, with the `values` field populated by the generated embeddings vector. | ||
""" # noqa: E501 | ||
return super().encode_documents(documents) | ||
|
||
async def _aencode_documents_batch(self, | ||
documents: List[KBDocChunk] | ||
) -> List[KBEncodedDocChunk]: | ||
raise NotImplementedError | ||
|
||
async def _aencode_queries_batch(self, queries: List[Query]) -> List[KBQuery]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Optional, Any | ||
import os | ||
from canopy.llm import OpenAILLM | ||
from canopy.llm.models import Function | ||
from canopy.models.data_models import Messages | ||
|
||
OCTOAI_BASE_URL = "https://text.octoai.run/v1" | ||
|
||
|
||
class OctoAILLM(OpenAILLM): | ||
""" | ||
OctoAI LLM wrapper built on top of the OpenAI Python client. | ||
Note: OctoAI requires a valid API key to use this class. | ||
You can set the "OCTOAI_API_KEY" environment variable. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "mistral-7b-instruct-fp16", | ||
*, | ||
base_url: Optional[str] = OCTOAI_BASE_URL, | ||
api_key: Optional[str] = None, | ||
**kwargs: Any, | ||
): | ||
octoai_api_key = api_key or os.environ.get("OCTOAI_API_KEY") | ||
if not octoai_api_key: | ||
raise ValueError( | ||
"OctoAI API key is required to use OctoAI. " | ||
"If you haven't done it, please sign up at https://octo.ai \n" | ||
"The key can be provided as an argument or " | ||
"via the OCTOAI_API_KEY environment variable." | ||
) | ||
octoai_base_url = base_url | ||
super().__init__( | ||
model_name, | ||
api_key=octoai_api_key, | ||
base_url=octoai_base_url, | ||
**kwargs | ||
) | ||
|
||
def enforced_function_call( | ||
self, | ||
system_prompt: str, | ||
chat_history: Messages, | ||
function: Function, | ||
*, | ||
max_tokens: Optional[int] = None, | ||
model_params: Optional[dict] = None, | ||
) -> dict: | ||
raise NotImplementedError("OctoAI doesn't support function calling.") | ||
|
||
def aenforced_function_call(self, | ||
system_prompt: str, | ||
chat_history: Messages, | ||
function: Function, | ||
*, | ||
max_tokens: Optional[int] = None, | ||
model_params: Optional[dict] = None | ||
): | ||
raise NotImplementedError("OctoAI doesn't support function calling.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBDocChunk | ||
from canopy.knowledge_base.record_encoder.octoai import OctoAIRecordEncoder | ||
from canopy.models.data_models import Query | ||
|
||
|
||
documents = [KBDocChunk( | ||
id=f"doc_1_{i}", | ||
text=f"Sample document {i}", | ||
document_id=f"doc_{i}", | ||
metadata={"test": i}, | ||
source="doc_1", | ||
) | ||
for i in range(4) | ||
] | ||
|
||
queries = [Query(text="Sample query 1"), | ||
Query(text="Sample query 2"), | ||
Query(text="Sample query 3"), | ||
Query(text="Sample query 4")] | ||
|
||
|
||
@pytest.fixture | ||
def encoder(): | ||
return OctoAIRecordEncoder(batch_size=2) | ||
|
||
|
||
def test_dimension(encoder): | ||
assert encoder.dimension == 1024 | ||
|
||
|
||
@pytest.mark.parametrize("items,function", | ||
[(documents, "encode_documents"), | ||
(queries, "encode_queries"), | ||
([], "encode_documents"), | ||
([], "encode_queries")]) | ||
def test_encode_documents(encoder, items, function): | ||
|
||
encoded_documents = getattr(encoder, function)(items) | ||
|
||
assert len(encoded_documents) == len(items) | ||
assert all(len(encoded.values) == encoder.dimension | ||
for encoded in encoded_documents) | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("items,function", | ||
[("aencode_documents", documents), | ||
("aencode_queries", queries)]) | ||
async def test_aencode_not_implemented(encoder, function, items): | ||
with pytest.raises(NotImplementedError): | ||
await encoder.aencode_queries(items) |