Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Add support OctoAI LLM and embeddings (#301)
Browse files Browse the repository at this point in the history
* "fixed typo in dense.py docstring"

* adding octoAI embeddings

* added octoai system test

* increased batch size

* added information for OctoAI env vars

* updated record_encoder batch size

* support for OctoAI LLM adaptor

* changed prefix after code review

* added OctoAI to llm unit tests

* fixed linting
  • Loading branch information
ptorru authored Mar 5, 2024
1 parent fa82c84 commit 9f3f8fd
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ These optional environment variables are used to authenticate to other supported
| `JINA_API_KEY` | API key for Jina AI. Used to authenticate to JinaAI's services for embedding and chat API | You can find your OpenAI API key [here](https://platform.openai.com/account/api-keys). You might need to login or register to OpenAI services |
| `AZURE_OPENAI_ENDOINT`| The URL of the Azure OpenAI endpoint you deployed. | You can find this in the Azure OpenAI portal under _Keys and Endpoints`|
| `AZURE_OPENAI_API_KEY` | The API key to use for your Azure OpenAI models. | You can find this in the Azure OpenAI portal under _Keys and Endpoints`|
| `OCTOAI_API_KEY` | API key for OctoAI. Used to authenticate for open source LLMs served in OctoAI | You can sign up for OctoAI and find your API key [here](https://octo.ai/)

</details>

Expand Down Expand Up @@ -280,4 +281,3 @@ gunicorn canopy_server.app:app --worker-class uvicorn.workers.UvicornWorker --bi
> The server interacts with services like Pinecone and OpenAI using your own authentication credentials.
When deploying the server on a public web hosting provider, it is recommended to enable an authentication mechanism,
so that your server would only take requests from authenticated users.

50 changes: 50 additions & 0 deletions src/canopy/config_templates/octoai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# ===========================================================
# Configuration file for Canopy Server
# ===========================================================
tokenizer:
# -------------------------------------------------------------------------------------------
# Tokenizer configuration
# Use LLamaTokenizer from HuggingFace with the relevant OSS model (e.g. LLama2)
# -------------------------------------------------------------------------------------------
type: LlamaTokenizer # Options: [OpenAITokenizer, LlamaTokenizer]
params:
model_name: hf-internal-testing/llama-tokenizer

chat_engine:
# -------------------------------------------------------------------------------------------
# Chat engine configuration
# Use OctoAI as the open source LLM provider
# You can find the list of supported LLMs at https://octo.ai/docs/text-gen-solution/rest-api
# -------------------------------------------------------------------------------------------
params:
max_prompt_tokens: 2048 # The maximum number of tokens to use for input prompt to the LLM.
llm: &llm
type: OctoAILLM
params:
model_name: mistral-7b-instruct-fp16 # The name of the model to use.

# query_builder:
# type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator, InstructionQueryGenerator]
# llm:
# type: OctoAILLM
# params:
# model_name: mistral-7b-instruct-fp16

context_engine:
# -------------------------------------------------------------------------------------------------------------
# ContextEngine configuration
# -------------------------------------------------------------------------------------------------------------
knowledge_base:
# -----------------------------------------------------------------------------------------------------------
# KnowledgeBase configuration
# -----------------------------------------------------------------------------------------------------------
record_encoder:
# --------------------------------------------------------------------------
# Configuration for the RecordEncoder subcomponent of the knowledge base.
# Use OctoAI's Embedding endpoint for dense encoding
# --------------------------------------------------------------------------
type: OctoAIRecordEncoder
params:
model_name: # The name of the model to use for encoding
thenlper/gte-large
batch_size: 2048 # The number of document chunks to encode in each call to the encoding model
1 change: 1 addition & 0 deletions src/canopy/knowledge_base/record_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .jina import JinaRecordEncoder
from .sentence_transformers import SentenceTransformerRecordEncoder
from .hybrid import HybridRecordEncoder
from .octoai import OctoAIRecordEncoder
68 changes: 68 additions & 0 deletions src/canopy/knowledge_base/record_encoder/octoai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
from typing import List
from pinecone_text.dense.openai_encoder import OpenAIEncoder
from canopy.knowledge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery
from canopy.knowledge_base.record_encoder.dense import DenseRecordEncoder
from canopy.models.data_models import Query

OCTOAI_BASE_URL = "https://text.octoai.run/v1"


class OctoAIRecordEncoder(DenseRecordEncoder):
"""
OctoAIRecordEncoder is a type of DenseRecordEncoder that uses the OpenAI `embeddings` API.
The implementation uses the `OpenAIEncoder` class from the `pinecone-text` library.
For more information about see: https://github.com/pinecone-io/pinecone-text
""" # noqa: E501
"""
Initialize the OctoAIRecordEncoder
Args:
api_key: The OctoAI Endpoint API Key
base_url: The Base URL for the OctoAI Endpoint
model_name: The name of the OctoAI embeddings model to use for encoding. See https://octo.ai/docs/text-gen-solution/getting-started
batch_size: The number of documents or queries to encode at once.
Defaults to 1.
**kwargs: Additional arguments to pass to the underlying `pinecone-text. OpenAIEncoder`.
""" # noqa: E501
def __init__(self,
*,
api_key: str = "",
base_url: str = OCTOAI_BASE_URL,
model_name: str = "thenlper/gte-large",
batch_size: int = 1024,
**kwargs):

octoai_api_key = api_key or os.environ.get("OCTOAI_API_KEY")
if not octoai_api_key:
raise ValueError(
"An OctoAI API token is required to use OctoAI. "
"Please provide it as an argument "
"or set the OCTOAI_API_KEY environment variable."
)
octoai_base_url = base_url
encoder = OpenAIEncoder(model_name,
base_url=octoai_base_url, api_key=octoai_api_key,
**kwargs)
super().__init__(dense_encoder=encoder, batch_size=batch_size)

def encode_documents(self, documents: List[KBDocChunk]) -> List[KBEncodedDocChunk]:
"""
Encode a list of documents, takes a list of KBDocChunk and returns a list of KBEncodedDocChunk.
Args:
documents: A list of KBDocChunk to encode.
Returns:
encoded chunks: A list of KBEncodedDocChunk, with the `values` field populated by the generated embeddings vector.
""" # noqa: E501
return super().encode_documents(documents)

async def _aencode_documents_batch(self,
documents: List[KBDocChunk]
) -> List[KBEncodedDocChunk]:
raise NotImplementedError

async def _aencode_queries_batch(self, queries: List[Query]) -> List[KBQuery]:
raise NotImplementedError
1 change: 1 addition & 0 deletions src/canopy/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .anyscale import AnyscaleLLM
from .azure_openai_llm import AzureOpenAILLM
from .cohere import CohereLLM
from .octoai import OctoAILLM
61 changes: 61 additions & 0 deletions src/canopy/llm/octoai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Optional, Any
import os
from canopy.llm import OpenAILLM
from canopy.llm.models import Function
from canopy.models.data_models import Messages

OCTOAI_BASE_URL = "https://text.octoai.run/v1"


class OctoAILLM(OpenAILLM):
"""
OctoAI LLM wrapper built on top of the OpenAI Python client.
Note: OctoAI requires a valid API key to use this class.
You can set the "OCTOAI_API_KEY" environment variable.
"""

def __init__(
self,
model_name: str = "mistral-7b-instruct-fp16",
*,
base_url: Optional[str] = OCTOAI_BASE_URL,
api_key: Optional[str] = None,
**kwargs: Any,
):
octoai_api_key = api_key or os.environ.get("OCTOAI_API_KEY")
if not octoai_api_key:
raise ValueError(
"OctoAI API key is required to use OctoAI. "
"If you haven't done it, please sign up at https://octo.ai \n"
"The key can be provided as an argument or "
"via the OCTOAI_API_KEY environment variable."
)
octoai_base_url = base_url
super().__init__(
model_name,
api_key=octoai_api_key,
base_url=octoai_base_url,
**kwargs
)

def enforced_function_call(
self,
system_prompt: str,
chat_history: Messages,
function: Function,
*,
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None,
) -> dict:
raise NotImplementedError("OctoAI doesn't support function calling.")

def aenforced_function_call(self,
system_prompt: str,
chat_history: Messages,
function: Function,
*,
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None
):
raise NotImplementedError("OctoAI doesn't support function calling.")
30 changes: 28 additions & 2 deletions tests/system/llm/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jsonschema
import pytest

from canopy.llm import AzureOpenAILLM, AnyscaleLLM
from canopy.llm import AzureOpenAILLM, AnyscaleLLM, OctoAILLM
from canopy.models.data_models import Role, MessageBase, Context, StringContextContent # noqa
from canopy.models.api_models import ChatResponse, StreamingChatChunk # noqa
from canopy.llm.openai import OpenAILLM # noqa
Expand Down Expand Up @@ -60,7 +60,7 @@ def model_params_low_temperature():
return {"temperature": 0.2, "top_p": 0.5, "n": 1}


@pytest.fixture(params=[OpenAILLM, AzureOpenAILLM, AnyscaleLLM])
@pytest.fixture(params=[OpenAILLM, AzureOpenAILLM, AnyscaleLLM, OctoAILLM])
def openai_llm(request):
llm_class = request.param
if llm_class == AzureOpenAILLM:
Expand All @@ -73,6 +73,10 @@ def openai_llm(request):
if os.getenv("ANYSCALE_API_KEY") is None:
pytest.skip("Couldn't find Anyscale API key. Skipping Anyscale tests.")
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
elif llm_class == OctoAILLM:
if os.getenv("OCTOAI_API_KEY") is None:
pytest.skip[("Couldn't find OctoAI API key. Skipping OctoAI tests.")]
model_name = "mistral-7b-instruct"
else:
model_name = "gpt-3.5-turbo-0613"

Expand Down Expand Up @@ -121,6 +125,8 @@ def test_chat_completion_with_context(openai_llm, messages):
def test_enforced_function_call(openai_llm,
messages,
function_query_knowledgebase):
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")
result = openai_llm.enforced_function_call(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
Expand All @@ -134,11 +140,15 @@ def test_chat_completion_high_temperature(openai_llm,
if isinstance(openai_llm, AnyscaleLLM):
pytest.skip("Anyscale don't support n>1 for the moment.")

if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support n>1 for the moment.")

response = openai_llm.chat_completion(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
model_params=model_params_high_temperature
)

assert_chat_completion(response,
num_choices=model_params_high_temperature["n"])

Expand All @@ -160,6 +170,9 @@ def test_enforced_function_call_high_temperature(openai_llm,
if isinstance(openai_llm, AnyscaleLLM):
pytest.skip("Anyscale don't support n>1 for the moment.")

if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

result = openai_llm.enforced_function_call(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
Expand All @@ -177,6 +190,9 @@ def test_enforced_function_call_low_temperature(openai_llm,
if isinstance(openai_llm, AnyscaleLLM):
model_params["top_p"] = 1.0

if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

result = openai_llm.enforced_function_call(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
Expand All @@ -191,6 +207,8 @@ def test_chat_completion_with_model_name(openai_llm, messages):
pytest.skip("In Azure the model name has to be a valid deployment")
elif isinstance(openai_llm, AnyscaleLLM):
new_model_name = "meta-llama/Llama-2-7b-chat-hf"
elif isinstance(openai_llm, OctoAILLM):
new_model_name = "codellama-7b-instruct"
else:
new_model_name = "gpt-3.5-turbo-1106"

Expand Down Expand Up @@ -248,6 +266,9 @@ def test_chat_complete_api_failure_populates(openai_llm,
def test_enforce_function_api_failure_populates(openai_llm,
messages,
function_query_knowledgebase):
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

openai_llm._client = MagicMock()
openai_llm._client.chat.completions.create.side_effect = Exception(
"API call failed")
Expand All @@ -261,6 +282,9 @@ def test_enforce_function_api_failure_populates(openai_llm,
def test_enforce_function_wrong_output_schema(openai_llm,
messages,
function_query_knowledgebase):
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

openai_llm._client = MagicMock()
openai_llm._client.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(
Expand Down Expand Up @@ -302,6 +326,8 @@ def test_enforce_function_unsupported_model(openai_llm,
def test_available_models(openai_llm):
if isinstance(openai_llm, AzureOpenAILLM):
pytest.skip("Azure does not support listing models")
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI does not support listing models")
models = openai_llm.available_models
assert isinstance(models, list)
assert len(models) > 0
Expand Down
53 changes: 53 additions & 0 deletions tests/system/record_encoder/test_octoai_record_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from canopy.knowledge_base.models import KBDocChunk
from canopy.knowledge_base.record_encoder.octoai import OctoAIRecordEncoder
from canopy.models.data_models import Query


documents = [KBDocChunk(
id=f"doc_1_{i}",
text=f"Sample document {i}",
document_id=f"doc_{i}",
metadata={"test": i},
source="doc_1",
)
for i in range(4)
]

queries = [Query(text="Sample query 1"),
Query(text="Sample query 2"),
Query(text="Sample query 3"),
Query(text="Sample query 4")]


@pytest.fixture
def encoder():
return OctoAIRecordEncoder(batch_size=2)


def test_dimension(encoder):
assert encoder.dimension == 1024


@pytest.mark.parametrize("items,function",
[(documents, "encode_documents"),
(queries, "encode_queries"),
([], "encode_documents"),
([], "encode_queries")])
def test_encode_documents(encoder, items, function):

encoded_documents = getattr(encoder, function)(items)

assert len(encoded_documents) == len(items)
assert all(len(encoded.values) == encoder.dimension
for encoded in encoded_documents)


@pytest.mark.asyncio
@pytest.mark.parametrize("items,function",
[("aencode_documents", documents),
("aencode_queries", queries)])
async def test_aencode_not_implemented(encoder, function, items):
with pytest.raises(NotImplementedError):
await encoder.aencode_queries(items)

0 comments on commit 9f3f8fd

Please sign in to comment.