From a9e3e09b9de72543d886307f69283d0eefe1dc63 Mon Sep 17 00:00:00 2001 From: Kevin Sylvestre Date: Fri, 2 Aug 2024 12:17:19 -0700 Subject: [PATCH] Support for embeddings via CLI The general usage is: omniai embed Hello! --- README.md | 36 ++++++++++++++++ lib/omniai/cli.rb | 12 ++++-- lib/omniai/cli/embed_handler.rb | 58 +++++++++++++++++++++++++ spec/omniai/cli/embed_handler_spec.rb | 61 +++++++++++++++++++++++++++ spec/omniai/cli_spec.rb | 13 +++++- 5 files changed, 174 insertions(+), 6 deletions(-) create mode 100644 lib/omniai/cli/embed_handler.rb create mode 100644 spec/omniai/cli/embed_handler_spec.rb diff --git a/README.md b/README.md index aa55713..bebe5fd 100644 --- a/README.md +++ b/README.md @@ -286,3 +286,39 @@ Type 'exit' or 'quit' to abort. ``` The warmest place on earth is Africa. ``` + +### Embed + +#### w/ input + +```bash +omniai embed "The quick brown fox jumps over a lazy dog." +``` + +``` +0.0 +... +``` + +#### w/o input + +```bash +omniai embed --provider="openai" --model="text-embedding-ada-002" +``` + +``` +Type 'exit' or 'quit' to abort. +# Whe quick brown fox jumps over a lazy dog. +``` + +``` +0.0 +... +``` + +0.0 +... + +``` + +``` diff --git a/lib/omniai/cli.rb b/lib/omniai/cli.rb index f0a0b80..b098e0a 100644 --- a/lib/omniai/cli.rb +++ b/lib/omniai/cli.rb @@ -28,10 +28,14 @@ def parse(argv = ARGV) command = argv.shift return if command.nil? - case command - when 'chat' then ChatHandler.handle!(stdin: @stdin, stdout: @stdout, provider: @provider, argv:) - else raise Error, "unsupported command=#{command.inspect}" - end + handler = + case command + when 'chat' then ChatHandler + when 'embed' then EmbedHandler + else raise Error, "unsupported command=#{command.inspect}" + end + + handler.handle!(stdin: @stdin, stdout: @stdout, provider: @provider, argv:) end private diff --git a/lib/omniai/cli/embed_handler.rb b/lib/omniai/cli/embed_handler.rb new file mode 100644 index 0000000..cf383bf --- /dev/null +++ b/lib/omniai/cli/embed_handler.rb @@ -0,0 +1,58 @@ +# frozen_string_literal: true + +module OmniAI + class CLI + # Used for CLI usage of 'omnia embed'. + class EmbedHandler < BaseHandler + # @param argv [Array] + def handle!(argv:) + parser.parse!(argv) + + if argv.empty? + listen! + else + embed(input: argv.join(' ')) + end + end + + private + + def listen! + @stdout.puts('Type "exit" or "quit" to leave.') + + loop do + @stdout.print('# ') + @stdout.flush + input = @stdin.gets&.chomp + + break if input.nil? || input.match?(/\A(exit|quit)\z/i) + + embed(input:) + rescue Interrupt + break + end + end + + # @param input [String] + def embed(input:) + response = client.embed(input, **@args) + @stdout.puts(response.embedding) + end + + # @return [OptionParser] + def parser + OptionParser.new do |options| + options.banner = 'usage: omniai embed [options] ""' + + options.on('-h', '--help', 'help') do + @stdout.puts(options) + exit + end + + options.on('-p', '--provider=PROVIDER', 'provider') { |provider| @provider = provider } + options.on('-m', '--model=MODEL', 'model') { |model| @args[:model] = model } + end + end + end + end +end diff --git a/spec/omniai/cli/embed_handler_spec.rb b/spec/omniai/cli/embed_handler_spec.rb new file mode 100644 index 0000000..f5165a2 --- /dev/null +++ b/spec/omniai/cli/embed_handler_spec.rb @@ -0,0 +1,61 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::CLI::EmbedHandler do + let(:stdin) { StringIO.new } + let(:stdout) { StringIO.new } + let(:provider) { 'fake' } + let(:model) { 'fake' } + + describe '.handle!' do + subject(:handle!) { described_class.handle!(argv:, stdin:, stdout:, provider:) } + + let(:client) { instance_double(OmniAI::Client) } + + context 'when chatting' do + let(:argv) do + [ + '--model', model, + '--provider', provider, + prompt, + ] + end + + before do + allow(OmniAI::Client).to receive(:find).with(provider:) { client } + allow(client).to receive(:embed) { OmniAI::Embed::Response.new(data: { 'data' => [{ 'embedding' => [0.0] }] }) } + end + + context 'with a prompt' do + let(:prompt) { 'The quick brown fox jumps over a lazy dog.' } + + it 'runs calls chat' do + handle! + expect(stdout.string).to eql("0.0\n") + end + end + + context 'without a prompt' do + let(:prompt) { nil } + + let(:stdin) { StringIO.new('The quick brown fox jumps over a lazy dog.') } + + it 'runs calls listen' do + handle! + expect(stdout.string).to include('Type "exit" or "quit" to leave.') + expect(stdout.string).to include('0.0') + end + end + end + + context 'with a help flag' do + %w[-h --help].each do |option| + let(:argv) { [option] } + + it "prints help with '#{option}'" do + expect { handle! }.to raise_error(SystemExit) + expect(stdout.string).not_to be_empty + end + end + end + end +end diff --git a/spec/omniai/cli_spec.rb b/spec/omniai/cli_spec.rb index 82ae98c..0c37660 100644 --- a/spec/omniai/cli_spec.rb +++ b/spec/omniai/cli_spec.rb @@ -3,13 +3,13 @@ RSpec.describe OmniAI::CLI do subject(:cli) { described_class.new(stdin:, stdout:, provider:) } - let(:stdin) { StringIO.new(prompt) } + let(:stdin) { StringIO.new(text) } let(:stdout) { StringIO.new } let(:provider) { 'fake' } let(:client) { instance_double(OmniAI::Client) } - let(:prompt) do + let(:text) do <<~TEXT What is the capital of Canada? exit @@ -31,6 +31,15 @@ end end + context 'with a embed command' do + it 'forwards the command to OmniAI::CLI::EmbedHandler' do + allow(OmniAI::CLI::EmbedHandler).to receive(:handle!) + cli.parse(['embed', 'What is the capital of Canada?']) + expect(OmniAI::CLI::EmbedHandler).to have_received(:handle!) + .with(stdin:, stdout:, provider:, argv: ['What is the capital of Canada?']) + end + end + context 'with an unknown command' do it 'raises an error' do expect { cli.parse(['unknown']) }.to raise_error(OmniAI::Error, 'unsupported command="unknown"')