Skip to content

Commit

Permalink
Merge pull request #22 from sergiobayona/anthropic-support
Browse files Browse the repository at this point in the history
Anthropic support
  • Loading branch information
sergiobayona authored May 28, 2024
2 parents 757992b + 8e0079f commit e96ab3b
Show file tree
Hide file tree
Showing 26 changed files with 678 additions and 138 deletions.
21 changes: 17 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,33 @@ Instructor-rb is a Ruby library that makes it a breeze to work with structured o
require 'instructor'
```

3. At the beginning of your script, initialize and patch the OpenAI client:
3. At the beginning of your script, initialize and patch the client:

For the OpenAI client:

```ruby
client = Instructor.from_openai(OpenAI::Client)
```
For the Anthropic client:

```ruby
client = Instructor.patch(OpenAI::Client)
client = Instructor.from_anthropic(Anthropic::Client)
```

## Usage

export your OpenAI API key:
export your API key:

```bash
export OPENAI_API_KEY=sk-...
```

or for Anthropic:

```bash
export ANTHROPIC_API_KEY=sk-...
```

Then use Instructor by defining your schema in Ruby using the `define_schema` block and [EasyTalk](https://github.com/sergiobayona/easy_talk)'s schema definition syntax. Here's an example in:

```ruby
Expand All @@ -54,7 +67,7 @@ class UserDetail
end
end

client = Instructor.patch(OpenAI::Client).new
client = Instructor.from_openai(OpenAI::Client).new

user = client.chat(
parameters: {
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/action_items.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ To extract action items from a meeting transcript, we use the **`extract_action_
```Ruby

def extract_action_items(data)
client = Instructor.patch(OpenAI::Client).new
client = Instructor.from_openai(OpenAI::Client).new

client.chat(
parameters: {
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class UserDetail
end
end

client = Instructor.patch(OpenAI::Client).new
client = Instructor.from_openai(OpenAI::Client).new

user = client.chat(
parameters: {
Expand Down
1 change: 1 addition & 0 deletions instructor-rb.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Gem::Specification.new do |spec|
spec.require_paths = ['lib']

spec.add_dependency 'activesupport', '~> 7.0'
spec.add_dependency 'anthropic', '~> 0.2'
spec.add_dependency 'easy_talk', '~> 0.2'
spec.add_dependency 'ruby-openai', '~> 7'
spec.add_development_dependency 'pry-byebug', '~> 3.10'
Expand Down
11 changes: 10 additions & 1 deletion lib/instructor.rb
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# frozen_string_literal: true

require 'openai'
require 'anthropic'
require 'easy_talk'
require 'active_support/all'
require_relative 'instructor/version'
require_relative 'instructor/openai/patch'
require_relative 'instructor/openai/response'
require_relative 'instructor/anthropic/patch'
require_relative 'instructor/anthropic/response'
require_relative 'instructor/mode'

# Instructor makes it easy to reliably get structured data like JSON from Large Language Models (LLMs)
Expand All @@ -30,8 +33,14 @@ def self.mode
# @param openai_client [OpenAI::Client] The OpenAI client to be patched.
# @param mode [Symbol] The mode to be used. Default is `Instructor::Mode::TOOLS.function`.
# @return [OpenAI::Client] The patched OpenAI client.
def self.patch(openai_client, mode: Instructor::Mode::TOOLS.function)
def self.from_openai(openai_client, mode: Instructor::Mode::TOOLS.function)
@mode = mode
openai_client.prepend(Instructor::OpenAI::Patch)
end

# @param anthropic_client [Anthropic::Client] The Anthropic client to be patched.
# @return [Anthropic::Client] The patched Anthropic client.
def self.from_anthropic(anthropic_client)
anthropic_client.prepend(Instructor::Anthropic::Patch)
end
end
55 changes: 55 additions & 0 deletions lib/instructor/anthropic/patch.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# frozen_string_literal: true

require 'anthropic'
require 'instructor/base/patch'

# The Instructor module provides functionality for interacting with Anthropic's messages API.
module Instructor
module Anthropic
# The `Patch` module provides methods for patching and modifying the Anthropic client behavior.
module Patch
include Instructor::Base::Patch

# Sends a message request to the API and processes the response.
#
# @param parameters [Hash] The parameters for the chat request as expected by the OpenAI client.
# @param response_model [Class] The response model class.
# @param max_retries [Integer] The maximum number of retries. Default is 0.
# @param validation_context [Hash] The validation context for the parameters. Optional.
# @return [Object] The processed response.
def messages(parameters:, response_model: nil, max_retries: 0, validation_context: nil)
with_retries(max_retries, [JSON::ParserError, Instructor::ValidationError, Faraday::ParsingError]) do
model = determine_model(response_model)
function = build_function(model)
parameters[:max_tokens] = 1024 unless parameters.key?(:max_tokens)
parameters = prepare_parameters(parameters, validation_context, function)
::Anthropic.configuration.extra_headers = { 'anthropic-beta' => 'tools-2024-04-04' }
response = ::Anthropic::Client.json_post(path: '/messages', parameters:)
process_response(response, model)
end
end

# Processes the API response.
#
# @param response [Hash] The API response.
# @param model [Class] The response model class.
# @return [Object] The processed response.
def process_response(response, model)
parsed_response = Response.new(response).parse
iterable? ? process_multiple_responses(parsed_response, model) : process_single_response(parsed_response, model)
end

# Builds the function details for the API request.
#
# @param model [Class] The response model class.
# @return [Hash] The function details.
def build_function(model)
{
name: generate_function_name(model),
description: generate_description(model),
input_schema: model.json_schema
}
end
end
end
end
57 changes: 57 additions & 0 deletions lib/instructor/anthropic/response.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# frozen_string_literal: true

module Instructor
module Anthropic
# The Response class represents the response received from the OpenAI API.
# It takes the raw response and provides convenience methods to access the chat completions,
# tool calls, function responses, and parsed arguments.
class Response
# Initializes a new instance of the Response class.
#
# @param response [Hash] The response received from the OpenAI API.
def initialize(response)
@response = response
end

# Parses the function response(s) and returns the parsed arguments.
#
# @return [Array, Hash] The parsed arguments.
# @raise [StandardError] if the api response contains an error.
def parse
raise StandardError, error_message if error?

if single_response?
arguments.first
else
arguments
end
end

private

def content
@response['content']
end

def tool_calls
content.is_a?(Array) && content.select { |c| c['type'] == 'tool_use' }
end

def single_response?
tool_calls&.size == 1
end

def arguments
tool_calls.map { |tc| tc['input'] }
end

def error?
@response['type'] == 'error'
end

def error_message
"#{@response.dig('error', 'type')} - #{@response.dig('error', 'message')}"
end
end
end
end
143 changes: 143 additions & 0 deletions lib/instructor/base/patch.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# frozen_string_literal: true

module Instructor
module Base
# The `Patch` module provides common methods for patching and modifying the client behavior.
module Patch
# Generates the function name for the API request.
# You can customize the function name for the LLM by adding a `title` key to the schema.
# Example:
# ```ruby
# class User
# include EasyTalk::Model
# define_schema do
# title 'User'
# property :name, String
# property :age, Integer
# end
# end
# ```
# The function name will be `User`.
# If the `title` key is not present, the function name will be the model's name.
# @param model [Class] The response model class.
# @return [String] The generated function name.
def generate_function_name(model)
model.schema.fetch(:title, model.name)
end

# Generates the description for the function.
# You can customize the instructions for the LLM by adding an `instructions` class method to the response model.
# Example:
# ```ruby
# class User
# include EasyTalk::Model
# def self.instructions
# 'Extract the user name and age from the response'
# end
#
# define_schema do ...
# end
# ```
#
# @param model [Class] The response model class.
# @return [String] The generated description.
def generate_description(model)
if model.respond_to?(:instructions)
raise Instructor::Error, 'The instructions must be a string' unless model.instructions.is_a?(String)

model.instructions
else
"Correctly extracted `#{model.name}` with all the required parameters with correct types"
end
end

private

# Executes a block of code with retries in case of specific exceptions.
#
# @param max_retries [Integer] The maximum number of retries.
# @param exceptions [Array<Class>] The exceptions to catch and retry.
# @yield The block of code to execute.
def with_retries(max_retries, exceptions, &block)
attempts = 0
begin
block.call
rescue *exceptions
attempts += 1
retry if attempts < max_retries
raise
end
end

# Prepares the parameters for the chat request.
#
# @param parameters [Hash] The original parameters.
# @param validation_context [Hash] The validation context for the parameters.
# @param function [Hash] The function details.
# @return [Hash] The prepared parameters.
def prepare_parameters(parameters, validation_context, function)
# parameters # fetch the parameters's max_token or set it to 1024
parameters = apply_validation_context(parameters, validation_context)
parameters.merge(tools: [function])
end

# Processes multiple responses from the API.
#
# @param parsed_response [Array<Hash>] The parsed API responses.
# @param model [Class] The response model class.
# @return [Array<Object>] The processed responses.
def process_multiple_responses(parsed_response, model)
parsed_response.map do |response|
instance = model.new(response)
instance.valid? ? instance : raise(Instructor::ValidationError)
end
end

# Processes a single response from the API.
#
# @param parsed_response [Hash] The parsed API response.
# @param model [Class] The response model class.
# @return [Object] The processed response.
def process_single_response(parsed_response, model)
instance = model.new(parsed_response)
instance.valid? ? instance : raise(Instructor::ValidationError)
end

# Determines the response model based on the provided value.
#
# @param response_model [Class] The response model class or typed array.
# @return [Class] The determined response model class.
def determine_model(response_model)
if response_model.is_a?(T::Types::TypedArray)
@iterable = true
response_model.type.raw_type
else
@iterable = false
response_model
end
end

# Applies the validation context to the parameters.
#
# @param parameters [Hash] The original parameters.
# @param validation_context [Hash] The validation context.
# @return [Hash] The parameters with applied validation context.
def apply_validation_context(parameters, validation_context)
return parameters unless validation_context.is_a?(Hash)

Array[validation_context].each_with_index do |message, index|
parameters[:messages][index][:content] = parameters[:messages][index][:content] % message
end

parameters
end

# Checks if the response is iterable.
#
# @return [Boolean] `true` if the response is iterable, `false` otherwise.
def iterable?
@iterable
end
end
end
end
Loading

0 comments on commit e96ab3b

Please sign in to comment.