Skip to content

Commit

Permalink
feat: support specifying "region" for aws-bedrock (#1150)
Browse files Browse the repository at this point in the history
- allow overriding `region`
- support `model` as well as `model_id` (reasoning in comments)
- get rid of `cfg_if` in `aws_client` (it's incompatible with rustfmt)
- make "Something went wrong" error messages better - actually propagate
LLMFailureReason.message in the exception itself
  • Loading branch information
sxlijin authored Nov 9, 2024
1 parent db89247 commit cbe3c92
Show file tree
Hide file tree
Showing 20 changed files with 1,386 additions and 1,422 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;

use aws_config::Region;
use aws_config::{identity::IdentityCache, retry::RetryConfig, BehaviorVersion, ConfigLoader};
use aws_sdk_bedrockruntime::{self as bedrock, operation::converse::ConverseOutput};

Expand Down Expand Up @@ -35,6 +36,8 @@ use crate::{RenderCurlSettings, RuntimeContext};
struct RequestProperties {
model_id: String,

aws_region: Option<String>,

default_role: String,
inference_config: Option<bedrock::types::InferenceConfiguration>,
allowed_metadata: AllowedMetadata,
Expand Down Expand Up @@ -68,20 +71,38 @@ fn resolve_properties(client: &ClientWalker, ctx: &RuntimeContext) -> Result<Req
})
.collect::<Result<HashMap<_, _>>>()?;

let model_id = properties
.remove("model_id")
.context("model_id is required")?
.as_str()
.context("model_id should be a string")?
.to_string();
let model_id = {
// We allow `provider aws-bedrock` to specify the model using either `model_id` or `model`:
//
// - the Bedrock API itself only accepts `model_id`
// - but all other providers specify the model using `model`, so for someone used to using
// "openai/gpt-4o", they'll expect to be able to use `model gpt-4o`
// - if I were on the bedrock team, I would be _very_ hesitant to add a new request field
// `model` if I already have `model_id`, so I think using `model` isn't too risky
let maybe_model_id = properties.remove("model_id");
let maybe_model = properties.remove("model");

match (maybe_model_id, maybe_model) {
(Some(model_id), _) => model_id
.as_str()
.context("model_id should be a string")?
.to_string(),
(None, Some(model)) => model
.as_str()
.context("model should be a string")?
.to_string(),
_ => anyhow::bail!("model_id or model is required"),
}
};

let default_role = properties
.remove("default_role")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "user".to_string());
let allowed_metadata = match properties.remove("allowed_role_metadata") {
Some(allowed_metadata) => serde_json::from_value(allowed_metadata)
.context("allowed_role_metadata must be an array of keys. For example: ['key1', 'key2']")?,
Some(allowed_metadata) => serde_json::from_value(allowed_metadata).context(
"allowed_role_metadata must be an array of keys. For example: ['key1', 'key2']",
)?,
None => AllowedMetadata::None,
};
let inference_config = match properties.remove("inference_configuration") {
Expand All @@ -93,8 +114,13 @@ fn resolve_properties(client: &ClientWalker, ctx: &RuntimeContext) -> Result<Req
None => None,
};

let aws_region = properties
.remove("region")
.and_then(|v| v.as_str().map(str::to_owned));

Ok(RequestProperties {
model_id,
aws_region,
default_role,
inference_config,
allowed_metadata,
Expand All @@ -113,7 +139,7 @@ impl AwsClient {
context: RenderContext_Client {
name: client.name().into(),
provider: client.elem().provider.clone(),
default_role: default_role,
default_role,
},
features: ModelFeatures {
chat: true,
Expand All @@ -138,42 +164,46 @@ impl AwsClient {
// TODO: this should be memoized on client construction, but because config loading is async,
// we can't do this in AwsClient::new (which is called from LLMPRimitiveProvider::try_from)
async fn client_anyhow(&self) -> Result<bedrock::Client> {
let loader: ConfigLoader = {
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
use aws_config::Region;
use aws_credential_types::Credentials;

let (aws_region, aws_access_key_id, aws_secret_access_key) = match (
self.properties.ctx_env.get("AWS_REGION"),
self.properties.ctx_env.get("AWS_ACCESS_KEY_ID"),
self.properties.ctx_env.get("AWS_SECRET_ACCESS_KEY"),
) {
(Some(aws_region), Some(aws_access_key_id), Some(aws_secret_access_key)) => {
(aws_region, aws_access_key_id, aws_secret_access_key)
}
_ => {
anyhow::bail!(
"AWS_REGION, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY must be set in the environment"
)
}
};
#[cfg(not(target_arch = "wasm32"))]
let loader: ConfigLoader = aws_config::defaults(BehaviorVersion::latest());

let loader = super::wasm::load_aws_config()
.region(Region::new(aws_region.clone()))
.credentials_provider(Credentials::new(
aws_access_key_id.clone(),
aws_secret_access_key.clone(),
None,
None,
"baml-runtime/wasm",
));

loader
} else {
aws_config::defaults(BehaviorVersion::latest())
#[cfg(target_arch = "wasm32")]
let loader: ConfigLoader = {
use aws_config::Region;
use aws_credential_types::Credentials;

let (aws_region, aws_access_key_id, aws_secret_access_key) = match (
self.properties.ctx_env.get("AWS_REGION"),
self.properties.ctx_env.get("AWS_ACCESS_KEY_ID"),
self.properties.ctx_env.get("AWS_SECRET_ACCESS_KEY"),
) {
(Some(aws_region), Some(aws_access_key_id), Some(aws_secret_access_key)) => {
(aws_region, aws_access_key_id, aws_secret_access_key)
}
}
_ => {
anyhow::bail!(
"AWS_REGION, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY must be set in the environment"
)
}
};

let loader = super::wasm::load_aws_config()
.region(Region::new(aws_region.clone()))
.credentials_provider(Credentials::new(
aws_access_key_id.clone(),
aws_secret_access_key.clone(),
None,
None,
"baml-runtime/wasm",
));

loader
};

let loader = if let Some(aws_region) = &self.properties.aws_region {
loader.region(Region::new(aws_region.clone()))
} else {
loader
};

let config = loader
Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_python/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl BamlError {
baml_runtime::internal::llm_client::ErrorCode::Other(2) => {
PyErr::new::<BamlClientError, _>(format!(
"Something went wrong with the LLM client: {}",
err
failed.message
))
}
baml_runtime::internal::llm_client::ErrorCode::Other(_)
Expand Down
Loading

0 comments on commit cbe3c92

Please sign in to comment.