Skip to content

Commit

Permalink
feat: Added tool support for parallel function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
naqvis committed Nov 28, 2023
1 parent f5d57e6 commit 9ba725a
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 19 deletions.
38 changes: 22 additions & 16 deletions examples/function_call.cr
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct Weather
end
end

record WeatherRes, location : String, unit : WeatherUnit, temperature : Float64, description : String? do
record WeatherRes, location : String, unit : WeatherUnit, temperature : Float64 do
include JSON::Serializable
end

Expand All @@ -37,23 +37,30 @@ end
executor = OpenAI::FunctionExecutor.new

executor.add(
name: "get_weather",
name: "get_current_weather",
description: "Get the current weather of a given location",
clz: Weather # With this information, FunctionExecutor will auto populate the JSON Schema
) do |w| # In reality, this is an instance of `clz` param, but due to language restrictions, we are overcasting when invoke this block
w = w.as(Weather) # So here we have to downcast it back to. NOT Fany, but this is the best we can have with current Crystal restrictions :)
(WeatherRes.new(w.location, w.unit, rand(50), "sunny")).as(JSON::Serializable) # again we have to upcast the result to `Serializable`
clz: Weather # With this information, FunctionExecutor will auto populate the JSON Schema
) do |w| # In reality, this is an instance of `clz` param, but due to language restrictions, we are overcasting when invoke this block
w = w.as(Weather) # So here we have to downcast it back to. NOT Fany, but this is the best we can have with current Crystal restrictions :)
unit, temp = case w.location.downcase
when "tokyo" then {WeatherUnit::CELSIUS, 10}
when "san fransisco" then {WeatherUnit::FAHRENHEIT, 72}
when "paris" then {WeatherUnit::CELSIUS, 22}
else
{WeatherUnit::CELSIUS, rand(50)}
end
(WeatherRes.new(w.location, unit, temp)).as(JSON::Serializable) # again we have to upcast the result to `Serializable`
end

client = OpenAI::Client.new

req = OpenAI::ChatCompletionRequest.new(
model: OpenAI::GPT3Dot5Turbo, # gpt-3.5-turbo
model: OpenAI::GPT3Dot5Turbo1106, # gpt-3.5-turbo-1106
messages: [
OpenAI::ChatMessage.new(role: :system, content: "You are an assistant that answers using the local slang of the given place, uncensored."),
],
functions: executor.functions,
function_call: "auto"
tools: executor.tools,
tool_choice: "auto"
)
puts "\n--------------------------"
puts "Conversation"
Expand All @@ -69,16 +76,15 @@ loop do
msg = resp.choices.first.message
req.messages << msg # don't forget to update the conversation with the latest response

if func_call = msg.function_call
puts "Trying to execute #{func_call.name} ..."
func_res = executor.execute(func_call) # execute might raise, so its good to catch. But for demo just assume all is good
if tool_calls = msg.tool_calls
puts "Trying to execute #{tool_calls.size} function calls in parallel ..."
func_res = executor.execute(tool_calls) # execute might raise, so its good to catch. But for demo just assume all is good
# At this point
# * requested function was found
# * requested function(s) was found
# * request was converted to its specified object for execution (`Weather` in this demo case)
# * Block was executed
# * Block returned object (`WeatherRes` in this case) was converted back to `OpenAI::ChatMessage` object
puts "Executed #{func_call.name}."
req.messages << func_res
# * Block returned object (`WeatherRes` in this case) was converted back to `Array(OpenAI::ChatMessage)` object
req.messages.concat(func_res)
next
end

Expand Down
5 changes: 4 additions & 1 deletion shard.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: openai
version: 0.9.0
version: 0.9.1

description: |
Unofficial Crystal client for OpenAI API. Supports ChatGTP, GTP-3, GPT-4, DALL·E 2, Whisper
Expand All @@ -18,6 +18,9 @@ dependencies:
connect-proxy:
github: spider-gazelle/connect-proxy

promise:
github: spider-gazelle/promise

development_dependencies:
webmock:
github: manastech/webmock.cr
Expand Down
1 change: 1 addition & 0 deletions spec/spec_helper.cr
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ CHAT_COMPLETION_RES = <<-JSON
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0613",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
Expand Down
98 changes: 96 additions & 2 deletions src/openai/api/chat.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "json"
require "promise"
require "./usage"
require "../stream"

Expand All @@ -13,6 +14,8 @@ module OpenAI
Assistant
# function
Function
# tool
Tool

def to_s(io : IO) : Nil
io << to_s
Expand All @@ -34,6 +37,8 @@ module OpenAI
ContentFilter
# API response still in progress or incomplete
Null
# if model called a tool
ToolCalls

def to_s(io : IO) : Nil
io << to_s
Expand Down Expand Up @@ -65,6 +70,21 @@ module OpenAI
include JSON::Serializable
end

record ResponseFormat, type : FormatType do
include JSON::Serializable
enum FormatType
Text
JsonObject

def to_s(io : IO) : Nil
io << to_s
end

def to_s : String
super.underscore
end
end
end
# The name and arguments of a function that should be called, as generated by the model.
#
# The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON,
Expand All @@ -90,6 +110,26 @@ module OpenAI
end
end

record ChatTool, type : ToolType = ToolType::Function, function : ChatFunction? = nil do
include JSON::Serializable

enum ToolType
Function

def to_s(io : IO) : Nil
io << to_s
end

def to_s : String
super.downcase
end
end
end

record ChatToolCall, id : String, type : String, function : ChatFunctionCall do
include JSON::Serializable
end

struct ChatMessage
include JSON::Serializable

Expand All @@ -105,12 +145,19 @@ module OpenAI
getter name : String?

# The name and arguments of a function that should be called, as generated by the model.
@[Deprecated("Deprecated and replaced by tool_calls")]
getter function_call : ChatFunctionCall?

@[JSON::Field(ignore: true)]
property tokens : Int32 = 0

def initialize(@role, @content = nil, @name = nil, @function_call = nil, @tokens = 0)
# The tool calls generated by the model, such as function calls.
getter tool_calls : Array(ChatToolCall)?

# For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool.
getter tool_call_id : String?

def initialize(@role, @content = nil, @name = nil, @function_call = nil, @tokens = 0, @tool_calls = nil, @tool_call_id = nil)
end
end

Expand All @@ -119,7 +166,8 @@ module OpenAI

def initialize(@model, @messages, @max_tokens = nil, @temperature = 1.0, @top_p = 1.0,
@stream = false, @stop = nil, @presence_penalty = 0.0, @frequency_penalty = 0.0,
@logit_bias = nil, @user = nil, @functions = nil, @function_call = nil)
@logit_bias = nil, @user = nil, @functions = nil, @function_call = nil,
@tools = nil, @tool_choice = nil)
end

# the model id
Expand Down Expand Up @@ -173,13 +221,37 @@ module OpenAI
# A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
property user : String? = nil

# An object specifying the format that the model must output.
# Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
# Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this,
# the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request.
# Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation
# exceeded the max context length.
property response_format : ResponseFormat? = nil

# This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed
# and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor
# changes in the backend.
property seed : Int32? = nil

# A list of functions the model may generate JSON inputs for.
@[Deprecated("Deprecated in favor of tools")]
property functions : Array(ChatFunction)? = nil

# Controls how the model responds to function calls. none means the model does not call a function, and responds to the end-user.
# auto means the model can pick between an end-user or calling a function. Specifying a particular function via {"name": "my_function"}
# forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present.
@[Deprecated("Deprecated in favor of tool_choice")]
property function_call : String | JSON::Any? = nil

# A list of tools the model may call. Currently, only functions are supported as a tool.
# Use this to provide a list of functions the model may generate JSON inputs for.
property tools : Array(ChatTool)? = nil
# Controls which (if any) function is called by the model. `none` means the model will not call a function and instead generates a message.
# `auto` means the model can pick between generating a message or calling a function. Specifying a particular function via
# `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function.
# `none` is the default when no functions are present. `auto` is the default if functions are present.
property tool_choice : String | JSON::Any? = nil
end

record ChatCompletionChoice, index : Int32, message : ChatMessage, finish_reason : FinishReason do
Expand Down Expand Up @@ -207,6 +279,10 @@ module OpenAI

# Usage statistics for the completion request.
getter usage : Usage

# This fingerprint represents the backend configuration that the model runs with.
# Can be used in conjunction with the seed request parameter to understand when backend changes have been made that might impact determinism.
getter system_fingerprint : String?
end

record ChatCompletionStreamChoiceDelta, role : ChatMessageRole?, content : String?, function_call : ChatFunctionCall? do
Expand Down Expand Up @@ -247,15 +323,18 @@ module OpenAI
class FunctionExecutor
alias Callback = JSON::Serializable -> JSON::Serializable
getter functions : Array(ChatFunction)
getter tools : Array(ChatTool)

def initialize
@functions = Array(ChatFunction).new
@tools = Array(ChatTool).new
@map = Hash(String, {FuncMarker, Callback}).new
end

def add(name : String, description : String?, clz : U, &block : Callback) forall U
func = ChatFunction.new(name: name, description: description, parameters: JSON.parse(clz.json_schema.to_json))
functions << func
tools << ChatTool.new(function: func)
@map[name] = {clz, block}
end

Expand All @@ -267,5 +346,20 @@ module OpenAI
result = func.last.call(arg)
ChatMessage.new(:function, result.to_pretty_json, call.name)
end

def execute(calls : Array(ChatToolCall))
raise OpenAIError.new "OpenAI returned response with no function call details" if calls.empty?
Promise.all(
calls.map do |call|
raise OpenAIError.new "OpenAI called unknown function: name: '#{call.function.name}' with #{call.id}'" unless func = @map[call.function.name]? || @map[call.function.name.split('.', 2)[-1]]?
Promise(ChatMessage).defer(same_thread: true) do
params = call.function.arguments.as_s? || call.function.arguments.to_s
arg = func.first.from_json(params)
result = func.last.call(arg)
ChatMessage.new(:tool, result.to_pretty_json, call.function.name, tool_call_id: call.id)
end
end
).get
end
end
end
4 changes: 4 additions & 0 deletions src/openai/constants.cr
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ module OpenAI
GPT40613 = "gpt-4-0613"
GPT40314 = "gpt-4-0314"
GPT4 = "gpt-4"
GPT41106 = "gpt-4-1106-preview"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301"
GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k"
GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613"
GPT3Dot5Turbo = "gpt-3.5-turbo"
GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct"
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"

# Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead.
GPT3TextDavinci003 = "text-davinci-003"
Expand Down Expand Up @@ -101,12 +103,14 @@ module OpenAI
GPT3Dot5Turbo0613,
GPT3Dot5Turbo16K,
GPT3Dot5Turbo16K0613,
GPT3Dot5Turbo1106,
GPT4,
GPT40314,
GPT40613,
GPT432K,
GPT432K0314,
GPT432K0613,
GPT41106,
],
"/chat/completions" => [
CodexCodeDavinci002,
Expand Down

0 comments on commit 9ba725a

Please sign in to comment.