Skip to content

Commit

Permalink
Correctly compute required_env_vars even for shorthand clients. (#1164)
Browse files Browse the repository at this point in the history
* This makes the playground experience much nicer

<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Improve environment variable computation for shorthand clients in BAML
engine, with refactoring and minor formatting changes.
> 
>   - **Behavior**:
> - Correctly compute `required_env_vars` for shorthand clients in
`repr.rs` and `walker.rs`.
> - Update `required_env_vars()` in `IntermediateRepr` to handle
shorthand clients.
>   - **Refactoring**:
> - Refactor `ClientSpec` to handle provider/model as separate fields in
`repr.rs`.
> - Update `provider_to_env_vars()` to include more providers in
`walker.rs`.
>   - **Misc**:
> - Minor formatting changes in `ir_helpers/mod.rs` and
`runtime_interface.rs`.
> - Update `root-wasm32.code-workspace` to change path from `docs` to
`fern`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for c421a46. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Nov 12, 2024
1 parent 603d9f7 commit 8b51b6e
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 74 deletions.
67 changes: 33 additions & 34 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ use crate::{
},
};
use anyhow::Result;
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue};
use baml_types::{
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue,
TypeValue,
};
pub use to_baml_arg::ArgCoercer;

use super::repr;
Expand All @@ -29,12 +32,15 @@ pub type TestCaseWalker<'a> = Walker<'a, (&'a FunctionNode, &'a TestCase)>;
pub type ClassFieldWalker<'a> = Walker<'a, &'a Field>;

pub trait IRHelper {
fn find_enum(&self, enum_name: &str) -> Result<EnumWalker<'_>>;
fn find_class(&self, class_name: &str) -> Result<ClassWalker<'_>>;
fn find_function(&self, function_name: &str) -> Result<FunctionWalker<'_>>;
fn find_client(&self, client_name: &str) -> Result<ClientWalker<'_>>;
fn find_retry_policy(&self, retry_policy_name: &str) -> Result<RetryPolicyWalker<'_>>;
fn find_template_string(&self, template_string_name: &str) -> Result<TemplateStringWalker<'_>>;
fn find_enum<'a>(&'a self, enum_name: &str) -> Result<EnumWalker<'a>>;
fn find_class<'a>(&'a self, class_name: &str) -> Result<ClassWalker<'a>>;
fn find_function<'a>(&'a self, function_name: &str) -> Result<FunctionWalker<'a>>;
fn find_client<'a>(&'a self, client_name: &str) -> Result<ClientWalker<'a>>;
fn find_retry_policy<'a>(&'a self, retry_policy_name: &str) -> Result<RetryPolicyWalker<'a>>;
fn find_template_string<'a>(
&'a self,
template_string_name: &str,
) -> Result<TemplateStringWalker<'a>>;
fn find_test<'a>(
&'a self,
function: &'a FunctionWalker<'a>,
Expand All @@ -53,16 +59,10 @@ pub trait IRHelper {
) -> Result<BamlValueWithMeta<FieldType>>;
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType
field_type: &'a FieldType,
) -> (&'a FieldType, Vec<Constraint>);
fn type_has_constraints(
&self,
field_type: &FieldType
) -> bool;
fn type_has_checks(
&self,
field_type: &FieldType
) -> bool;
fn type_has_constraints(&self, field_type: &FieldType) -> bool;
fn type_has_checks(&self, field_type: &FieldType) -> bool;
}

impl IRHelper for IntermediateRepr {
Expand Down Expand Up @@ -119,14 +119,14 @@ impl IRHelper for IntermediateRepr {
}
}

fn find_client<'ir>(&'ir self, client_name: &str) -> Result<ClientWalker<'ir>> {
match self.walk_clients().find(|c| c.elem().name == client_name) {
fn find_client<'a>(&'a self, client_name: &str) -> Result<ClientWalker<'a>> {
match self.walk_clients().find(|c| c.name() == client_name) {
Some(c) => Ok(c),
None => {
// Get best match.
let clients = self
.walk_clients()
.map(|c| c.elem().name.as_str())
.map(|c| c.name().to_string())
.collect::<Vec<_>>();
error_not_found!("client", client_name, &clients)
}
Expand Down Expand Up @@ -378,7 +378,6 @@ impl IRHelper for IntermediateRepr {
}
}


/// Constraints may live in several places. A constrained base type stors its
/// constraints by wrapping itself in the `FieldType::Constrained` constructor.
/// Additionally, `FieldType::Class` may have constraints stored in its class node,
Expand All @@ -390,20 +389,19 @@ impl IRHelper for IntermediateRepr {
/// possible sources. Whenever querying a type for its constraints, you
/// should do so with this function, instead of searching manually for all
/// the places that Constraints can live.
fn distribute_constraints<'a>(&'a self, field_type: &'a FieldType) -> (&'a FieldType, Vec<Constraint>) {
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType,
) -> (&'a FieldType, Vec<Constraint>) {
match field_type {
FieldType::Class(class_name) => {
match self.find_class(class_name) {
Err(_) => (field_type, Vec::new()),
Ok(class_node) => (field_type, class_node.item.attributes.constraints.clone())
}
}
FieldType::Enum(enum_name) => {
match self.find_enum(enum_name) {
Err(_) => (field_type, Vec::new()),
Ok(enum_node) => (field_type, enum_node.item.attributes.constraints.clone())
}
}
FieldType::Class(class_name) => match self.find_class(class_name) {
Err(_) => (field_type, Vec::new()),
Ok(class_node) => (field_type, class_node.item.attributes.constraints.clone()),
},
FieldType::Enum(enum_name) => match self.find_enum(enum_name) {
Err(_) => (field_type, Vec::new()),
Ok(enum_node) => (field_type, enum_node.item.attributes.constraints.clone()),
},
// Check the first level to see if it's constrained.
FieldType::Constrained { base, constraints } => {
match base.as_ref() {
Expand All @@ -412,7 +410,8 @@ impl IRHelper for IntermediateRepr {
// The recursion here means that arbitrarily nested `FieldType::Constrained`s
// will be collapsed before the function returns.
FieldType::Constrained { .. } => {
let (sub_base, sub_constraints) = self.distribute_constraints(base.as_ref());
let (sub_base, sub_constraints) =
self.distribute_constraints(base.as_ref());
let combined_constraints = vec![constraints.clone(), sub_constraints]
.into_iter()
.flatten()
Expand Down
58 changes: 41 additions & 17 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,29 @@ impl IntermediateRepr {
&self.configuration
}

pub fn required_env_vars(&self) -> HashSet<&str> {
pub fn required_env_vars(&self) -> HashSet<String> {
// TODO: We should likely check the full IR.
let mut env_vars = HashSet::new();

self.clients
.iter()
.flat_map(|c| c.elem.options.iter())
.flat_map(|(_, expr)| expr.required_env_vars())
.collect::<HashSet<&str>>()
for client in self.walk_clients() {
client.required_env_vars().iter().for_each(|v| {
env_vars.insert(v.to_string());
});
}

// self.walk_functions().filter_map(
// |f| f.client_name()
// ).map(|c| c.required_env_vars())

// // for any functions, check for shorthand env vars
// self.functions
// .iter()
// .filter_map(|f| f.elem.configs())
// .into_iter()
// .flatten()
// .flat_map(|(expr)| expr.client.required_env_vars())
// .collect()
env_vars
}

/// Returns a list of all the recursive cycles in the IR.
Expand Down Expand Up @@ -406,7 +421,7 @@ impl WithRepr<FieldType> for ast::FieldType {
base: Box::new(base_class),
constraints,
},
_ => base_class
_ => base_class,
}
}
Some(Either::Right(enum_walker)) => {
Expand All @@ -415,9 +430,9 @@ impl WithRepr<FieldType> for ast::FieldType {
match maybe_constraints {
Some(constraints) if constraints.len() > 0 => FieldType::Constrained {
base: Box::new(base_type),
constraints
constraints,
},
_ => base_type
_ => base_type,
}
}
None => return Err(anyhow!("Field type uses unresolvable local identifier")),
Expand Down Expand Up @@ -514,9 +529,9 @@ pub enum Expression {
}

impl Expression {
pub fn required_env_vars(&self) -> Vec<&str> {
pub fn required_env_vars<'a>(&'a self) -> Vec<String> {
match self {
Expression::Identifier(Identifier::ENV(k)) => vec![k.as_str()],
Expression::Identifier(Identifier::ENV(k)) => vec![k.to_string()],
Expression::List(l) => l.iter().flat_map(Expression::required_env_vars).collect(),
Expression::Map(m) => m
.iter()
Expand Down Expand Up @@ -859,31 +874,40 @@ pub struct FunctionConfig {
#[derive(serde::Serialize, Clone, Debug)]
pub enum ClientSpec {
Named(String),
Shorthand(String),
/// Shorthand for "<provider>/<model>"
Shorthand(String, String),
}

impl ClientSpec {
pub fn as_str(&self) -> &str {
pub fn as_str(&self) -> String {
match self {
ClientSpec::Named(n) => n,
ClientSpec::Shorthand(n) => n,
ClientSpec::Named(n) => n.clone(),
ClientSpec::Shorthand(provider, model) => format!("{provider}/{model}"),
}
}

pub fn new_from_id(arg: String) -> Self {
if arg.contains("/") {
ClientSpec::Shorthand(arg)
let (provider, model) = arg.split_once("/").unwrap();
ClientSpec::Shorthand(provider.to_string(), model.to_string())
} else {
ClientSpec::Named(arg)
}
}

pub fn required_env_vars(&self) -> HashSet<String> {
match self {
ClientSpec::Named(n) => HashSet::new(),
ClientSpec::Shorthand(_, _) => HashSet::new(),
}
}
}

impl From<AstClientSpec> for ClientSpec {
fn from(spec: AstClientSpec) -> Self {
match spec {
AstClientSpec::Named(n) => ClientSpec::Named(n.to_string()),
AstClientSpec::Shorthand(n) => ClientSpec::Shorthand(n.to_string()),
AstClientSpec::Shorthand(provider, model) => ClientSpec::Shorthand(provider, model),
}
}
}
Expand Down
82 changes: 71 additions & 11 deletions engine/baml-lib/baml-core/src/ir/walker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,40 @@ use indexmap::IndexMap;

use internal_baml_parser_database::RetryPolicyStrategy;

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use crate::ir::jinja_helpers::render_expression;
use super::{
repr::{self, FunctionConfig},
Class, Client, Enum, EnumValue, Expression, Field, FunctionNode, Identifier, Impl, RetryPolicy,
TemplateString, TestCase, Walker,
repr::{self, FunctionConfig, WithRepr},
Class, Client, Enum, EnumValue, Expression, Field, FunctionNode, IRHelper, Identifier, Impl,
RetryPolicy, TemplateString, TestCase, Walker,
};
use crate::ir::jinja_helpers::render_expression;

fn provider_to_env_vars(
provider: &str,
) -> impl IntoIterator<Item = (Option<&'static str>, &'static str)> {
match provider {
"aws-bedrock" => vec![
(None, "AWS_ACCESS_KEY_ID"),
(None, "AWS_SECRET_ACCESS_KEY"),
(Some("region"), "AWS_REGION"),
],
"openai" => vec![(Some("api_key"), "OPENAI_API_KEY")],
"anthropic" => vec![(Some("api_key"), "ANTHROPIC_API_KEY")],
"google-ai" => vec![(Some("api_key"), "GOOGLE_API_KEY")],
"vertex-ai" => vec![
(Some("credentials"), "GOOGLE_APPLICATION_CREDENTIALS"),
(
Some("credentials_content"),
"GOOGLE_APPLICATION_CREDENTIALS_CONTENT",
),
],
"azure-openai" => vec![(Some("api_key"), "AZURE_OPENAI_API_KEY")],
"openai-generic" => vec![(Some("api_key"), "OPENAI_API_KEY")],
"ollama" => vec![],
other => vec![],
}
}

impl<'a> Walker<'a, &'a FunctionNode> {
pub fn name(&self) -> &'a str {
Expand All @@ -26,13 +52,30 @@ impl<'a> Walker<'a, &'a FunctionNode> {
true
}

pub fn client_name(&self) -> Option<&'a str> {
pub fn client_name(&self) -> Option<String> {
if let Some(c) = self.elem().configs.first() {
return Some(c.client.as_str());
}

None
}

pub fn required_env_vars(&'a self) -> Result<HashSet<String>> {
if let Some(c) = self.elem().configs.first() {
match &c.client {
repr::ClientSpec::Named(n) => {
let client: super::ClientWalker<'a> = self.db.find_client(n)?;
Ok(client.required_env_vars())
}
repr::ClientSpec::Shorthand(provider, _) => {
let env_vars = provider_to_env_vars(provider);
Ok(env_vars.into_iter().map(|(_, v)| v.to_string()).collect())
}
}
} else {
anyhow::bail!("Function {} has no client", self.name())
}
}

pub fn walk_impls(
&'a self,
) -> impl Iterator<Item = Walker<'a, (&'a repr::Function, &'a FunctionConfig)>> {
Expand Down Expand Up @@ -228,7 +271,6 @@ impl Expression {
}
}


impl<'a> Walker<'a, (&'a FunctionNode, &'a Impl)> {
#[allow(dead_code)]
pub fn function(&'a self) -> Walker<'a, &'a FunctionNode> {
Expand Down Expand Up @@ -328,11 +370,11 @@ impl<'a> Walker<'a, &'a Class> {
}

impl<'a> Walker<'a, &'a Client> {
pub fn elem(&self) -> &'a repr::Client {
pub fn elem(&'a self) -> &'a repr::Client {
&self.item.elem
}

pub fn name(&self) -> &str {
pub fn name(&'a self) -> &'a str {
&self.elem().name
}

Expand All @@ -344,9 +386,27 @@ impl<'a> Walker<'a, &'a Client> {
self.item.attributes.span.as_ref()
}

pub fn options(&self) -> &Vec<(String, Expression)> {
pub fn options(&'a self) -> &'a Vec<(String, Expression)> {
&self.elem().options
}

pub fn required_env_vars(&'a self) -> HashSet<String> {
let mut env_vars = self
.options()
.iter()
.flat_map(|(_, expr)| expr.required_env_vars())
.collect::<HashSet<String>>();

let options = self.options();
for (k, v) in provider_to_env_vars(self.elem().provider.as_str()) {
match k {
Some(k) if !options.iter().any(|(k2, _)| k2 == k) => env_vars.insert(v.to_string()),
None => env_vars.insert(v.to_string()),
_ => false,
};
}
env_vars
}
}

impl<'a> Walker<'a, &'a RetryPolicy> {
Expand Down
4 changes: 2 additions & 2 deletions engine/baml-lib/parser-database/src/walkers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub enum ClientSpec {
Named(String),

/// Defined inline using shorthand "<provider>/<model>" syntax
Shorthand(String),
Shorthand(String, String),
}

impl<'db> FunctionWalker<'db> {
Expand All @@ -143,7 +143,7 @@ impl<'db> FunctionWalker<'db> {
match client.0.split_once("/") {
// TODO: do this in a more robust way
// actually validate which clients are and aren't allowed
Some((provider, model)) => Ok(ClientSpec::Shorthand(format!("{}/{}", provider, model))),
Some((provider, model)) => Ok(ClientSpec::Shorthand(provider.to_string(), model.to_string())),
None => match self.db.find_client(client.0.as_str()) {
Some(client) => Ok(ClientSpec::Named(client.name().to_string())),
None => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ impl WithChat for VertexClient {
.properties
.get("model")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.or_else(|| _ctx.env.get("default model").map(|s| s.to_string()))
.unwrap_or_else(|| "".to_string()),
metadata: LLMCompleteResponseMetadata {
baml_is_complete: match response.candidates[0].finish_reason {
Expand Down
Loading

0 comments on commit 8b51b6e

Please sign in to comment.