diff --git a/.env.sample b/.env.sample index 00826f4b..857108f5 100644 --- a/.env.sample +++ b/.env.sample @@ -1,6 +1,9 @@ # OpenAI API Key OPENAI_API_KEY= +# Cerebras API Key +CEREBRAS_API_KEY= + # Anthropic API Key ANTHROPIC_API_KEY= diff --git a/.gitignore b/.gitignore index 5b651c66..f607e5a5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ __pycache__/ env/ .env .google-adc +.DS_Store +aisuite/.DS_Store + diff --git a/aisuite/providers/cerebras_provider.py b/aisuite/providers/cerebras_provider.py new file mode 100644 index 00000000..00c5cfb1 --- /dev/null +++ b/aisuite/providers/cerebras_provider.py @@ -0,0 +1,26 @@ +import os + +from cerebras.cloud.sdk import Cerebras +from aisuite.provider import Provider + + +class CerebrasProvider(Provider): + def __init__(self, **config): + """ + Initialize the Cerebras provider with the given configuration. + Pass the entire configuration dictionary to the Cerebras client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("CEREBRAS_API_KEY")) + if not config["api_key"]: + raise ValueError( + " API key is missing. Please provide it in the config or set the CEREBRAS_API_KEY environment variable." + ) + self.client = Cerebras(**config) + + def chat_completions_create(self, model, messages, **kwargs): + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Cerebras API + ) diff --git a/guides/README.md b/guides/README.md index 3079c29c..715f8e2c 100644 --- a/guides/README.md +++ b/guides/README.md @@ -9,6 +9,7 @@ Here're the instructions for: - [Google](google.md) - [Hugging Face](huggingface.md) - [OpenAI](openai.md) +- [Cerebras](cerebras.md) Unless otherwise stated, these guides have not been endorsed by the providers. diff --git a/guides/cerebras.md b/guides/cerebras.md new file mode 100644 index 00000000..06ca5237 --- /dev/null +++ b/guides/cerebras.md @@ -0,0 +1,42 @@ +# Cerebras + +To use Cerebras with `aisuite`, you'll need a [Cerebras account](https://console.cerebras.net/). After logging in, navigate to your account settings to generate an API key. Once you have your key, add it to your environment as follows: + +```shell +export CEREBRAS_API_KEY="your-cerebras-api-key" +``` + +## Create a Chat Completion + +Install the `cerebras` Python client: + +Example with pip: +```shell +pip install cerebras_cloud_sdk +``` + +Example with poetry: +```shell +poetry add cerebras_cloud_sdk +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "cerebras" +model_id = "llama3.1-8b" + + +messages = [ + {"content": "Why is fast inference important?"}, +] + +response = client.chat.completions.create(provider=provider, model_id=model_id, messages=messages) + +print(response.choices[0].message.content) + +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8ae9295b..a8a68d9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,26 @@ [tool.poetry] name = "aisuite" -version = "0.1.6" +version = "0.1.7" description = "Uniform access layer for LLMs" authors = ["Andrew Ng"] +maintainers = [ + { email = "standsleeping@gmail.com" }, + { name = "Kevin Solorio" }, + { name = "Rohit Prasad", email = "rohit.prasad15@gmail.com" }, + { name = "Jeff Tang" }, + { name = "Andrew Ng" }, + { name = "John Santerre" }, + { name = "Zachary Bloss", email = "zacharybloss@gmail.com" }, + { email = "rohit-rptless" }, + { name = "Bilal Kamal", email = "bilal.k.hamada@gmail.com" } +] readme = "README.md" [tool.poetry.dependencies] python = "^3.10" anthropic = { version = "^0.30.1", optional = true } boto3 = { version = "^1.34.144", optional = true } +cerebras-cloud-sdk = { version = "^1.12.2", optional = true } vertexai = { version = "^1.63.0", optional = true } groq = { version = "^0.9.0", optional = true } mistralai = { version = "^1.0.3", optional = true } @@ -25,7 +37,7 @@ huggingface = [] mistral = ["mistralai"] ollama = [] openai = ["openai"] -all = ["anthropic", "aws", "google", "groq", "mistral", "openai"] # To install all providers +all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cerebras-cloud-sdk"] # To install all providers [tool.poetry.group.dev.dependencies] pytest = "^8.2.2"