-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The general usage is: omniai embed Hello!
- Loading branch information
Showing
5 changed files
with
174 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# frozen_string_literal: true | ||
|
||
module OmniAI | ||
class CLI | ||
# Used for CLI usage of 'omnia embed'. | ||
class EmbedHandler < BaseHandler | ||
# @param argv [Array<String>] | ||
def handle!(argv:) | ||
parser.parse!(argv) | ||
|
||
if argv.empty? | ||
listen! | ||
else | ||
embed(input: argv.join(' ')) | ||
end | ||
end | ||
|
||
private | ||
|
||
def listen! | ||
@stdout.puts('Type "exit" or "quit" to leave.') | ||
|
||
loop do | ||
@stdout.print('# ') | ||
@stdout.flush | ||
input = @stdin.gets&.chomp | ||
|
||
break if input.nil? || input.match?(/\A(exit|quit)\z/i) | ||
|
||
embed(input:) | ||
rescue Interrupt | ||
break | ||
end | ||
end | ||
|
||
# @param input [String] | ||
def embed(input:) | ||
response = client.embed(input, **@args) | ||
@stdout.puts(response.embedding) | ||
end | ||
|
||
# @return [OptionParser] | ||
def parser | ||
OptionParser.new do |options| | ||
options.banner = 'usage: omniai embed [options] "<prompt>"' | ||
|
||
options.on('-h', '--help', 'help') do | ||
@stdout.puts(options) | ||
exit | ||
end | ||
|
||
options.on('-p', '--provider=PROVIDER', 'provider') { |provider| @provider = provider } | ||
options.on('-m', '--model=MODEL', 'model') { |model| @args[:model] = model } | ||
end | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# frozen_string_literal: true | ||
|
||
RSpec.describe OmniAI::CLI::EmbedHandler do | ||
let(:stdin) { StringIO.new } | ||
let(:stdout) { StringIO.new } | ||
let(:provider) { 'fake' } | ||
let(:model) { 'fake' } | ||
|
||
describe '.handle!' do | ||
subject(:handle!) { described_class.handle!(argv:, stdin:, stdout:, provider:) } | ||
|
||
let(:client) { instance_double(OmniAI::Client) } | ||
|
||
context 'when chatting' do | ||
let(:argv) do | ||
[ | ||
'--model', model, | ||
'--provider', provider, | ||
prompt, | ||
] | ||
end | ||
|
||
before do | ||
allow(OmniAI::Client).to receive(:find).with(provider:) { client } | ||
allow(client).to receive(:embed) { OmniAI::Embed::Response.new(data: { 'data' => [{ 'embedding' => [0.0] }] }) } | ||
end | ||
|
||
context 'with a prompt' do | ||
let(:prompt) { 'The quick brown fox jumps over a lazy dog.' } | ||
|
||
it 'runs calls chat' do | ||
handle! | ||
expect(stdout.string).to eql("0.0\n") | ||
end | ||
end | ||
|
||
context 'without a prompt' do | ||
let(:prompt) { nil } | ||
|
||
let(:stdin) { StringIO.new('The quick brown fox jumps over a lazy dog.') } | ||
|
||
it 'runs calls listen' do | ||
handle! | ||
expect(stdout.string).to include('Type "exit" or "quit" to leave.') | ||
expect(stdout.string).to include('0.0') | ||
end | ||
end | ||
end | ||
|
||
context 'with a help flag' do | ||
%w[-h --help].each do |option| | ||
let(:argv) { [option] } | ||
|
||
it "prints help with '#{option}'" do | ||
expect { handle! }.to raise_error(SystemExit) | ||
expect(stdout.string).not_to be_empty | ||
end | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters