-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from sergiobayona/anthropic-support
Anthropic support
- Loading branch information
Showing
26 changed files
with
678 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'anthropic' | ||
require 'instructor/base/patch' | ||
|
||
# The Instructor module provides functionality for interacting with Anthropic's messages API. | ||
module Instructor | ||
module Anthropic | ||
# The `Patch` module provides methods for patching and modifying the Anthropic client behavior. | ||
module Patch | ||
include Instructor::Base::Patch | ||
|
||
# Sends a message request to the API and processes the response. | ||
# | ||
# @param parameters [Hash] The parameters for the chat request as expected by the OpenAI client. | ||
# @param response_model [Class] The response model class. | ||
# @param max_retries [Integer] The maximum number of retries. Default is 0. | ||
# @param validation_context [Hash] The validation context for the parameters. Optional. | ||
# @return [Object] The processed response. | ||
def messages(parameters:, response_model: nil, max_retries: 0, validation_context: nil) | ||
with_retries(max_retries, [JSON::ParserError, Instructor::ValidationError, Faraday::ParsingError]) do | ||
model = determine_model(response_model) | ||
function = build_function(model) | ||
parameters[:max_tokens] = 1024 unless parameters.key?(:max_tokens) | ||
parameters = prepare_parameters(parameters, validation_context, function) | ||
::Anthropic.configuration.extra_headers = { 'anthropic-beta' => 'tools-2024-04-04' } | ||
response = ::Anthropic::Client.json_post(path: '/messages', parameters:) | ||
process_response(response, model) | ||
end | ||
end | ||
|
||
# Processes the API response. | ||
# | ||
# @param response [Hash] The API response. | ||
# @param model [Class] The response model class. | ||
# @return [Object] The processed response. | ||
def process_response(response, model) | ||
parsed_response = Response.new(response).parse | ||
iterable? ? process_multiple_responses(parsed_response, model) : process_single_response(parsed_response, model) | ||
end | ||
|
||
# Builds the function details for the API request. | ||
# | ||
# @param model [Class] The response model class. | ||
# @return [Hash] The function details. | ||
def build_function(model) | ||
{ | ||
name: generate_function_name(model), | ||
description: generate_description(model), | ||
input_schema: model.json_schema | ||
} | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# frozen_string_literal: true | ||
|
||
module Instructor | ||
module Anthropic | ||
# The Response class represents the response received from the OpenAI API. | ||
# It takes the raw response and provides convenience methods to access the chat completions, | ||
# tool calls, function responses, and parsed arguments. | ||
class Response | ||
# Initializes a new instance of the Response class. | ||
# | ||
# @param response [Hash] The response received from the OpenAI API. | ||
def initialize(response) | ||
@response = response | ||
end | ||
|
||
# Parses the function response(s) and returns the parsed arguments. | ||
# | ||
# @return [Array, Hash] The parsed arguments. | ||
# @raise [StandardError] if the api response contains an error. | ||
def parse | ||
raise StandardError, error_message if error? | ||
|
||
if single_response? | ||
arguments.first | ||
else | ||
arguments | ||
end | ||
end | ||
|
||
private | ||
|
||
def content | ||
@response['content'] | ||
end | ||
|
||
def tool_calls | ||
content.is_a?(Array) && content.select { |c| c['type'] == 'tool_use' } | ||
end | ||
|
||
def single_response? | ||
tool_calls&.size == 1 | ||
end | ||
|
||
def arguments | ||
tool_calls.map { |tc| tc['input'] } | ||
end | ||
|
||
def error? | ||
@response['type'] == 'error' | ||
end | ||
|
||
def error_message | ||
"#{@response.dig('error', 'type')} - #{@response.dig('error', 'message')}" | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# frozen_string_literal: true | ||
|
||
module Instructor | ||
module Base | ||
# The `Patch` module provides common methods for patching and modifying the client behavior. | ||
module Patch | ||
# Generates the function name for the API request. | ||
# You can customize the function name for the LLM by adding a `title` key to the schema. | ||
# Example: | ||
# ```ruby | ||
# class User | ||
# include EasyTalk::Model | ||
# define_schema do | ||
# title 'User' | ||
# property :name, String | ||
# property :age, Integer | ||
# end | ||
# end | ||
# ``` | ||
# The function name will be `User`. | ||
# If the `title` key is not present, the function name will be the model's name. | ||
# @param model [Class] The response model class. | ||
# @return [String] The generated function name. | ||
def generate_function_name(model) | ||
model.schema.fetch(:title, model.name) | ||
end | ||
|
||
# Generates the description for the function. | ||
# You can customize the instructions for the LLM by adding an `instructions` class method to the response model. | ||
# Example: | ||
# ```ruby | ||
# class User | ||
# include EasyTalk::Model | ||
# def self.instructions | ||
# 'Extract the user name and age from the response' | ||
# end | ||
# | ||
# define_schema do ... | ||
# end | ||
# ``` | ||
# | ||
# @param model [Class] The response model class. | ||
# @return [String] The generated description. | ||
def generate_description(model) | ||
if model.respond_to?(:instructions) | ||
raise Instructor::Error, 'The instructions must be a string' unless model.instructions.is_a?(String) | ||
|
||
model.instructions | ||
else | ||
"Correctly extracted `#{model.name}` with all the required parameters with correct types" | ||
end | ||
end | ||
|
||
private | ||
|
||
# Executes a block of code with retries in case of specific exceptions. | ||
# | ||
# @param max_retries [Integer] The maximum number of retries. | ||
# @param exceptions [Array<Class>] The exceptions to catch and retry. | ||
# @yield The block of code to execute. | ||
def with_retries(max_retries, exceptions, &block) | ||
attempts = 0 | ||
begin | ||
block.call | ||
rescue *exceptions | ||
attempts += 1 | ||
retry if attempts < max_retries | ||
raise | ||
end | ||
end | ||
|
||
# Prepares the parameters for the chat request. | ||
# | ||
# @param parameters [Hash] The original parameters. | ||
# @param validation_context [Hash] The validation context for the parameters. | ||
# @param function [Hash] The function details. | ||
# @return [Hash] The prepared parameters. | ||
def prepare_parameters(parameters, validation_context, function) | ||
# parameters # fetch the parameters's max_token or set it to 1024 | ||
parameters = apply_validation_context(parameters, validation_context) | ||
parameters.merge(tools: [function]) | ||
end | ||
|
||
# Processes multiple responses from the API. | ||
# | ||
# @param parsed_response [Array<Hash>] The parsed API responses. | ||
# @param model [Class] The response model class. | ||
# @return [Array<Object>] The processed responses. | ||
def process_multiple_responses(parsed_response, model) | ||
parsed_response.map do |response| | ||
instance = model.new(response) | ||
instance.valid? ? instance : raise(Instructor::ValidationError) | ||
end | ||
end | ||
|
||
# Processes a single response from the API. | ||
# | ||
# @param parsed_response [Hash] The parsed API response. | ||
# @param model [Class] The response model class. | ||
# @return [Object] The processed response. | ||
def process_single_response(parsed_response, model) | ||
instance = model.new(parsed_response) | ||
instance.valid? ? instance : raise(Instructor::ValidationError) | ||
end | ||
|
||
# Determines the response model based on the provided value. | ||
# | ||
# @param response_model [Class] The response model class or typed array. | ||
# @return [Class] The determined response model class. | ||
def determine_model(response_model) | ||
if response_model.is_a?(T::Types::TypedArray) | ||
@iterable = true | ||
response_model.type.raw_type | ||
else | ||
@iterable = false | ||
response_model | ||
end | ||
end | ||
|
||
# Applies the validation context to the parameters. | ||
# | ||
# @param parameters [Hash] The original parameters. | ||
# @param validation_context [Hash] The validation context. | ||
# @return [Hash] The parameters with applied validation context. | ||
def apply_validation_context(parameters, validation_context) | ||
return parameters unless validation_context.is_a?(Hash) | ||
|
||
Array[validation_context].each_with_index do |message, index| | ||
parameters[:messages][index][:content] = parameters[:messages][index][:content] % message | ||
end | ||
|
||
parameters | ||
end | ||
|
||
# Checks if the response is iterable. | ||
# | ||
# @return [Boolean] `true` if the response is iterable, `false` otherwise. | ||
def iterable? | ||
@iterable | ||
end | ||
end | ||
end | ||
end |
Oops, something went wrong.