Skip to content

Commit

Permalink
Support for embeddings
Browse files Browse the repository at this point in the history
This supports a consistent API for converting text into embeddings (vectors).
  • Loading branch information
ksylvest committed Aug 2, 2024
1 parent 08dc852 commit 6b56e8f
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
omniai (1.6.6)
omniai (1.7.0)
event_stream_parser
http
zeitwerk
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,29 @@ tempfile.close
tempfile.unlink
```

### Embeddings

Clients that support generating embeddings (e.g. OpenAI, Mistral, etc.) convert text to embeddings via the following:

```ruby
response = client.embed('The quick brown fox jumps over a lazy dog')
response.usage # <OmniAI::Embed::Usage prompt_tokens=5 total_tokens=5>
response.embedding # [0.1, 0.2, ...] >
```

Batches of text can also be converted to embeddings via the following:

```ruby
response = client.embed([
'',
'',
])
response.usage # <OmniAI::Embed::Usage prompt_tokens=5 total_tokens=5>
response.embeddings.each do |embedding|
embedding # [0.1, 0.2, ...]
end
```

## CLI

OmniAI packages a basic command line interface (CLI) to allow for exploration of various APIs. A detailed CLI documentation can be found via help:
Expand Down
10 changes: 10 additions & 0 deletions lib/omniai/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,15 @@ def transcribe(io, model:, language: nil, prompt: nil, temperature: nil, format:
def speak(input, model:, voice:, speed: nil, format: nil, &stream)
raise NotImplementedError, "#{self.class.name}#speak undefined"
end

# @raise [OmniAI::Error]
#
# @param input [String] required
# @param model [String] required
#
# @return [OmniAI::Embed::Embedding]
def embed(input, model:)
raise NotImplementedError, "#{self.class.name}#embed undefined"
end
end
end
80 changes: 80 additions & 0 deletions lib/omniai/embed.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# frozen_string_literal: true

module OmniAI
# An abstract class that provides a consistent interface for processing embedding requests.
#
# Usage:
#
# class OmniAI::OpenAI::Embed < OmniAI::Embed
# module Model
# SMALL = "text-embedding-3-small"
# LARGE = "text-embedding-3-large"
# ADA = "text-embedding-3-002"
# end
#
# protected
#
# # @return [Hash]
# def payload
# { ... }
# end
#
# # @return [String]
# def path
# "..."
# end
# end
#
# client.embed(input, model: "...")
class Embed
def self.process!(...)
new(...).process!
end

# @param input [String] required
# @param client [Client] the client
# @param model [String] required
#
# @return [Response]
def initialize(input, client:, model:)
@input = input
@client = client
@model = model
end

# @raise [Error]
# @return [Response]
def process!
response = request!
raise HTTPError, response.flush unless response.status.ok?

parse!(response:)
end

protected

# @param response [HTTP::Response]
# @return [Response]
def parse!(response:)
Response.new(data: response.parse)
end

# @return [HTTP::Response]
def request!
@client
.connection
.accept(:json)
.post(path, json: payload)
end

# @return [Hash]
def payload
raise NotImplementedError, "#{self.class.name}#payload undefined"
end

# @return [String]
def path
raise NotImplementedError, "#{self.class.name}#path undefined"
end
end
end
59 changes: 59 additions & 0 deletions lib/omniai/embed/response.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# frozen_string_literal: true

module OmniAI
class Embed
# The response returned by the API.
class Response
# @return [Hash]
attr_accessor :data

# @param data [Hash]
# @param context [OmniAI::Context] optional
def initialize(data:, context: nil)
@data = data
@context = context
end

# @return [String]
def inspect
"#<#{self.class.name}>"
end

# @return [Usage]
def usage
@usage ||= begin
deserializer = @context&.deserializers&.[](:usage)

if deserializer
deserializer.call(@data, context: @context)
else
prompt_tokens = @data.dig('usage', 'prompt_tokens')
total_tokens = @data.dig('usage', 'total_tokens')

Usage.new(prompt_tokens:, total_tokens:)
end
end
end

# @param index [Integer] optional
#
# @return [Array<Float>]
def embedding(index: 0)
embeddings[index]
end

# @return [Array<Array<Float>>]
def embeddings
@embeddings ||= begin
deserializer = @context&.deserializers&.[](:embeddings)

if deserializer
deserializer.call(@data, context: @context)
else
@data['data'].map { |embedding| embedding['embedding'] }
end
end
end
end
end
end
26 changes: 26 additions & 0 deletions lib/omniai/embed/usage.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# frozen_string_literal: true

module OmniAI
class Embed
# Token usage returned by the API.
class Usage
# @return [Integer]
attr_accessor :prompt_tokens

# @return [Integer]
attr_accessor :total_tokens

# @param prompt_tokens Integer
# @param total_tokens Integer
def initialize(prompt_tokens:, total_tokens:)
@prompt_tokens = prompt_tokens
@total_tokens = total_tokens
end

# @return [String]
def inspect
"#<#{self.class.name} prompt_tokens=#{@prompt_tokens} total_tokens=#{@total_tokens}>"
end
end
end
end
2 changes: 1 addition & 1 deletion lib/omniai/version.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# frozen_string_literal: true

module OmniAI
VERSION = '1.6.6'
VERSION = '1.7.0'
end
4 changes: 4 additions & 0 deletions spec/omniai/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@
it { expect { client.chat('Hello!', model: '...') }.to raise_error(NotImplementedError) }
end

describe '#embed' do
it { expect { client.embed('Hello!', model: '...') }.to raise_error(NotImplementedError) }
end

describe '#inspect' do
it { expect(client.inspect).to eq('#<OmniAI::Client api_key="abc***" host="http://localhost:8080">') }
end
Expand Down
70 changes: 70 additions & 0 deletions spec/omniai/embed/response_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# frozen_string_literal: true

RSpec.describe OmniAI::Embed::Response do
subject(:response) { described_class.new(data:, context:) }

let(:context) { nil }
let(:data) do
{
'data' => [{ 'embedding' => [0.0] }],
'usage' => {
'prompt_tokens' => 2,
'total_tokens' => 4,
},
}
end

describe '#inspect' do
it { expect(response.inspect).to eql('#<OmniAI::Embed::Response>') }
end

describe '#embedding' do
context 'without a context' do
let(:context) { nil }

it { expect(response.embedding).to eql([0.0]) }
end

context 'with a context' do
let(:context) do
OmniAI::Context.build do |context|
context.deserializers[:embeddings] = lambda { |data, *|
data['data'].map { |entry| entry['embedding'] }
}
end
end

it { expect(response.embedding).to eql([0.0]) }
end
end

describe '#usage' do
it { expect(response.usage).to be_a(OmniAI::Embed::Usage) }
it { expect(response.usage.prompt_tokens).to be(2) }
it { expect(response.usage.total_tokens).to be(4) }

context 'without a context' do
let(:context) { nil }

it { expect(response.usage).to be_a(OmniAI::Embed::Usage) }
it { expect(response.usage.prompt_tokens).to be(2) }
it { expect(response.usage.total_tokens).to be(4) }
end

context 'with a context' do
let(:context) do
OmniAI::Context.build do |context|
context.deserializers[:usage] = lambda { |data, *|
prompt_tokens = data['usage']['prompt_tokens']
total_tokens = data['usage']['total_tokens']
OmniAI::Embed::Usage.new(prompt_tokens:, total_tokens:)
}
end
end

it { expect(response.usage).to be_a(OmniAI::Embed::Usage) }
it { expect(response.usage.prompt_tokens).to be(2) }
it { expect(response.usage.total_tokens).to be(4) }
end
end
end
20 changes: 20 additions & 0 deletions spec/omniai/embed/usage_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# frozen_string_literal: true

RSpec.describe OmniAI::Embed::Usage do
subject(:usage) { described_class.new(prompt_tokens:, total_tokens:) }

let(:prompt_tokens) { 2 }
let(:total_tokens) { 4 }

describe '#inspect' do
it { expect(usage.inspect).to eql('#<OmniAI::Embed::Usage prompt_tokens=2 total_tokens=4>') }
end

describe '#prompt_tokens' do
it { expect(usage.prompt_tokens).to be(2) }
end

describe '#total_tokens' do
it { expect(usage.total_tokens).to be(4) }
end
end
Loading

0 comments on commit 6b56e8f

Please sign in to comment.