Skip to content

Commit

Permalink
feat(chat_gpt): compress chats once tasks are completed (#366)
Browse files Browse the repository at this point in the history
also remove older messages from chats when there is no memory
  • Loading branch information
stakach authored Nov 3, 2023
1 parent 1783a73 commit d74716c
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 36 deletions.
4 changes: 2 additions & 2 deletions shard.lock
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ shards:

openai:
git: https://github.com/spider-gazelle/crystal-openai.git
version: 0.9.0+git.commit.e6bfaba7758f992d7cb81cad0109180d5be2d958
version: 0.9.0+git.commit.f5d57e6973fa52b494bb084a99253da5dae8dad8

openapi-generator:
git: https://github.com/place-labs/openapi-generator.git
Expand Down Expand Up @@ -267,7 +267,7 @@ shards:

placeos-models:
git: https://github.com/placeos/models.git
version: 9.25.2
version: 9.26.0

placeos-resource:
git: https://github.com/place-labs/resource.git
Expand Down
5 changes: 3 additions & 2 deletions src/constants.cr
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ module PlaceOS::Api
PROD = ENV["SG_ENV"]?.try(&.downcase) == "production"

# Open AI
OPENAI_API_KEY = ENV["OPENAI_API_KEY"]?
OPENAI_API_BASE = ENV["OPENAI_API_BASE"]? # Set this to Azure URL only if Azure OpenAI is used
OPENAI_API_KEY = ENV["OPENAI_API_KEY"]?
OPENAI_API_BASE = ENV["OPENAI_API_BASE"]? # Set this to Azure URL only if Azure OpenAI is used
OPENAI_MAX_TOKENS = ENV["OPENAI_MAX_TOKENS"]?.try(&.to_i) || 8192

# CHANGELOG
#################################################################################################
Expand Down
12 changes: 7 additions & 5 deletions src/placeos-rest-api/controllers/chat_gpt.cr
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ module PlaceOS::Api
chat.destroy
end

record Config, api_key : String, api_base : String?
record Config, api_key : String, api_base : String?, max_tokens : Int32

protected def config
if internals = authority.internals["openai"]?
key = internals["api_key"]?.try &.as_s || Api::OPENAI_API_KEY || raise Error::NotFound.new("missing openai api_key configuration")
Config.new(key, internals["api_base"]?.try &.as_s || Api::OPENAI_API_BASE)
api_key = internals["api_key"]?.try &.as_s || Api::OPENAI_API_KEY || raise Error::NotFound.new("missing openai api_key configuration")
api_base = internals["api_base"]?.try &.as_s || Api::OPENAI_API_BASE
max_tokens = internals["max_tokens"]?.try &.as_i || Api::OPENAI_MAX_TOKENS
Config.new(api_key, api_base, max_tokens)
else
key = Api::OPENAI_API_KEY || raise Error::NotFound.new("missing openai api_key configuration")
Config.new(key, Api::OPENAI_API_BASE)
api_key = Api::OPENAI_API_KEY || raise Error::NotFound.new("missing openai api_key configuration")
Config.new(api_key, Api::OPENAI_API_BASE, Api::OPENAI_MAX_TOKENS)
end
end
end
Expand Down
214 changes: 187 additions & 27 deletions src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ module PlaceOS::Api
ws_sockets[ws_id] = {ws, id, c, req, e}
{ws, id, c, req, e}
end
resp = openai_interaction(client, completion_req, executor, message, chat_id)
ws.send(resp.to_json)
openai_interaction(client, completion_req, executor, message, chat_id) do |resp|
ws.send(resp.to_json)
end
end
rescue error
Log.warn(exception: error) { "failure processing chat message" }
Expand All @@ -77,14 +78,18 @@ module PlaceOS::Api

private def setup(chat, chat_payload)
client = build_client
executor = build_executor(chat, chat_payload)
executor = build_executor(chat)
chat_completion = build_completion(build_prompt(chat, chat_payload), executor.functions)

{client, executor, chat_completion}
end

private def build_client
app_config = app.config

# we save 10% of the tokens to hold the latest request and new output, should be enough
@max_tokens = (app_config.max_tokens.to_f * 0.90).to_i

config = if base = app_config.api_base
OpenAI::Client::Config.azure(api_key: app_config.api_key, api_base: base)
else
Expand All @@ -103,31 +108,170 @@ module PlaceOS::Api
)
end

private def openai_interaction(client, request, executor, message, chat_id) : NamedTuple(chat_id: String, message: String?)
@max_tokens : Int32 = 0
@total_tokens : Int32 = 0

# ameba:disable Metrics/CyclomaticComplexity
private def openai_interaction(client, request, executor, message, chat_id, &) : Nil
request.messages << OpenAI::ChatMessage.new(role: :user, content: message)
save_history(chat_id, :user, message)

# track token usage
discardable_tokens = 0
tracking_total = 0
calculate_discard = false
save_initial_msg = true

# ensure we don't loop forever
count = 0
loop do
count += 1
if count > 20
yield({chat_id: chat_id, message: "sorry, I am unable to complete that task", type: :response})
request.messages.truncate(0..0) # leave only the prompt
break
end

# cleanup old messages, saving first system prompt and then removing messages beyond that until we're within the limit
ensure_request_fits(request)

# track token usage
resp = client.chat_completion(request)

if save_initial_msg
save_initial_msg = false

# the first request is actually the prompt + user message
# we always want to keep the prompt so we need to guestimate how many tokens this user message actually contains
# this doesn't need to be highly accurate
if request.messages.size == 2
calculate_initial_request_size(request, resp.usage)
save_history(chat_id, :user, message, request.messages[1].tokens)
else
save_history(chat_id, :user, message, resp.usage.prompt_tokens - @total_tokens)
end
end
@total_tokens = resp.usage.total_tokens

if calculate_discard
discardable_tokens += resp.usage.prompt_tokens - tracking_total
calculate_discard = false
end
tracking_total = @total_tokens

# save relevant history
msg = resp.choices.first.message
msg.tokens = resp.usage.completion_tokens
request.messages << msg
save_history(chat_id, msg)
save_history(chat_id, msg) unless msg.function_call || (msg.role.function? && msg.name != "task_complete")

# perform function calls until we get a response for the user
if func_call = msg.function_call
func_res = executor.execute(func_call)
discardable_tokens += resp.usage.completion_tokens

# handle the AI not providing a valid function name, we want it to retry
func_res = begin
executor.execute(func_call)
rescue ex
Log.error(exception: ex) { "executing function call" }
reply = "Encountered error: #{ex.message}"
result = DriverResponse.new(reply).as(JSON::Serializable)
request.messages << OpenAI::ChatMessage.new(:function, result.to_pretty_json, func_call.name)
next
end

# process the function result
case func_res.name
when "task_complete"
cleanup_messages(request, discardable_tokens)
discardable_tokens = 0
summary = TaskCompleted.from_json func_call.arguments.as_s
yield({chat_id: chat_id, message: "condensing progress: #{summary.details}", type: :progress, function: func_res.name, usage: resp.usage, compressed_usage: @total_tokens})
when "list_function_schemas"
calculate_discard = true
discover = FunctionDiscovery.from_json func_call.arguments.as_s
yield({chat_id: chat_id, message: "checking #{discover.id} capabilities", type: :progress, function: func_res.name, usage: resp.usage})
when "call_function"
calculate_discard = true
execute = FunctionExecutor.from_json func_call.arguments.as_s
yield({chat_id: chat_id, message: "performing action: #{execute.id}.#{execute.function}(#{execute.parameters})", type: :progress, function: func_res.name, usage: resp.usage})
end
request.messages << func_res
save_history(chat_id, msg)
next
end
break {chat_id: chat_id, message: msg.content}

cleanup_messages(request, discardable_tokens)
yield({chat_id: chat_id, message: msg.content, type: :response, usage: resp.usage, compressed_usage: @total_tokens})
break
end
end

private def ensure_request_fits(request)
return if @total_tokens < @max_tokens

messages = request.messages

# NOTE:: we need at least one user message in the request
num_user = messages.count(&.role.user?)

# let the LLM know some information has been removed
if messages[1].role.user?
# we inject a message to the AI to indicate that some messages have been removed
messages.insert(1, OpenAI::ChatMessage.new(role: :system, content: "some earlier messages have been removed", tokens: 6))
@total_tokens += 6
end

delete_at = 2

loop do
msg = messages.delete_at(delete_at)
if msg.role.user?
if num_user == 1
messages.insert(delete_at, msg)
delete_at += 1
next
end

num_user -= 1
end
@total_tokens -= msg.tokens

break if @total_tokens <= @max_tokens || messages[delete_at]?.nil?
end
end

private def save_history(chat_id : String, role : PlaceOS::Model::ChatMessage::Role, message : String, func_name : String? = nil, func_args : JSON::Any? = nil) : Nil
PlaceOS::Model::ChatMessage.create!(role: role, chat_id: chat_id, content: message, function_name: func_name, function_args: func_args)
private def calculate_initial_request_size(request, usage)
msg = request.messages.pop
prompt = request.messages.pop

prompt_size = prompt.content.as(String).count(' ')
msg_size = msg.content.as(String).count(' ')

token_part = usage.prompt_tokens / (prompt_size + msg_size)

msg_tokens = (token_part * msg_size).to_i
prompt_tokens = (token_part * prompt_size).to_i

msg.tokens = msg_tokens
prompt.tokens = prompt_tokens

request.messages << prompt
request.messages << msg
end

private def cleanup_messages(request, discardable_tokens)
# keep task summaries
request.messages.reject! { |mess| mess.function_call || (mess.role.function? && mess.name != "task_complete") }

# a good estimate of the total tokens once the cleanup is complete
@total_tokens = @total_tokens - discardable_tokens
end

private def save_history(chat_id : String, role : PlaceOS::Model::ChatMessage::Role, message : String, tokens : Int32, func_name : String? = nil, func_args : JSON::Any? = nil) : Nil
PlaceOS::Model::ChatMessage.create!(role: role, chat_id: chat_id, content: message, tokens: tokens, function_name: func_name, function_args: func_args)
end

private def save_history(chat_id : String, msg : OpenAI::ChatMessage)
save_history(chat_id, PlaceOS::Model::ChatMessage::Role.parse(msg.role.to_s), msg.content || "", msg.name, msg.function_call.try &.arguments)
save_history(chat_id, PlaceOS::Model::ChatMessage::Role.parse(msg.role.to_s), msg.content || "", msg.tokens, msg.name, msg.function_call.try &.arguments)
end

private def build_prompt(chat : PlaceOS::Model::Chat, chat_payload : Payload?)
Expand All @@ -140,15 +284,17 @@ module PlaceOS::Api
role: :system,
content: String.build { |str|
str << payload.prompt
str << "\n\nrequest function lists and call functions as required to fulfil requests.\n"
str << "\n\nrequest function schemas and call functions as required to fulfil requests.\n"
str << "make sure to interpret results and reply appropriately once you have all the information.\n"
str << "remember to only use valid capability ids, they can be found in this JSON:\n```json\n#{payload.capabilities.to_json}\n```\n\n"
str << "remember to use valid capability ids, they can be found in this JSON:\n```json\n#{payload.capabilities.to_json}\n```\n\n"
str << "you must have a schema for a function before calling it\n"
str << "my name is: #{user.name}\n"
str << "my email is: #{user.email}\n"
str << "my phone number is: #{user.phone}\n" if user.phone.presence
str << "my swipe card number is: #{user.card_number}\n" if user.card_number.presence
str << "my user_id is: #{user.id}\n"
str << "use these details in function calls as required\n"
str << "use these details in function calls as required.\n"
str << "perform one task at a time, making as many function calls as required to complete a task. Once a task is complete call the task_complete function with details of the progress you've made.\n"
str << "the chat client prepends the date-time each message was sent at in the following format YYYY-MM-DD HH:mm:ss +ZZ:ZZ:ZZ"
}
)
Expand All @@ -163,9 +309,12 @@ module PlaceOS::Api
func_call = OpenAI::ChatFunctionCall.new(name, args)
end
end
messages << OpenAI::ChatMessage.new(role: OpenAI::ChatMessageRole.parse(hist.role.to_s), content: hist.content,
messages << OpenAI::ChatMessage.new(
role: OpenAI::ChatMessageRole.parse(hist.role.to_s),
content: hist.content,
name: hist.function_name,
function_call: func_call
function_call: func_call,
tokens: hist.tokens
)
end
end
Expand All @@ -177,19 +326,12 @@ module PlaceOS::Api
Payload.from_json grab_driver_status(chat, LLM_DRIVER, LLM_DRIVER_PROMPT)
end

private def build_executor(chat, payload : Payload?)
private def build_executor(chat)
executor = OpenAI::FunctionExecutor.new

description = if payload
"You have the following capability list, described in the following JSON:\n```json\n#{payload.capabilities.to_json}\n```\n" +
"if a request could benefit from these capabilities, obtain the list of functions by providing the id string."
else
"if a request could benefit from a capability, obtain the list of functions by providing the id string"
end

executor.add(
name: "list_function_schemas",
description: description,
description: "if a request could benefit from a capability, obtain the list of function schemas by providing the id string",
clz: FunctionDiscovery
) do |call|
request = call.as(FunctionDiscovery)
Expand All @@ -206,7 +348,8 @@ module PlaceOS::Api
executor.add(
name: "call_function",
description: "Executes functionality offered by a capability, you'll need to obtain the function schema to perform requests",
clz: FunctionExecutor) do |call|
clz: FunctionExecutor
) do |call|
request = call.as(FunctionExecutor)
reply = "No response received"
begin
Expand All @@ -219,6 +362,15 @@ module PlaceOS::Api
DriverResponse.new(reply).as(JSON::Serializable)
end

executor.add(
name: "task_complete",
description: "Once a task is complete, call this function with the details that are relevant to the conversion. Provide enough detail so you don't perform the actions again and can formulate a response to the user",
clz: TaskCompleted
) do |call|
request = call.as(TaskCompleted)
request.as(JSON::Serializable)
end

executor
end

Expand Down Expand Up @@ -279,6 +431,14 @@ module PlaceOS::Api
getter id : String
end

private struct TaskCompleted
extend OpenAI::FuncMarker
include JSON::Serializable

@[JSON::Field(description: "the details of the task that are relevant to continuing the conversion")]
getter details : String
end

private record DriverResponse, body : String do
include JSON::Serializable
end
Expand Down

0 comments on commit d74716c

Please sign in to comment.