Skip to content

Commit

Permalink
Gemini (Google AI) support for existing data types (#669)
Browse files Browse the repository at this point in the history
Important Note: this is not Vertex AI's Gemini endpoint, as that does
not support API keys and requires OAuth tokens.
This endpoint supports API keys, and is provided by Google AI.
  • Loading branch information
anish-palakurthi authored Jun 12, 2024
1 parent 2423ace commit 6697ecd
Show file tree
Hide file tree
Showing 30 changed files with 1,755 additions and 159 deletions.
2 changes: 1 addition & 1 deletion docs/docs/home/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Share your creations and ask questions in our [Discord](https://discord.gg/BTNBe

### Language features
- **Python and Typescript support**: Plug-and-play BAML with other languages
- **JSON correction**: BAML fix bad JSON returned by LLMs (e.g. unquoted keys, newlines, comments, extra quotes, and more)
- **JSON correction**: BAML fixes bad JSON returned by LLMs (e.g. unquoted keys, newlines, comments, extra quotes, and more)
- **Wide model support**: Ollama, Openai, Anthropic. Tested on small models like Llama2
- **Streaming**: Stream structured partial outputs
- **Resilience and fallback features**: Add retries, redundancy, to your LLM calls
Expand Down
38 changes: 33 additions & 5 deletions docs/docs/syntax/client/client.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ BAML ships with the following providers (you can can also write your own!):
- `azure-openai`
- `anthropic`
- `ollama`
- `google-ai`
- Composite client providers
- `fallback`
- `round-robin`
Expand Down Expand Up @@ -122,23 +123,50 @@ Provider names:

Accepts any options as defined by [Ollama SDK](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion).

```rust
client<llm> MyOllamaClient {
provider ollama
options {
model llama2
}
}
```
#### Requirements
Make sure you disable CORS if you are trying to run ollama using the BAML VSCode playground:
1. in your terminal run `OLLAMA_ORIGINS="*" ollama serve`
2. Run `ollama run llama2` (or your model), and you should be good to go.

1. For Ollama, in your terminal run `ollama serve`
2. In another window, run `ollama run llama2` (or your model), and you should be good to go.

```rust
client<llm> MyClient {
provider ollama
options {
model mistral
model llama2
options {
temperature 0
}
}
}
```

### Google

Provider names:
- `google-ai`

Accepts any options as defined by the [Gemini SDK](https://ai.google.dev/gemini-api/docs/api-overview).


```rust
client<llm> MyGoogleClient {
provider google-ai
options{
model "gemini-1.5-pro-001"
api_key env.GOOGLE_API_KEY
}
}
```


### Fallback

The `baml-fallback` provider allows you to define a resilient client, by
Expand Down Expand Up @@ -184,7 +212,7 @@ client<llm> MyClient {
```

## Other providers
You can use the `openai` provider if the provider you're trying to use has the same ChatML response format.
You can use the `openai` provider if the provider you're trying to use has the same ChatML response format (i.e. HuggingFace via their Inference Endpoint or your own local endpoint)

Some providers ask you to add a `base_url`, which you can do like this:

Expand Down
28 changes: 15 additions & 13 deletions docs/docs/syntax/type.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ title: Supported Types

- **Syntax:** `null`

### ✅ Images

You can use an image like this:

```rust
function DescribeImage(myImg: image) -> string {
client GPT4Turbo
prompt #"
{{ _.role("user")}}
Describe the image in four words:
{{ myImg }}
"#
}
```

### ⚠️ bytes

- Not yet supported. Use a `string[]` or `int[]` instead.
Expand Down Expand Up @@ -65,20 +80,7 @@ temperature of 32 degrees Fahrenheit or cost of $100.00.
the unit be part of the variable name. For example, `temperature_fahrenheit`
and `cost_usd` (see [@alias](/docs/syntax/class#alias)).

### ✅ Images

You can use an image like this:

```rust
function DescribeImage(myImg: image) -> string {
client GPT4Turbo
prompt #"
{{ _.role("user")}}
Describe the image in four words:
{{ myImg }}
"#
}
```


## Composite/Structured Types
Expand Down
4 changes: 4 additions & 0 deletions engine/.turbo/daemon/d1f419c7a2c5692a-turbo.log.2024-06-10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
2024-06-10T23:15:35.616256Z WARN daemon_server: turborepo_lib::commands::daemon: daemon already running
2024-06-10T23:35:28.207259Z WARN daemon_server: turborepo_lib::commands::daemon: daemon already running
2024-06-10T23:48:30.344739Z WARN daemon_server: turborepo_lib::commands::daemon: daemon already running
2024-06-10T23:57:43.838177Z WARN daemon_server: turborepo_lib::commands::daemon: daemon already running
2 changes: 2 additions & 0 deletions engine/.turbo/daemon/d1f419c7a2c5692a-turbo.log.2024-06-11
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
2024-06-11T22:27:08.580345Z WARN daemon_server: turborepo_lib::commands::daemon: daemon already running
2024-06-11T22:40:18.568937Z WARN daemon_server: turborepo_lib::commands::daemon: daemon already running
1 change: 0 additions & 1 deletion engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ strum = { version = "0.26.2", features = ["derive"] }
strum_macros = "0.26.2"
walkdir = "2.5.0"
web-time = "1.1.0"

baml-types = { path = "baml-lib/baml-types" }
internal-baml-codegen = { path = "language-client-codegen" }
internal-baml-core = { path = "baml-lib/baml-core" }
Expand Down
9 changes: 9 additions & 0 deletions engine/baml-cli/src/init_command/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,12 @@ fn anthropic_clients<T: From<&'static str> + AsRef<str>>() -> Vec<ClientConfig<T
},
]
}

fn google_clients<T: From<&'static str> + AsRef<str>>() -> Vec<ClientConfig<T>> {
vec![ClientConfig {
comment: None,
provider: "google-ai".into(),
name: "Gemini".into(),
params: vec![("model_name", "gemini"), ("api_key", "env.GOOGLE_API_KEY")],
}]
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
"round-robin",
"baml-fallback",
"fallback",
"google-ai",
];

let suggestions: Vec<String> = allowed_providers
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use reqwest::Response;
use crate::{
internal::llm_client::{
primitive::{
anthropic::types::{AnthropicErrorResponse, AnthropicMessageResponse, StopReason},
anthropic::types::{AnthropicMessageResponse, StopReason},
request::{make_parsed_request, make_request, RequestBuilder},
},
traits::{
Expand All @@ -35,6 +35,7 @@ use crate::RuntimeContext;

use super::types::MessageChunk;

// stores properties required for making a post request to the API
struct PostRequestProperities {
default_role: String,
base_url: String,
Expand All @@ -45,6 +46,7 @@ struct PostRequestProperities {
properties: HashMap<String, serde_json::Value>,
}

// represents client that interacts with the Anthropic API
pub struct AnthropicClient {
pub name: String,
retry_policy: Option<String>,
Expand All @@ -56,6 +58,8 @@ pub struct AnthropicClient {
client: reqwest::Client,
}

// resolves/constructs PostRequestProperties from the client's options and runtime context, fleshing out the needed headers and parameters
// basically just reads the client's options and matches them to needed properties or defaults them
fn resolve_properties(
client: &ClientWalker,
ctx: &RuntimeContext,
Expand Down Expand Up @@ -131,6 +135,7 @@ fn resolve_properties(
})
}

// getters for client info
impl WithRetryPolicy for AnthropicClient {
fn retry_policy_name(&self) -> Option<&str> {
self.retry_policy.as_deref()
Expand All @@ -149,6 +154,7 @@ impl WithClient for AnthropicClient {

impl WithNoCompletion for AnthropicClient {}

// Manages processing response chunks from streaming response, and converting it into a structured response format
impl SseResponseTrait for AnthropicClient {
fn response_stream(
&self,
Expand Down Expand Up @@ -285,6 +291,7 @@ impl SseResponseTrait for AnthropicClient {
}
}

// handles streamign chat interactions, when sending prompt to API and processing response stream
impl WithStreamChat for AnthropicClient {
async fn stream_chat(
&self,
Expand All @@ -300,6 +307,7 @@ impl WithStreamChat for AnthropicClient {
}
}

// constructs base client and resolves properties based on context
impl AnthropicClient {
pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result<AnthropicClient> {
Ok(Self {
Expand All @@ -324,6 +332,7 @@ impl AnthropicClient {
}
}

// how to build the HTTP request for requests
impl RequestBuilder for AnthropicClient {
fn http_client(&self) -> &reqwest::Client {
&self.client
Expand Down Expand Up @@ -377,6 +386,7 @@ impl RequestBuilder for AnthropicClient {
if stream {
body_obj.insert("stream".into(), true.into());
}
log::info!("Request body: {:#?}", body);

req.json(&body)
}
Expand Down Expand Up @@ -447,17 +457,18 @@ impl WithChat for AnthropicClient {
}
}

// converts completion prompt into JSON body for request
fn convert_completion_prompt_to_body(prompt: &String) -> HashMap<String, serde_json::Value> {
let mut map = HashMap::new();
map.insert("prompt".into(), json!(prompt));
map
}

// converts chat prompt into JSON body for request
fn convert_chat_prompt_to_body(
prompt: &Vec<RenderedChatMessage>,
) -> HashMap<String, serde_json::Value> {
let mut map = HashMap::new();
log::debug!("converting chat prompt to body: {:#?}", prompt);

if let Some(first) = prompt.get(0) {
if first.role == "system" {
Expand Down Expand Up @@ -511,6 +522,7 @@ fn convert_chat_prompt_to_body(
return map;
}

// converts chat message parts into JSON content
fn convert_message_parts_to_content(parts: &Vec<ChatMessagePart>) -> serde_json::Value {
if parts.len() == 1 {
if let ChatMessagePart::Text(text) = &parts[0] {
Expand Down
Loading

0 comments on commit 6697ecd

Please sign in to comment.