Skip to content

Commit

Permalink
Support for embeddings via CLI
Browse files Browse the repository at this point in the history
The general usage is:

omniai embed Hello!
  • Loading branch information
ksylvest committed Aug 2, 2024
1 parent 6b56e8f commit a9e3e09
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 6 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,39 @@ Type 'exit' or 'quit' to abort.
```
The warmest place on earth is Africa.
```

### Embed

#### w/ input

```bash
omniai embed "The quick brown fox jumps over a lazy dog."
```

```
0.0
...
```

#### w/o input

```bash
omniai embed --provider="openai" --model="text-embedding-ada-002"
```

```
Type 'exit' or 'quit' to abort.
# Whe quick brown fox jumps over a lazy dog.
```

```
0.0
...
```

0.0
...

```
```
12 changes: 8 additions & 4 deletions lib/omniai/cli.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ def parse(argv = ARGV)
command = argv.shift
return if command.nil?

case command
when 'chat' then ChatHandler.handle!(stdin: @stdin, stdout: @stdout, provider: @provider, argv:)
else raise Error, "unsupported command=#{command.inspect}"
end
handler =
case command
when 'chat' then ChatHandler
when 'embed' then EmbedHandler
else raise Error, "unsupported command=#{command.inspect}"
end

handler.handle!(stdin: @stdin, stdout: @stdout, provider: @provider, argv:)
end

private
Expand Down
58 changes: 58 additions & 0 deletions lib/omniai/cli/embed_handler.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# frozen_string_literal: true

module OmniAI
class CLI
# Used for CLI usage of 'omnia embed'.
class EmbedHandler < BaseHandler
# @param argv [Array<String>]
def handle!(argv:)
parser.parse!(argv)

if argv.empty?
listen!
else
embed(input: argv.join(' '))
end
end

private

def listen!
@stdout.puts('Type "exit" or "quit" to leave.')

loop do
@stdout.print('# ')
@stdout.flush
input = @stdin.gets&.chomp

break if input.nil? || input.match?(/\A(exit|quit)\z/i)

embed(input:)
rescue Interrupt
break
end
end

# @param input [String]
def embed(input:)
response = client.embed(input, **@args)
@stdout.puts(response.embedding)
end

# @return [OptionParser]
def parser
OptionParser.new do |options|
options.banner = 'usage: omniai embed [options] "<prompt>"'

options.on('-h', '--help', 'help') do
@stdout.puts(options)
exit
end

options.on('-p', '--provider=PROVIDER', 'provider') { |provider| @provider = provider }
options.on('-m', '--model=MODEL', 'model') { |model| @args[:model] = model }
end
end
end
end
end
61 changes: 61 additions & 0 deletions spec/omniai/cli/embed_handler_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# frozen_string_literal: true

RSpec.describe OmniAI::CLI::EmbedHandler do
let(:stdin) { StringIO.new }
let(:stdout) { StringIO.new }
let(:provider) { 'fake' }
let(:model) { 'fake' }

describe '.handle!' do
subject(:handle!) { described_class.handle!(argv:, stdin:, stdout:, provider:) }

let(:client) { instance_double(OmniAI::Client) }

context 'when chatting' do
let(:argv) do
[
'--model', model,
'--provider', provider,
prompt,
]
end

before do
allow(OmniAI::Client).to receive(:find).with(provider:) { client }
allow(client).to receive(:embed) { OmniAI::Embed::Response.new(data: { 'data' => [{ 'embedding' => [0.0] }] }) }
end

context 'with a prompt' do
let(:prompt) { 'The quick brown fox jumps over a lazy dog.' }

it 'runs calls chat' do
handle!
expect(stdout.string).to eql("0.0\n")
end
end

context 'without a prompt' do
let(:prompt) { nil }

let(:stdin) { StringIO.new('The quick brown fox jumps over a lazy dog.') }

it 'runs calls listen' do
handle!
expect(stdout.string).to include('Type "exit" or "quit" to leave.')
expect(stdout.string).to include('0.0')
end
end
end

context 'with a help flag' do
%w[-h --help].each do |option|
let(:argv) { [option] }

it "prints help with '#{option}'" do
expect { handle! }.to raise_error(SystemExit)
expect(stdout.string).not_to be_empty
end
end
end
end
end
13 changes: 11 additions & 2 deletions spec/omniai/cli_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
RSpec.describe OmniAI::CLI do
subject(:cli) { described_class.new(stdin:, stdout:, provider:) }

let(:stdin) { StringIO.new(prompt) }
let(:stdin) { StringIO.new(text) }
let(:stdout) { StringIO.new }
let(:provider) { 'fake' }

let(:client) { instance_double(OmniAI::Client) }

let(:prompt) do
let(:text) do
<<~TEXT
What is the capital of Canada?
exit
Expand All @@ -31,6 +31,15 @@
end
end

context 'with a embed command' do
it 'forwards the command to OmniAI::CLI::EmbedHandler' do
allow(OmniAI::CLI::EmbedHandler).to receive(:handle!)
cli.parse(['embed', 'What is the capital of Canada?'])
expect(OmniAI::CLI::EmbedHandler).to have_received(:handle!)
.with(stdin:, stdout:, provider:, argv: ['What is the capital of Canada?'])
end
end

context 'with an unknown command' do
it 'raises an error' do
expect { cli.parse(['unknown']) }.to raise_error(OmniAI::Error, 'unsupported command="unknown"')
Expand Down

0 comments on commit a9e3e09

Please sign in to comment.