Skip to content

Commit

Permalink
Merge pull request #82 from frankroeder/feat/custom_model
Browse files Browse the repository at this point in the history
Add support for custom provider
  • Loading branch information
frankroeder authored Dec 20, 2024
2 parents ccc93b5 + 034142b commit ab76cbc
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 11 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,34 @@ Below, we provide an example for [lualine](https://github.com/nvim-lualine/luali

```

## Adding a custom provider
In case your provider is not available, there is an option to resuse a present
provider with a different endpoint and a custom selection of models.
For this, the `custom` provider needs to be added to the list of providers the following way:
```lua
providers = {
custom = {
style = "openai",
api_key = os.getenv "CUSTOM_API_KEY",
endpoint = "https://api.openai.com/v1/chat/completions",
models = {
"gpt-4o-mini",
"gpt-4o",
},
-- parameters to summarize chat
topic = {
model = "gpt-4o-mini",
params = { max_completion_tokens = 64 },
},
-- default parameters
params = {
chat = { temperature = 1.1, top_p = 1 },
command = { temperature = 1.1, top_p = 1 },
},
}
}
```

## Bonus

Access parrot.nvim directly from your terminal:
Expand Down
4 changes: 3 additions & 1 deletion lua/parrot/chat_handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ end
function ChatHandler:set_provider(selected_prov, is_chat)
local endpoint = self.providers[selected_prov].endpoint
local api_key = self.providers[selected_prov].api_key
local _prov = init_provider(selected_prov, endpoint, api_key)
local style = self.providers[selected_prov].style
local models = self.providers[selected_prov].models
local _prov = init_provider(selected_prov, endpoint, api_key, style, models)
self.current_provider[is_chat and "chat" or "command"] = _prov
self.state:set_provider(_prov.name, is_chat)
self.state:refresh(self.available_providers, self.available_models)
Expand Down
14 changes: 13 additions & 1 deletion lua/parrot/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ local defaults = {
command = { temperature = 1.1, top_p = 1 },
},
},
custom = {
style = "openai",
api_key = "",
endpoint = "https://api.openai.com/v1/chat/completions",
topic_prompt = topic_prompt,
},
},
cmd_prefix = "Prt",
curl_params = {},
Expand Down Expand Up @@ -374,7 +380,13 @@ function M.setup(opts)

local available_models = {}
for _, prov_name in ipairs(M.available_providers) do
local _prov = init_provider(prov_name, M.providers[prov_name].endpoint, M.providers[prov_name].api_key)
local _prov = init_provider(
prov_name,
M.providers[prov_name].endpoint,
M.providers[prov_name].api_key,
M.providers[prov_name].style or nil,
M.providers[prov_name].models or nil
)
-- do not make an API call on startup
available_models[prov_name] = _prov:get_available_models(false)
end
Expand Down
14 changes: 8 additions & 6 deletions lua/parrot/provider/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ local M = {}
---@param endpoint string # API endpoint for the provider
---@param api_key string|table # API key or routine for authentication
---@return table # returns initialized provider
M.init_provider = function(prov_name, endpoint, api_key)
M.init_provider = function(prov_name, endpoint, api_key, style, models)
local providers = {
anthropic = Anthropic,
gemini = Gemini,
Expand All @@ -30,12 +30,14 @@ M.init_provider = function(prov_name, endpoint, api_key)
xai = xAI,
}

local ProviderClass = providers[prov_name]
if not ProviderClass then
logger.error("Unknown provider " .. prov_name)
return {}
if providers[prov_name] then
return providers[prov_name]:new(endpoint, api_key)
elseif style and providers[style] then
return providers[style]:new(endpoint, api_key, prov_name, models)
end
return ProviderClass:new(endpoint, api_key)

logger.error("Unknown provider " .. prov_name)
return {}
end

return M
2 changes: 1 addition & 1 deletion lua/parrot/provider/ollama.lua
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function Ollama:get_available_models()
logger.error("Ollama is not installed or not in PATH.")
return {}
end
local endpoint_api = self.endpoint:gsub("chat", "")
local endpoint_api = self.endpoint:gsub("chat", "")

local job = Job:new({
command = "curl",
Expand Down
9 changes: 7 additions & 2 deletions lua/parrot/provider/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ local AVAILABLE_API_PARAMETERS = {
---@param endpoint string
---@param api_key string|table
---@return OpenAI
function OpenAI:new(endpoint, api_key)
function OpenAI:new(endpoint, api_key, name, models)
return setmetatable({
endpoint = endpoint,
api_key = api_key,
name = "openai",
models = models,
name = name or "openai",
}, self)
end

Expand Down Expand Up @@ -145,6 +146,10 @@ end
---@param online boolean Whether to fetch models online
---@return string[]
function OpenAI:get_available_models(online)
if self.models then
return self.models
end

local ids = {
"gpt-4o",
"gpt-4-turbo",
Expand Down

0 comments on commit ab76cbc

Please sign in to comment.