Skip to content

Commit

Permalink
Define a prompt builder
Browse files Browse the repository at this point in the history
  • Loading branch information
ksylvest committed Jul 18, 2024
1 parent ae441d4 commit 641e2a3
Show file tree
Hide file tree
Showing 28 changed files with 1,185 additions and 283 deletions.
4 changes: 2 additions & 2 deletions Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
omniai (1.5.2)
omniai (1.6.0)
event_stream_parser
http
zeitwerk
Expand Down Expand Up @@ -60,7 +60,7 @@ GEM
rainbow (3.1.1)
rake (13.2.1)
regexp_parser (2.9.2)
rexml (3.3.1)
rexml (3.3.2)
strscan
rspec (3.13.0)
rspec-core (~> 3.13.0)
Expand Down
54 changes: 29 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,30 @@ client = OmniAI::OpenAI::Client.new(timeout: {

Clients that support chat (e.g. Anthropic w/ "Claude", Google w/ "Gemini", Mistral w/ "LeChat", OpenAI w/ "ChatGPT", etc) generate completions using the following calls:

#### Completions using Single Message
#### Completions using a Simple Prompt

Generating a completion is as simple as sending in the text:

```ruby
completion = client.chat('Tell me a joke.')
completion.choice.message.content # '...'
completion.choice.message.content # 'Why don't scientists trust atoms? They make up everything!'
```

#### Completions using Multiple Messages
#### Completions using a Complex Prompt

More complex completions are generated using a block w/ various system / user messages:

```ruby
messages = [
{
role: OmniAI::Chat::Role::SYSTEM,
content: 'You are a helpful assistant with an expertise in geography.',
},
'What is the capital of Canada?'
]
completion = client.chat(messages, model: '...', temperature: 0.7, format: :json)
completion.choice.message.content # '...'
completion = client.chat do |prompt|
prompt.system 'You are a helpful assistant with an expertise in animals.'
prompt.user do |message|
message.text 'What animals are in the attached photos?'
message.url('https://.../cat.jpeg', "image/jpeg")
message.url('https://.../dog.jpeg', "image/jpeg")
message.file('./hamster.jpeg', "image/jpeg")
end
end
completion.choice.message.content # 'They are photos of a cat, a cat, and a hamster.'
```

#### Completions using Streaming via Proc
Expand All @@ -167,20 +172,19 @@ client.chat('Tell me a story', stream: $stdout)
A chat can also be initialized with tools:

```ruby
client.chat('What is the weather in "London, England" and "Madrid, Spain"?', tools: [
OmniAI::Tool.new(
proc { |location:, unit: 'celsius'| "It is #{rand(20..50)}° #{unit} in #{location}" },
name: 'Weather',
description: 'Lookup the weather in a location',
parameters: OmniAI::Tool::Parameters.new(
properties: {
location: OmniAI::Tool::Property.string(description: 'The city and country (e.g. Toronto, Canada).'),
unit: OmniAI::Tool::Property.string(enum: %w[celcius farenheit]),
},
required: %i[location]
)
tool = OmniAI::Tool.new(
proc { |location:, unit: 'celsius'| "#{rand(20..50)}° #{unit} in #{location}" },
name: 'Weather',
description: 'Lookup the weather in a location',
parameters: OmniAI::Tool::Parameters.new(
properties: {
location: OmniAI::Tool::Property.string(description: 'e.g. Toronto'),
unit: OmniAI::Tool::Property.string(enum: %w[celcius farenheit]),
},
required: %i[location]
)
])
)
client.chat('What is the weather in "London" and "Madrid"?', tools: [tool])
```

### Transcribe
Expand Down
37 changes: 14 additions & 23 deletions lib/omniai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,20 @@ def self.process!(...)
new(...).process!
end

# @param messages [String] required
# @param prompt [OmniAI::Chat::Prompt, String, nil] optional
# @param client [OmniAI::Client] the client
# @param model [String] required
# @param temperature [Float, nil] optional
# @param stream [Proc, IO, nil] optional
# @param tools [Array<OmniAI::Tool>] optional
# @param format [Symbol, nil] optional - :json
def initialize(messages, client:, model:, temperature: nil, stream: nil, tools: nil, format: nil)
@messages = arrayify(messages)
# @yield [prompt] optional
def initialize(prompt = nil, client:, model:, temperature: nil, stream: nil, tools: nil, format: nil, &block)
raise ArgumentError, 'prompt or block is required' if !prompt && !block

@prompt = prompt ? Prompt.parse(prompt) : Prompt.new
block&.call(@prompt)

@client = client
@model = model
@temperature = temperature
Expand All @@ -79,9 +84,12 @@ def process!
protected

# Used to spawn another chat with the same configuration using different messages.
def spawn!(messages)
#
# @param prompt [OmniAI::Chat::Prompt]
# @return [OmniAI::Chat::Prompt]
def spawn!(prompt)
self.class.new(
messages,
prompt,
client: @client,
model: @model,
temperature: @temperature,
Expand Down Expand Up @@ -118,7 +126,7 @@ def complete!(response:)

if @tools && completion.tool_call_list.any?
spawn!([
*@messages,
*@prompt.serialize,
*completion.choices.map(&:message).map(&:data),
*(completion.tool_call_list.map { |tool_call| execute_tool_call(tool_call) }),
])
Expand Down Expand Up @@ -146,23 +154,6 @@ def stream!(response:)
@stream.puts if @stream.is_a?(IO) || @stream.is_a?(StringIO)
end

# @return [Array<Hash>]
def messages
@messages.map do |content|
case content
when String then { role: Role::USER, content: }
when Hash then content
else raise Error, "Unsupported content=#{content.inspect}"
end
end
end

# @param value [Object, Array<Object>]
# @return [Array<Object>]
def arrayify(value)
value.is_a?(Array) ? value : [value]
end

# @return [HTTP::Response]
def request!
@client
Expand Down
29 changes: 29 additions & 0 deletions lib/omniai/chat/content.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# frozen_string_literal: true

module OmniAI
class Chat
# A placeholder for parts of a message. Any subclass must implement the serializable interface.
class Content
# @param context [Context] optional
#
# @return [String]
def serialize(context: nil)
raise NotImplementedError, ' # {self.class}#serialize undefined'
end

# @param data [hash]
# @param context [Context] optional
#
# @return [Content]
def self.deserialize(data, context: nil)
raise ArgumentError, "untyped data=#{data.inspect}" unless data.key?('type')

case data['type']
when 'text' then Text.deserialize(data, context:)
when /(.*)_url/ then URL.deserialize(data, context:)
else raise ArgumentError, "unknown type=#{data['type'].inspect}"
end
end
end
end
end
27 changes: 0 additions & 27 deletions lib/omniai/chat/content/file.rb

This file was deleted.

56 changes: 0 additions & 56 deletions lib/omniai/chat/content/media.rb

This file was deleted.

17 changes: 0 additions & 17 deletions lib/omniai/chat/content/text.rb

This file was deleted.

41 changes: 0 additions & 41 deletions lib/omniai/chat/content/url.rb

This file was deleted.

42 changes: 42 additions & 0 deletions lib/omniai/chat/context.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# frozen_string_literal: true

module OmniAI
class Chat
# Used to handle the setup of serializer / deserializer methods for each type.
#
# Usage:
#
# OmniAI::Chat::Context.build do |context|
# context.serializers[:prompt] = (prompt, context:) -> { ... }
# context.serializers[:message] = (prompt, context:) -> { ... }
# context.serializers[:file] = (prompt, context:) -> { ... }
# context.serializers[:text] = (prompt, context:) -> { ... }
# context.serializers[:url] = (prompt, context:) -> { ... }
# context.deserializers[:prompt] = (data, context:) -> { Prompt.new(...) }
# context.deserializers[:message] = (data, context:) -> { Message.new(...) }
# context.deserializers[:file] = (data, context:) -> { File.new(...) }
# context.deserializers[:text] = (data, context:) -> { Text.new(...) }
# context.deserializers[:url] = (data, context:) -> { URL.new(...) }
# end
class Context
# @return [Hash]
attr_accessor :serializers

# @return [Hash]
attr_reader :deserializers

# @return [Context]
def self.build(&block)
new.tap do |context|
block&.call(context)
end
end

# @return [Context]
def initialize
@serializers = {}
@deserializers = {}
end
end
end
end
Loading

0 comments on commit 641e2a3

Please sign in to comment.