diff --git a/Gemfile.lock b/Gemfile.lock index d0c4151..96c4381 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - omniai (1.6.6) + omniai (1.7.0) event_stream_parser http zeitwerk diff --git a/README.md b/README.md index 5c16122..aa55713 100644 --- a/README.md +++ b/README.md @@ -229,6 +229,29 @@ tempfile.close tempfile.unlink ``` +### Embeddings + +Clients that support generating embeddings (e.g. OpenAI, Mistral, etc.) convert text to embeddings via the following: + +```ruby +response = client.embed('The quick brown fox jumps over a lazy dog') +response.usage # +response.embedding # [0.1, 0.2, ...] > +``` + +Batches of text can also be converted to embeddings via the following: + +```ruby +response = client.embed([ + '', + '', +]) +response.usage # +response.embeddings.each do |embedding| + embedding # [0.1, 0.2, ...] +end +``` + ## CLI OmniAI packages a basic command line interface (CLI) to allow for exploration of various APIs. A detailed CLI documentation can be found via help: diff --git a/lib/omniai/client.rb b/lib/omniai/client.rb index 1a6cd03..3401e7b 100644 --- a/lib/omniai/client.rb +++ b/lib/omniai/client.rb @@ -173,5 +173,15 @@ def transcribe(io, model:, language: nil, prompt: nil, temperature: nil, format: def speak(input, model:, voice:, speed: nil, format: nil, &stream) raise NotImplementedError, "#{self.class.name}#speak undefined" end + + # @raise [OmniAI::Error] + # + # @param input [String] required + # @param model [String] required + # + # @return [OmniAI::Embed::Embedding] + def embed(input, model:) + raise NotImplementedError, "#{self.class.name}#embed undefined" + end end end diff --git a/lib/omniai/embed.rb b/lib/omniai/embed.rb new file mode 100644 index 0000000..4cddb02 --- /dev/null +++ b/lib/omniai/embed.rb @@ -0,0 +1,80 @@ +# frozen_string_literal: true + +module OmniAI + # An abstract class that provides a consistent interface for processing embedding requests. + # + # Usage: + # + # class OmniAI::OpenAI::Embed < OmniAI::Embed + # module Model + # SMALL = "text-embedding-3-small" + # LARGE = "text-embedding-3-large" + # ADA = "text-embedding-3-002" + # end + # + # protected + # + # # @return [Hash] + # def payload + # { ... } + # end + # + # # @return [String] + # def path + # "..." + # end + # end + # + # client.embed(input, model: "...") + class Embed + def self.process!(...) + new(...).process! + end + + # @param input [String] required + # @param client [Client] the client + # @param model [String] required + # + # @return [Response] + def initialize(input, client:, model:) + @input = input + @client = client + @model = model + end + + # @raise [Error] + # @return [Response] + def process! + response = request! + raise HTTPError, response.flush unless response.status.ok? + + parse!(response:) + end + + protected + + # @param response [HTTP::Response] + # @return [Response] + def parse!(response:) + Response.new(data: response.parse) + end + + # @return [HTTP::Response] + def request! + @client + .connection + .accept(:json) + .post(path, json: payload) + end + + # @return [Hash] + def payload + raise NotImplementedError, "#{self.class.name}#payload undefined" + end + + # @return [String] + def path + raise NotImplementedError, "#{self.class.name}#path undefined" + end + end +end diff --git a/lib/omniai/embed/response.rb b/lib/omniai/embed/response.rb new file mode 100644 index 0000000..329e212 --- /dev/null +++ b/lib/omniai/embed/response.rb @@ -0,0 +1,59 @@ +# frozen_string_literal: true + +module OmniAI + class Embed + # The response returned by the API. + class Response + # @return [Hash] + attr_accessor :data + + # @param data [Hash] + # @param context [OmniAI::Context] optional + def initialize(data:, context: nil) + @data = data + @context = context + end + + # @return [String] + def inspect + "#<#{self.class.name}>" + end + + # @return [Usage] + def usage + @usage ||= begin + deserializer = @context&.deserializers&.[](:usage) + + if deserializer + deserializer.call(@data, context: @context) + else + prompt_tokens = @data.dig('usage', 'prompt_tokens') + total_tokens = @data.dig('usage', 'total_tokens') + + Usage.new(prompt_tokens:, total_tokens:) + end + end + end + + # @param index [Integer] optional + # + # @return [Array] + def embedding(index: 0) + embeddings[index] + end + + # @return [Array>] + def embeddings + @embeddings ||= begin + deserializer = @context&.deserializers&.[](:embeddings) + + if deserializer + deserializer.call(@data, context: @context) + else + @data['data'].map { |embedding| embedding['embedding'] } + end + end + end + end + end +end diff --git a/lib/omniai/embed/usage.rb b/lib/omniai/embed/usage.rb new file mode 100644 index 0000000..4a57f1a --- /dev/null +++ b/lib/omniai/embed/usage.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +module OmniAI + class Embed + # Token usage returned by the API. + class Usage + # @return [Integer] + attr_accessor :prompt_tokens + + # @return [Integer] + attr_accessor :total_tokens + + # @param prompt_tokens Integer + # @param total_tokens Integer + def initialize(prompt_tokens:, total_tokens:) + @prompt_tokens = prompt_tokens + @total_tokens = total_tokens + end + + # @return [String] + def inspect + "#<#{self.class.name} prompt_tokens=#{@prompt_tokens} total_tokens=#{@total_tokens}>" + end + end + end +end diff --git a/lib/omniai/version.rb b/lib/omniai/version.rb index 618c866..d0d4359 100644 --- a/lib/omniai/version.rb +++ b/lib/omniai/version.rb @@ -1,5 +1,5 @@ # frozen_string_literal: true module OmniAI - VERSION = '1.6.6' + VERSION = '1.7.0' end diff --git a/spec/omniai/client_spec.rb b/spec/omniai/client_spec.rb index e4433a1..3d0956c 100644 --- a/spec/omniai/client_spec.rb +++ b/spec/omniai/client_spec.rb @@ -146,6 +146,10 @@ it { expect { client.chat('Hello!', model: '...') }.to raise_error(NotImplementedError) } end + describe '#embed' do + it { expect { client.embed('Hello!', model: '...') }.to raise_error(NotImplementedError) } + end + describe '#inspect' do it { expect(client.inspect).to eq('#') } end diff --git a/spec/omniai/embed/response_spec.rb b/spec/omniai/embed/response_spec.rb new file mode 100644 index 0000000..79be6c0 --- /dev/null +++ b/spec/omniai/embed/response_spec.rb @@ -0,0 +1,70 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Embed::Response do + subject(:response) { described_class.new(data:, context:) } + + let(:context) { nil } + let(:data) do + { + 'data' => [{ 'embedding' => [0.0] }], + 'usage' => { + 'prompt_tokens' => 2, + 'total_tokens' => 4, + }, + } + end + + describe '#inspect' do + it { expect(response.inspect).to eql('#') } + end + + describe '#embedding' do + context 'without a context' do + let(:context) { nil } + + it { expect(response.embedding).to eql([0.0]) } + end + + context 'with a context' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:embeddings] = lambda { |data, *| + data['data'].map { |entry| entry['embedding'] } + } + end + end + + it { expect(response.embedding).to eql([0.0]) } + end + end + + describe '#usage' do + it { expect(response.usage).to be_a(OmniAI::Embed::Usage) } + it { expect(response.usage.prompt_tokens).to be(2) } + it { expect(response.usage.total_tokens).to be(4) } + + context 'without a context' do + let(:context) { nil } + + it { expect(response.usage).to be_a(OmniAI::Embed::Usage) } + it { expect(response.usage.prompt_tokens).to be(2) } + it { expect(response.usage.total_tokens).to be(4) } + end + + context 'with a context' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:usage] = lambda { |data, *| + prompt_tokens = data['usage']['prompt_tokens'] + total_tokens = data['usage']['total_tokens'] + OmniAI::Embed::Usage.new(prompt_tokens:, total_tokens:) + } + end + end + + it { expect(response.usage).to be_a(OmniAI::Embed::Usage) } + it { expect(response.usage.prompt_tokens).to be(2) } + it { expect(response.usage.total_tokens).to be(4) } + end + end +end diff --git a/spec/omniai/embed/usage_spec.rb b/spec/omniai/embed/usage_spec.rb new file mode 100644 index 0000000..39f22b8 --- /dev/null +++ b/spec/omniai/embed/usage_spec.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Embed::Usage do + subject(:usage) { described_class.new(prompt_tokens:, total_tokens:) } + + let(:prompt_tokens) { 2 } + let(:total_tokens) { 4 } + + describe '#inspect' do + it { expect(usage.inspect).to eql('#') } + end + + describe '#prompt_tokens' do + it { expect(usage.prompt_tokens).to be(2) } + end + + describe '#total_tokens' do + it { expect(usage.total_tokens).to be(4) } + end +end diff --git a/spec/omniai/embed_spec.rb b/spec/omniai/embed_spec.rb new file mode 100644 index 0000000..07ae818 --- /dev/null +++ b/spec/omniai/embed_spec.rb @@ -0,0 +1,75 @@ +# frozen_string_literal: true + +class FakeClient < OmniAI::Client + def connection + HTTP.persistent('http://localhost:8080') + end +end + +class FakeEmbed < OmniAI::Embed + module Model + FAKE = 'fake' + end + + def path + '/embed' + end + + def payload + { input: @input, model: @model } + end +end + +RSpec.describe OmniAI::Embed do + subject(:embed) { described_class.new(input, model:, client:) } + + let(:input) { 'The quick brown fox jumps over a lazy dog.' } + let(:model) { '...' } + let(:client) { OmniAI::Client.new(api_key: '...') } + + describe '#path' do + it { expect { embed.send(:path) }.to raise_error(NotImplementedError) } + end + + describe '#payload' do + it { expect { embed.send(:payload) }.to raise_error(NotImplementedError) } + end + + describe '.process!' do + subject(:process!) { FakeEmbed.process!(input, model:, client:) } + + let(:client) { FakeClient.new(api_key: '...') } + let(:model) { FakeChat::Model::FAKE } + + context 'when OK' do + before do + stub_request(:post, 'http://localhost:8080/embed') + .with(body: { + input:, + model:, + }) + .to_return_json(status: 200, body: { + data: [ + { index: 0, embedding: [0.0] }, + ], + usage: { prompt_tokens: 2, total_tokens: 4 }, + }) + end + + it { expect(process!).to be_a(OmniAI::Embed::Response) } + end + + context 'when UNPROCESSABLE' do + before do + stub_request(:post, 'http://localhost:8080/embed') + .with(body: { + input:, + model:, + }) + .to_return(status: 422, body: 'An unknown error occurred.') + end + + it { expect { process! }.to raise_error(OmniAI::Error) } + end + end +end