Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add api args and verbose parameters to Bedrock provider #292

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 193 additions & 13 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,117 @@ NULL
#' Chat with an AWS bedrock model
#'
#' @description
#' [AWS Bedrock](https://aws.amazon.com/bedrock/) provides a number of chat
#' based models, including those Anthropic's
#' [Claude](https://aws.amazon.com/bedrock/claude/).
#' [AWS Bedrock](https://aws.amazon.com/bedrock/) provides a number of
#' language models, including those from Anthropic's
#' [Claude](https://aws.amazon.com/bedrock/claude/), using the Bedrock
#' [Converse API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html).
#' Although Ellmer provides a default model, you'll need to
#' specify a model that you actually have access to using the `model` argument.
#' If using [cross-region inference](https://aws.amazon.com/blogs/machine-learning/getting-started-with-cross-region-inference-in-amazon-bedrock/),
#' you'll need to use the inference profile ID for
#' any model argument, e.g., `model="us.anthropic.claude-3-5-sonnet-20240620-v1:0"`.
#' For examples of tool usage, asynchronous input, and other advanced features,
#' visit the [vignettes](https://posit-dev.github.io/ellmer/vignettes/) section
#' of the repo.
#'
#' ## Authentication
#'
#' Authenthication is handled through \{paws.common\}, so if authenthication
#'
#' Authentication is handled through \{paws.common\}, so if authentication
#' does not work for you automatically, you'll need to follow the advice
#' at <https://www.paws-r-sdk.com/#credentials>. In particular, if your
#' org uses AWS SSO, you'll need to run `aws sso login` at the terminal.
#'
#' @param profile AWS profile to use.
#' @param api_args Optional list of arguments passed to the Bedrock API. Use
#' this to customize model behavior. Valid arguments are: `temperature`,
#' `top_p`, `top_k`, `stop_sequences`, and `max_tokens`, though certain
#' models may not support every parameter. Check the AWS Bedrock model
#' documentation for specifics. Note that different model families
#' (Claude, Nova, Llama, etc.) may natively use different parameter
#' names for the same concept, e.g., max_tokens, max_new_tokens, or
#' max_gen_len. However, Ellmer uses the parameter names above
#' for consistency across all models, and the Converse API conveniently
#' handles the mapping from these to the model-specific native
#' parameter names.
#' @param verbose Logical. When TRUE, prints AWS credentials,
#' request and response headers/bodies for debugging.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @family chatbots
#' @export
#' @examples
#' \dontrun{
#' # Basic usage
#' chat <- chat_bedrock()
#' chat$chat("Tell me three jokes about statisticians")
#'
#' # Using custom API parameters
#' chat <- chat_bedrock(
#' model = "us.meta.llama3-2-3b-instruct-v1:0",
#' api_args = list(
#' temperature = 0.7,
#' max_tokens = 2000
#' )
#' )
#'
#' # Enable verbose output for debugging requests and responses
#' chat <- chat_bedrock(verbose = TRUE)
#'
#' # Custom system prompt with API parameters
#' chat <- chat_bedrock(
#' system_prompt = "You are a helpful data science assistant",
#' api_args = list(temperature = 0.5)
#' )
#'
#' # Use a non-default AWS profile in ~/.aws/credentials
#' chat <- chat_bedrock(profile = "my_profile_name")
#'
#' # Image interpretation when using a vision capable model
#' chat <- chat_bedrock(
#' model = "us.meta.llama3-2-11b-instruct-v1:0"
#' )
#' chat$chat(
#' "What's in this image?",
#' content_image_file("path/to/image.jpg")
#' )
#'
#' # The echo argument, "none", "text", and "all" determines whether
#' # input and/or output is echoed to the console. Also of note, "none" uses a
#' # non-streaming endpoint, whereas "text", "all", or TRUE uses a streaming endpoint.
#' # You can use verbose=TRUE to verify which endpoint is used.
#' chat <- chat_bedrock(verbose = TRUE)
#' chat$chat("What is 1 + 1?") # Streaming response
#' resp <- chat$chat("What is 1 + 1?", echo = "none") # Non-streaming response
#' resp # View response
#'
#' # Use echo = "none" in the client constructor to suppress streaming response
#' chat <- chat_bedrock(echo = "none")
#' resp <- chat$chat("What is 1 + 1?") # Non-streaming response
#' resp # View response
#' chat$chat("What is 1 + 1?", echo=TRUE) # Overrides client echo arg, uses streaming
#'
#' # $stream returns a generator, requiring concatentation of the streamed responses.
#' resp <- chat$stream("What is the capital of France?") # resp is a generator object
#' chunks <- coro::collect(resp) # returns list of partial text responses
#' complete_response <- paste(chunks, collapse="") # Full text response, no echo
#' }
chat_bedrock <- function(system_prompt = NULL,
turns = NULL,
model = NULL,
profile = NULL,
echo = NULL) {
echo = NULL,
api_args = NULL,
verbose = FALSE) {

check_installed("paws.common", "AWS authentication")
cache <- aws_creds_cache(profile)
credentials <- paws_credentials(profile, cache = cache)

# Validate api_args if present
if (!is.null(api_args)) {
validate_parameters(api_args, model)
}

turns <- normalize_turns(turns, system_prompt)
model <- set_default(model, "anthropic.claude-3-5-sonnet-20240620-v1:0")
echo <- check_echo(echo)
Expand All @@ -47,7 +127,9 @@ chat_bedrock <- function(system_prompt = NULL,
model = model,
profile = profile,
region = credentials$region,
cache = cache
cache = cache,
api_args = if (is.null(api_args)) list() else api_args,
verbose = verbose
)

Chat$new(provider = provider, turns = turns, echo = echo)
Expand All @@ -60,17 +142,53 @@ ProviderBedrock <- new_class(
model = prop_string(),
profile = prop_string(allow_null = TRUE),
region = prop_string(),
cache = class_list
cache = class_list,
api_args = class_list,
verbose = class_logical
)
)

validate_parameters <- function(api_args, model) {
# Check for unsupported parameters in Llama models
if (grepl("llama", model, ignore.case = TRUE)) {
if (!is.null(api_args$top_k)) {
cli::cli_abort("top_k parameter is not supported for Llama models")
}
if (!is.null(api_args$stop_sequences)) {
cli::cli_abort("stop_sequences parameter is not supported for Llama models")
}
}

# Validate temperature
if (!is.null(api_args$temperature) && (api_args$temperature < 0 || api_args$temperature > 1)) {
cli::cli_abort("temperature must be a numeric value between 0 and 1, inclusive")
}

# Validate top_p
if (!is.null(api_args$top_p) && (api_args$top_p < 0 || api_args$top_p > 1)) {
cli::cli_abort("top_p must be a numeric value between 0 and 1, inclusive")
}

# Validate top_k
if (!is.null(api_args$top_k)) {
if (!is.numeric(api_args$top_k) || api_args$top_k <= 0 || api_args$top_k %% 1 != 0) {
cli::cli_abort("top_k must be a positive integer")
}
}
}

method(chat_request, ProviderBedrock) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {

# Validate parameters if api_args are present
if (length(provider@api_args) > 0) {
validate_parameters(provider@api_args, provider@model)
}

req <- request(paste0(
"https://bedrock-runtime.", provider@region, ".amazonaws.com"
))
Expand All @@ -88,6 +206,16 @@ method(chat_request, ProviderBedrock) <- function(provider,
aws_session_token = creds$session_token
)

if (provider@verbose) {
cli::cli_h3("AWS Credentials")
cli::cli_alert_info(paste0("Profile: ", provider@profile,
"; Key: ", paste0(creds$access_key_id),
"; Secret: ", paste0(substr(creds$secret_access_key, 1, 2),
paste(rep("*", 4), collapse = "")),
"; Session: ", creds$session_token,
"; Region: ", provider@region))
}

req <- req_error(req, body = function(resp) {
body <- resp_body_json(resp)
body$Message %||% body$message
Expand Down Expand Up @@ -123,16 +251,55 @@ method(chat_request, ProviderBedrock) <- function(provider,
}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
req <- req_body_json(req, list(
# Build request body
body <- list(
messages = messages,
system = system,
toolConfig = toolConfig
))
)

# Add inference configuration from api_args if present
if (length(provider@api_args) > 0) {
inference_config <- list()

# Convert snake_case parameters to camelCase for Converse API
if (!is.null(provider@api_args$max_tokens)) {
inference_config$maxTokens <- provider@api_args$max_tokens
}
if (!is.null(provider@api_args$temperature)) {
inference_config$temperature <- provider@api_args$temperature
}
if (!is.null(provider@api_args$top_p)) {
inference_config$topP <- provider@api_args$top_p
}
if (!is.null(provider@api_args$top_k)) {
inference_config$topK <- provider@api_args$top_k
}
if (!is.null(provider@api_args$stop_sequences)) {
inference_config$stopSequences <- provider@api_args$stop_sequences
}

# Only add inferenceConfig if we have parameters
if (length(inference_config) > 0) {
body$inferenceConfig <- inference_config
}
}

req <- req_body_json(req, body)

if (provider@verbose) {
cli::cli_h3("Request Body")
cat(jsonlite::toJSON(body, auto_unbox = TRUE, pretty = TRUE), "\n")
req <- httr2::req_verbose(req)
}

req
return(req)
}

method(chat_resp_stream, ProviderBedrock) <- function(provider, resp) {
if (provider@verbose) {
cli::cli_h3("Response Stream")
}
resp_stream_aws(resp)
}

Expand All @@ -145,15 +312,22 @@ method(stream_parse, ProviderBedrock) <- function(provider, event) {

body <- event$body
body$event_type <- event$headers$`:event-type`
body$p <- NULL # padding?
body$p <- NULL # padding? Looks like: "p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJ",

if (provider@verbose) {
cli::cli_h3("Response Chunk")
cat(jsonlite::toJSON(body, auto_unbox = TRUE, pretty = TRUE), "\n")
}

body
}

method(stream_text, ProviderBedrock) <- function(provider, event) {
if (event$event_type == "contentBlockDelta") {
event$delta$text
}
}

method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk) {
i <- chunk$contentBlockIndex + 1

Expand Down Expand Up @@ -200,6 +374,12 @@ method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk
}

method(value_turn, ProviderBedrock) <- function(provider, result, has_type = FALSE) {
# Print response if verbose mode is enabled
if (provider@verbose) {
cli::cli_h3("Response Body")
cat(jsonlite::toJSON(result, auto_unbox = TRUE, pretty = TRUE), "\n")
}

contents <- lapply(result$output$message$content, function(content) {
if (has_name(content, "text")) {
ContentText(content$text)
Expand Down Expand Up @@ -310,7 +490,7 @@ paws_credentials <- function(profile, cache = aws_creds_cache(profile),
creds <- locate_aws_credentials(profile),
error = function(cnd) {
if (is_testing()) {
testthat::skip("Failed to locate AWS credentails")
testthat::skip("Failed to locate AWS credentials")
}
cli::cli_abort("No IAM credentials found.", parent = cnd)
}
Expand Down
Loading
Loading