From 74bb0064aaad109583a2e9b0c7405add88a42d25 Mon Sep 17 00:00:00 2001 From: Stephan Janssen Date: Tue, 10 Dec 2024 10:11:40 +0100 Subject: [PATCH] Feature #175 Select which LLM providers you want to enable/disable --- .../com/devoxx/genie/model/LanguageModel.java | 3 +- .../model/enumarations/ModelProvider.java | 36 ++-- .../genie/chatmodel/ChatModelProvider.java | 14 +- .../anthropic/AnthropicChatModelFactory.java | 2 +- .../AzureOpenAIChatModelFactory.java | 2 +- .../deepinfra/DeepInfraChatModelFactory.java | 2 +- .../deepseek/DeepSeekChatModelFactory.java | 2 +- .../chatmodel/exo/ExoChatModelFactory.java | 16 +- .../google/GoogleChatModelFactory.java | 2 +- .../gpt4all/GPT4AllChatModelFactory.java | 2 +- .../chatmodel/groq/GroqChatModelFactory.java | 2 +- .../chatmodel/jan/JanChatModelFactory.java | 2 +- .../lmstudio/LMStudioChatModelFactory.java | 2 +- .../mistral/MistralChatModelFactory.java | 2 +- .../ollama/OllamaChatModelFactory.java | 2 +- .../openai/OpenAIChatModelFactory.java | 2 +- .../OpenRouterChatModelFactory.java | 2 +- .../controller/ActionPanelController.java | 10 +- .../service/LLMModelRegistryService.java | 186 +++++++++--------- .../genie/service/LLMProviderService.java | 20 +- .../genie/service/ProjectContentService.java | 4 +- .../genie/service/PromptExecutionService.java | 2 +- .../ui/DevoxxGenieToolWindowContent.java | 1 - .../genie/ui/panel/ActionButtonsPanel.java | 16 +- .../genie/ui/panel/ChatResponsePanel.java | 2 +- .../genie/ui/panel/LlmProviderPanel.java | 29 ++- .../settings/AbstractSettingsComponent.java | 8 + .../ui/settings/DevoxxGenieStateService.java | 33 +++- .../settings/llm/LLMProvidersComponent.java | 150 +++++++++++--- .../llm/LLMProvidersConfigurable.java | 78 +++++++- .../genie/util/ChatMessageContextUtil.java | 2 +- .../genie/util/DefaultLLMSettingsUtil.java | 16 +- .../devoxx/genie/util/LLMProviderUtil.java | 18 +- .../service/PromptExecutionServiceIT.java | 20 +- 34 files changed, 446 insertions(+), 244 deletions(-) diff --git a/core/src/main/java/com/devoxx/genie/model/LanguageModel.java b/core/src/main/java/com/devoxx/genie/model/LanguageModel.java index 73492219..2f0da97e 100644 --- a/core/src/main/java/com/devoxx/genie/model/LanguageModel.java +++ b/core/src/main/java/com/devoxx/genie/model/LanguageModel.java @@ -3,7 +3,6 @@ import com.devoxx.genie.model.enumarations.ModelProvider; import lombok.Builder; import lombok.Data; -import lombok.EqualsAndHashCode; import org.jetbrains.annotations.NotNull; import java.util.Comparator; @@ -21,7 +20,7 @@ public class LanguageModel implements Comparable { private int contextWindow; public LanguageModel() { - this(ModelProvider.OpenAI, "", "", false, 0.0, 0.0, 0); + this(ModelProvider.OPENAI, "", "", false, 0.0, 0.0, 0); } public LanguageModel(ModelProvider provider, diff --git a/core/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java b/core/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java index c3657143..782e5786 100644 --- a/core/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java +++ b/core/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java @@ -8,27 +8,28 @@ @Getter public enum ModelProvider { - Ollama("Ollama", Type.LOCAL), - LMStudio("LMStudio", Type.LOCAL), - GPT4All("GPT4All", Type.LOCAL), - Jan("Jan", Type.LOCAL), - OpenAI("OpenAI", Type.CLOUD), - Anthropic("Anthropic", Type.CLOUD), - Mistral("Mistral", Type.CLOUD), - Groq("Groq", Type.CLOUD), - DeepInfra("DeepInfra", Type.CLOUD), - Google("Google", Type.CLOUD), - Exo("Exo (Experimental)", Type.LOCAL), - LLaMA("LLaMA.c++", Type.LOCAL), - OpenRouter("OpenRouter", Type.CLOUD), - DeepSeek("DeepSeek", Type.CLOUD), - Jlama("Jlama (Experimental /w REST API)", Type.LOCAL), - AzureOpenAI("AzureOpenAI", Type.OPTIONAL); + OPENAI("OpenAI", Type.CLOUD), + ANTHROPIC("Anthropic", Type.CLOUD), + MISTRAL("Mistral", Type.CLOUD), + GROQ("Groq", Type.CLOUD), + DEEP_INFRA("DeepInfra", Type.CLOUD), + GOOGLE("Google", Type.CLOUD), + LLAMA("LLaMA.c++", Type.LOCAL), + OPEN_ROUTER("OpenRouter", Type.CLOUD), + DEEP_SEEK("DeepSeek", Type.CLOUD), + AZURE_OPEN_AI("AzureOpenAI", Type.CLOUD), + OLLAMA("Ollama", Type.LOCAL), + LMSTUDIO("LMStudio", Type.LOCAL), + GPT_4_ALL("GPT4All", Type.LOCAL), + JAN("Jan", Type.LOCAL), + JLAMA("Jlama (Experimental /w REST API)", Type.LOCAL), + EXO("Exo (Experimental)", Type.LOCAL), + CUSTOM_OPEN_AI("CustomOpenAI", Type.OPTIONAL); public enum Type { LOCAL, // Local Providers CLOUD, // Cloud Providers - OPTIONAL // Optional Providers(Need to be enabled from settings, due to inconvenient setup) + OPTIONAL // Optional Providers (Need to be enabled from settings, due to inconvenient setup) } private final String name; @@ -39,6 +40,7 @@ public enum Type { this.type = type; } + @Override public String toString() { return name; } diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java index 4663d2e7..83396fcf 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java @@ -20,7 +20,7 @@ @Setter public class ChatModelProvider { - private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OpenAI; // Choose an appropriate default + private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OPENAI; // Choose an appropriate default public ChatLanguageModel getChatLanguageModel(@NotNull ChatMessageContext chatMessageContext) { ChatModel chatModel = initChatModel(chatMessageContext); @@ -72,22 +72,22 @@ private void setLocalBaseUrl(@NotNull LanguageModel languageModel, DevoxxGenieSettingsService stateService) { // Set base URL for local providers switch (languageModel.getProvider()) { - case LMStudio: + case LMSTUDIO: chatModel.setBaseUrl(stateService.getLmstudioModelUrl()); break; - case Ollama: + case OLLAMA: chatModel.setBaseUrl(stateService.getOllamaModelUrl()); break; - case GPT4All: + case GPT_4_ALL: chatModel.setBaseUrl(stateService.getGpt4allModelUrl()); break; - case Exo: + case EXO: chatModel.setBaseUrl(stateService.getExoModelUrl()); break; - case LLaMA: + case LLAMA: chatModel.setBaseUrl(stateService.getLlamaCPPUrl()); break; - case Jlama: + case JLAMA: chatModel.setBaseUrl(stateService.getJlamaUrl()); break; // Add other local providers as needed diff --git a/src/main/java/com/devoxx/genie/chatmodel/anthropic/AnthropicChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/anthropic/AnthropicChatModelFactory.java index 8798729a..d25edebc 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/anthropic/AnthropicChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/anthropic/AnthropicChatModelFactory.java @@ -45,6 +45,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.Anthropic); + return getModels(ModelProvider.ANTHROPIC); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/azureopenai/AzureOpenAIChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/azureopenai/AzureOpenAIChatModelFactory.java index a88cbc74..17b892cf 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/azureopenai/AzureOpenAIChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/azureopenai/AzureOpenAIChatModelFactory.java @@ -59,7 +59,7 @@ public String getApiKey() { @Override public List getModels() { return List.of(LanguageModel.builder() - .provider(ModelProvider.AzureOpenAI) + .provider(ModelProvider.AZURE_OPEN_AI) .modelName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment()) .displayName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment()) .inputCost(0.0) diff --git a/src/main/java/com/devoxx/genie/chatmodel/deepinfra/DeepInfraChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/deepinfra/DeepInfraChatModelFactory.java index f8fbbdad..2b161d6a 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/deepinfra/DeepInfraChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/deepinfra/DeepInfraChatModelFactory.java @@ -49,6 +49,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.DeepInfra); + return getModels(ModelProvider.DEEP_INFRA); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/deepseek/DeepSeekChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/deepseek/DeepSeekChatModelFactory.java index f139a277..a6608ddc 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/deepseek/DeepSeekChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/deepseek/DeepSeekChatModelFactory.java @@ -50,6 +50,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.DeepSeek); + return getModels(ModelProvider.DEEP_SEEK); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/exo/ExoChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/exo/ExoChatModelFactory.java index ca6e825b..0d08060e 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/exo/ExoChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/exo/ExoChatModelFactory.java @@ -55,7 +55,7 @@ public List getModels() { .modelName("llama-3.1-405b") .displayName("Llama 3.1 405B") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(131_000) @@ -65,7 +65,7 @@ public List getModels() { .modelName("llama-3.1-8b") .displayName("Llama 3.1 8B") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(8_000) @@ -75,7 +75,7 @@ public List getModels() { .modelName("llama-3.1-70b") .displayName("Llama 3.1 70B") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(131_000) @@ -85,7 +85,7 @@ public List getModels() { .modelName("llama-3-8b") .displayName("Llama 3 8B") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(8_000) @@ -96,7 +96,7 @@ public List getModels() { .modelName("mistral-nemo") .displayName("Mistral Nemo") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(8_000) @@ -107,7 +107,7 @@ public List getModels() { .modelName("mistral-large") .displayName("Mistral Large") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(8_000) @@ -118,7 +118,7 @@ public List getModels() { .modelName("deepseek-coder-v2-lite") .displayName("Deepseek Coder V2 Lite") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(8_000) @@ -129,7 +129,7 @@ public List getModels() { .modelName("llava-1.5-7b-hf") .displayName("Llava 1.5 7B HF") .apiKeyUsed(false) - .provider(ModelProvider.Exo) + .provider(ModelProvider.EXO) .outputCost(0) .inputCost(0) .contextWindow(8_000) diff --git a/src/main/java/com/devoxx/genie/chatmodel/google/GoogleChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/google/GoogleChatModelFactory.java index cbcbf249..10721be6 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/google/GoogleChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/google/GoogleChatModelFactory.java @@ -30,6 +30,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.Google); + return getModels(ModelProvider.GOOGLE); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java index f505a356..607a07ee 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java @@ -43,7 +43,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch @Override public List getModels() { LanguageModel lmStudio = LanguageModel.builder() - .provider(ModelProvider.GPT4All) + .provider(ModelProvider.GPT_4_ALL) .modelName("GPT4All") .inputCost(0) .outputCost(0) diff --git a/src/main/java/com/devoxx/genie/chatmodel/groq/GroqChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/groq/GroqChatModelFactory.java index 8e753ce3..35788b31 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/groq/GroqChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/groq/GroqChatModelFactory.java @@ -47,6 +47,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.Groq); + return getModels(ModelProvider.GROQ); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java index 9eca3b55..35b25d34 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java @@ -71,7 +71,7 @@ public List getModels() { for (Data model : models) { CompletableFuture future = CompletableFuture.runAsync(() -> { LanguageModel languageModel = LanguageModel.builder() - .provider(ModelProvider.Jan) + .provider(ModelProvider.JAN) .modelName(model.getId()) .displayName(model.getName()) .inputCost(0) diff --git a/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java index 252931fe..20427d8c 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java @@ -74,7 +74,7 @@ public List getModels() { for (LMStudioModelEntryDTO model : lmStudioModels) { CompletableFuture future = CompletableFuture.runAsync(() -> { LanguageModel languageModel = LanguageModel.builder() - .provider(ModelProvider.LMStudio) + .provider(ModelProvider.LMSTUDIO) .modelName(model.getId()) .displayName(model.getId()) .inputCost(0) diff --git a/src/main/java/com/devoxx/genie/chatmodel/mistral/MistralChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/mistral/MistralChatModelFactory.java index 11a33860..1ec4dc8e 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/mistral/MistralChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/mistral/MistralChatModelFactory.java @@ -47,6 +47,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.Mistral); + return getModels(ModelProvider.MISTRAL); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java index 03e803bb..0f799ed2 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java @@ -75,7 +75,7 @@ public List getModels() { try { int contextWindow = OllamaApiService.getModelContext(model.getName()); LanguageModel languageModel = LanguageModel.builder() - .provider(ModelProvider.Ollama) + .provider(ModelProvider.OLLAMA) .modelName(model.getName()) .displayName(model.getName()) .inputCost(0) diff --git a/src/main/java/com/devoxx/genie/chatmodel/openai/OpenAIChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/openai/OpenAIChatModelFactory.java index 6c7c4ffa..0de6f785 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/openai/OpenAIChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/openai/OpenAIChatModelFactory.java @@ -59,6 +59,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.OpenAI); + return getModels(ModelProvider.OPENAI); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/openrouter/OpenRouterChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/openrouter/OpenRouterChatModelFactory.java index 67015b3b..1a7edc61 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/openrouter/OpenRouterChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/openrouter/OpenRouterChatModelFactory.java @@ -86,7 +86,7 @@ public List getModels() { double outputCost = convertAndScalePrice(model.getPricing().getCompletion()); LanguageModel languageModel = LanguageModel.builder() - .provider(ModelProvider.OpenRouter) + .provider(ModelProvider.OPEN_ROUTER) .modelName(model.getId()) .displayName(model.getName()) .inputCost(inputCost) diff --git a/src/main/java/com/devoxx/genie/controller/ActionPanelController.java b/src/main/java/com/devoxx/genie/controller/ActionPanelController.java index bf0965d5..368b1e5c 100644 --- a/src/main/java/com/devoxx/genie/controller/ActionPanelController.java +++ b/src/main/java/com/devoxx/genie/controller/ActionPanelController.java @@ -155,10 +155,10 @@ private boolean validateAndPreparePrompt(String actionCommand, private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsService stateService) { ModelProvider selectedProvider = (ModelProvider) modelProviderComboBox.getSelectedItem(); if (selectedProvider != null && - (selectedProvider.equals(ModelProvider.LMStudio) || - selectedProvider.equals(ModelProvider.GPT4All) || - selectedProvider.equals(ModelProvider.Jlama) || - selectedProvider.equals(ModelProvider.LLaMA))) { + (selectedProvider.equals(ModelProvider.LMSTUDIO) || + selectedProvider.equals(ModelProvider.GPT_4_ALL) || + selectedProvider.equals(ModelProvider.JLAMA) || + selectedProvider.equals(ModelProvider.LLAMA))) { return LanguageModel.builder() .provider(selectedProvider) .apiKeyUsed(false) @@ -169,7 +169,7 @@ private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsSer } else { String modelName = stateService.getSelectedLanguageModel(project.getLocationHash()); return LanguageModel.builder() - .provider(selectedProvider != null ? selectedProvider : ModelProvider.OpenAI) + .provider(selectedProvider != null ? selectedProvider : ModelProvider.OPENAI) .modelName(modelName) .apiKeyUsed(false) .inputCost(0) diff --git a/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java b/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java index c05180ff..6ebbd77a 100644 --- a/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java +++ b/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java @@ -39,9 +39,9 @@ public LLMModelRegistryService() { private void addAnthropicModels() { String claude2 = CLAUDE_2.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claude2, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claude2, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claude2) .displayName("Claude 2.0") .inputCost(8) @@ -51,9 +51,9 @@ private void addAnthropicModels() { .build()); String claude21 = CLAUDE_2_1.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claude21, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claude21, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claude21) .displayName("Claude 2.1") .inputCost(8) @@ -63,9 +63,9 @@ private void addAnthropicModels() { .build()); String claudeHaiku3 = CLAUDE_3_HAIKU_20240307.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claudeHaiku3, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claudeHaiku3, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claudeHaiku3) .displayName("Claude 3 Haiku") .inputCost(0.25) @@ -75,9 +75,9 @@ private void addAnthropicModels() { .build()); String claudeSonnet3 = CLAUDE_3_SONNET_20240229.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claudeSonnet3, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claudeSonnet3, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claudeSonnet3) .displayName("Claude 3 Sonnet") .inputCost(3) @@ -87,9 +87,9 @@ private void addAnthropicModels() { .build()); String claudeOpus3 = CLAUDE_3_OPUS_20240229.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claudeOpus3, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claudeOpus3, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claudeOpus3) .displayName("Claude 3 Opus") .inputCost(15) @@ -99,9 +99,9 @@ private void addAnthropicModels() { .build()); String claudeSonnet35 = CLAUDE_3_5_SONNET_20241022.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claudeSonnet35, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claudeSonnet35, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claudeSonnet35) .displayName("Claude 3.5 Sonnet") .inputCost(3) @@ -111,9 +111,9 @@ private void addAnthropicModels() { .build()); String claudeHaiku35 = CLAUDE_3_5_HAIKU_20241022.toString(); - models.put(ModelProvider.Anthropic.getName() + "-" + claudeHaiku35, + models.put(ModelProvider.ANTHROPIC.getName() + "-" + claudeHaiku35, LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName(claudeHaiku35) .displayName("Claude 3.5 Haiku") .inputCost(1) @@ -126,9 +126,9 @@ private void addAnthropicModels() { private void addOpenAiModels() { String o1Mini = "o1-mini"; - models.put(ModelProvider.OpenAI.getName() + ":" + o1Mini, + models.put(ModelProvider.OPENAI.getName() + ":" + o1Mini, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(o1Mini) .displayName("o1 mini") .inputCost(5) @@ -138,9 +138,9 @@ private void addOpenAiModels() { .build()); String o1Preview = "o1-preview"; - models.put(ModelProvider.OpenAI.getName() + ":" + o1Preview, + models.put(ModelProvider.OPENAI.getName() + ":" + o1Preview, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(o1Preview) .displayName("o1 preview") .inputCost(10) @@ -150,9 +150,9 @@ private void addOpenAiModels() { .build()); String gpt35Turbo = GPT_3_5_TURBO.toString(); - models.put(ModelProvider.OpenAI.getName() + ":" + gpt35Turbo, + models.put(ModelProvider.OPENAI.getName() + ":" + gpt35Turbo, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(gpt35Turbo) .displayName("GPT 3.5 Turbo") .inputCost(0.5) @@ -162,9 +162,9 @@ private void addOpenAiModels() { .build()); String gpt4 = GPT_4.toString(); - models.put(ModelProvider.OpenAI.getName() + ":" + gpt4, + models.put(ModelProvider.OPENAI.getName() + ":" + gpt4, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(gpt4) .displayName("GPT 4") .inputCost(30) @@ -174,9 +174,9 @@ private void addOpenAiModels() { .build()); String gpt4TurboPreview = GPT_4_TURBO_PREVIEW.toString(); - models.put(ModelProvider.OpenAI.getName() + ":" + gpt4TurboPreview, + models.put(ModelProvider.OPENAI.getName() + ":" + gpt4TurboPreview, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(gpt4TurboPreview) .displayName("GPT 4 Turbo") .inputCost(10) @@ -186,9 +186,9 @@ private void addOpenAiModels() { .build()); String gpt4o = GPT_4_O.toString(); - models.put(ModelProvider.OpenAI.getName() + ":" + gpt4o, + models.put(ModelProvider.OPENAI.getName() + ":" + gpt4o, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(gpt4o) .displayName("GPT 4o") .inputCost(5) @@ -198,9 +198,9 @@ private void addOpenAiModels() { .build()); String gpt4oMini = GPT_4_O_MINI.toString(); - models.put(ModelProvider.OpenAI.getName() + ":" + gpt4oMini, + models.put(ModelProvider.OPENAI.getName() + ":" + gpt4oMini, LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName(gpt4oMini) .displayName("GPT 4o mini") .inputCost(0.15) @@ -212,9 +212,9 @@ private void addOpenAiModels() { private void addDeepInfraModels() { String metaLlama31Instruct405B = "meta-llama/Meta-Llama-3.1-405B-Instruct"; - models.put(ModelProvider.DeepInfra.getName() + ":" + metaLlama31Instruct405B, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + metaLlama31Instruct405B, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(metaLlama31Instruct405B) .displayName("Meta Llama 3.1 405B") .inputCost(2.7) @@ -224,9 +224,9 @@ private void addDeepInfraModels() { .build()); String metaLlama31Instruct70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"; - models.put(ModelProvider.DeepInfra.getName() + ":" + metaLlama31Instruct70B, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + metaLlama31Instruct70B, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(metaLlama31Instruct70B) .displayName("Meta Llama 3.1 70B") .inputCost(0.35) @@ -236,9 +236,9 @@ private void addDeepInfraModels() { .build()); String metaLlama31Instruct8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"; - models.put(ModelProvider.DeepInfra.getName() + ":" + metaLlama31Instruct8B, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + metaLlama31Instruct8B, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(metaLlama31Instruct8B) .displayName("Meta Llama 3.1 8B") .inputCost(0.055) @@ -248,9 +248,9 @@ private void addDeepInfraModels() { .build()); String mistralNemoInstruct2407 = "mistralai/Mistral-Nemo-Instruct-2407"; - models.put(ModelProvider.DeepInfra.getName() + ":" + mistralNemoInstruct2407, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + mistralNemoInstruct2407, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(mistralNemoInstruct2407) .displayName("Mistral Nemo 12B") .inputCost(0.13) @@ -260,9 +260,9 @@ private void addDeepInfraModels() { .build()); String mistralMixtral8x7BInstruct = "mistralai/Mixtral-8x7B-Instruct-v0.1"; - models.put(ModelProvider.DeepInfra.getName() + ":" + mistralMixtral8x7BInstruct, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + mistralMixtral8x7BInstruct, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(mistralMixtral8x7BInstruct) .displayName("Mixtral 8x7B Instruct v0.1") .inputCost(0.24) @@ -272,9 +272,9 @@ private void addDeepInfraModels() { .build()); String mistralMixtral8x22BInstruct = "mistralai/Mixtral-8x22B-Instruct-v0.1"; - models.put(ModelProvider.DeepInfra.getName() + ":" + mistralMixtral8x22BInstruct, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + mistralMixtral8x22BInstruct, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(mistralMixtral8x22BInstruct) .displayName("Mixtral 8x22B Instruct v0.1") .inputCost(0.65) @@ -284,9 +284,9 @@ private void addDeepInfraModels() { .build()); String mistralMistral7BInstruct = "mistralai/Mistral-7B-Instruct-v0.3"; - models.put(ModelProvider.DeepInfra.getName() + ":" + mistralMistral7BInstruct, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + mistralMistral7BInstruct, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(mistralMistral7BInstruct) .displayName("Mistral 7B Instruct v0.3") .inputCost(0.07) @@ -296,9 +296,9 @@ private void addDeepInfraModels() { .build()); String microsoftWizardLM8x22B = "microsoft/WizardLM-2-8x22B"; - models.put(ModelProvider.DeepInfra.getName() + ":" + microsoftWizardLM8x22B, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + microsoftWizardLM8x22B, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(microsoftWizardLM8x22B) .displayName("Wizard LM 2 8x22B") .inputCost(0.5) @@ -308,9 +308,9 @@ private void addDeepInfraModels() { .build()); String microsoftWizardLM7B = "microsoft/WizardLM-2-7B"; - models.put(ModelProvider.DeepInfra.getName() + ":" + microsoftWizardLM7B, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + microsoftWizardLM7B, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(microsoftWizardLM7B) .displayName("Wizard LM 2 7B") .inputCost(0.055) @@ -320,9 +320,9 @@ private void addDeepInfraModels() { .build()); String openchat35 = "openchat/openchat_3.5"; - models.put(ModelProvider.DeepInfra.getName() + ":" + openchat35, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + openchat35, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(openchat35) .displayName("OpenChat 3.5") .inputCost(0.055) @@ -332,9 +332,9 @@ private void addDeepInfraModels() { .build()); String googleGemma9b = "google/gemma-2-9b-it"; - models.put(ModelProvider.DeepInfra.getName() + ":" + googleGemma9b, + models.put(ModelProvider.DEEP_INFRA.getName() + ":" + googleGemma9b, LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName(googleGemma9b) .displayName("Gemma 2 9B it") .inputCost(0.06) @@ -352,9 +352,9 @@ private void addDeepInfraModels() { private void addGeminiModels() { String gemini15Flash = "gemini-1.5-flash"; - models.put(ModelProvider.Google.getName() + ":" + gemini15Flash, + models.put(ModelProvider.GOOGLE.getName() + ":" + gemini15Flash, LanguageModel.builder() - .provider(ModelProvider.Google) + .provider(ModelProvider.GOOGLE) .modelName(gemini15Flash) .displayName("Gemini 1.5 Flash") .inputCost(0.0375) @@ -364,9 +364,9 @@ private void addGeminiModels() { .build()); String gemini15Pro = "gemini-1.5-pro"; - models.put(ModelProvider.Google.getName() + ":" + gemini15Pro, + models.put(ModelProvider.GOOGLE.getName() + ":" + gemini15Pro, LanguageModel.builder() - .provider(ModelProvider.Google) + .provider(ModelProvider.GOOGLE) .modelName(gemini15Pro) .displayName("Gemini 1.5 Pro") .inputCost(7) @@ -376,9 +376,9 @@ private void addGeminiModels() { .build()); String gemini15ProExp0801 = "gemini-1.5-pro-exp-0801"; - models.put(ModelProvider.Google.getName() + ":" + gemini15ProExp0801, + models.put(ModelProvider.GOOGLE.getName() + ":" + gemini15ProExp0801, LanguageModel.builder() - .provider(ModelProvider.Google) + .provider(ModelProvider.GOOGLE) .modelName(gemini15ProExp0801) .displayName("Gemini 1.5 Pro 0801") .inputCost(7) @@ -388,9 +388,9 @@ private void addGeminiModels() { .build()); String gemini10Pro = "gemini-1.0-pro"; - models.put(ModelProvider.Google.getName() + ":" + gemini10Pro, + models.put(ModelProvider.GOOGLE.getName() + ":" + gemini10Pro, LanguageModel.builder() - .provider(ModelProvider.Google) + .provider(ModelProvider.GOOGLE) .modelName(gemini10Pro) .displayName("Gemini 1.0 Pro") .inputCost(0.5) @@ -400,9 +400,9 @@ private void addGeminiModels() { .build()); String geminiExp1206 = "gemini-exp-1206"; - models.put(ModelProvider.Google.getName() + ":" + gemini10Pro, + models.put(ModelProvider.GOOGLE.getName() + ":" + gemini10Pro, LanguageModel.builder() - .provider(ModelProvider.Google) + .provider(ModelProvider.GOOGLE) .modelName(geminiExp1206) .displayName("Gemini Exp 1206") .inputCost(0) @@ -415,9 +415,9 @@ private void addGeminiModels() { private void addGroqModels() { String gemma7b = "gemma-7b-it"; - models.put(ModelProvider.Groq.getName() + ":" + gemma7b, + models.put(ModelProvider.GROQ.getName() + ":" + gemma7b, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(gemma7b) .displayName("Gemma 7B it") .inputCost(0.07) @@ -427,9 +427,9 @@ private void addGroqModels() { .build()); String gemma2 = "gemma2-9b-it"; - models.put(ModelProvider.Groq.getName() + ":" + gemma2, + models.put(ModelProvider.GROQ.getName() + ":" + gemma2, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(gemma2) .displayName("Gemma 2 9B it") .inputCost(0.2) @@ -439,9 +439,9 @@ private void addGroqModels() { .build()); String llama3 = "llama3-8b-8192"; - models.put(ModelProvider.Groq.getName() + ":" + llama3, + models.put(ModelProvider.GROQ.getName() + ":" + llama3, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(llama3) .displayName("Llama 3 8B") .inputCost(0.05) @@ -451,9 +451,9 @@ private void addGroqModels() { .build()); String llama31Versatile = "llama-3.1-70b-versatile"; - models.put(ModelProvider.Groq.getName() + ":" + llama31Versatile, + models.put(ModelProvider.GROQ.getName() + ":" + llama31Versatile, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(llama31Versatile) .displayName("Llama 3.1 70B") .inputCost(0.59) @@ -463,9 +463,9 @@ private void addGroqModels() { .build()); String llama31Instant = "llama-3.1-8b-instant"; - models.put(ModelProvider.Groq.getName() + ":" + llama31Instant, + models.put(ModelProvider.GROQ.getName() + ":" + llama31Instant, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(llama31Instant) .displayName("Llama 3.1 8B") .inputCost(0.05) @@ -475,9 +475,9 @@ private void addGroqModels() { .build()); String mixtral8x7b = "mixtral-8x7b-32768"; - models.put(ModelProvider.Groq.getName() + ":" + mixtral8x7b, + models.put(ModelProvider.GROQ.getName() + ":" + mixtral8x7b, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(mixtral8x7b) .displayName("Mixtral 8x7B") .inputCost(0.24) @@ -487,9 +487,9 @@ private void addGroqModels() { .build()); String llama370b = "llama3-70b-8192"; - models.put(ModelProvider.Groq.getName() + ":" + llama370b, + models.put(ModelProvider.GROQ.getName() + ":" + llama370b, LanguageModel.builder() - .provider(ModelProvider.Groq) + .provider(ModelProvider.GROQ) .modelName(llama370b) .displayName("Llama 3 70B") .inputCost(0.59) @@ -501,9 +501,9 @@ private void addGroqModels() { private void addMistralModels() { String openMistral7B = OPEN_MISTRAL_7B.toString(); - models.put(ModelProvider.Mistral.getName() + ":" + openMistral7B, + models.put(ModelProvider.MISTRAL.getName() + ":" + openMistral7B, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(openMistral7B) .displayName("Mistral 7B") .inputCost(0.25) @@ -513,9 +513,9 @@ private void addMistralModels() { .build()); String openMixtral8x7B = OPEN_MIXTRAL_8x7B.toString(); - models.put(ModelProvider.Mistral.getName() + ":" + openMixtral8x7B, + models.put(ModelProvider.MISTRAL.getName() + ":" + openMixtral8x7B, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(openMixtral8x7B) .displayName("Mistral 8x7B") .inputCost(0.7) @@ -525,9 +525,9 @@ private void addMistralModels() { .build()); String openMixtral8x22B = OPEN_MIXTRAL_8X22B.toString(); - models.put(ModelProvider.Mistral.getName() + ":" + openMixtral8x22B, + models.put(ModelProvider.MISTRAL.getName() + ":" + openMixtral8x22B, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(openMixtral8x22B) .displayName("Mistral 8x22b") .inputCost(2) @@ -537,9 +537,9 @@ private void addMistralModels() { .build()); String mistralSmallLatest = MISTRAL_SMALL_LATEST.toString(); - models.put(ModelProvider.Mistral.getName() + ":" + mistralSmallLatest, + models.put(ModelProvider.MISTRAL.getName() + ":" + mistralSmallLatest, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(mistralSmallLatest) .displayName("Mistral Small") .inputCost(1) @@ -549,9 +549,9 @@ private void addMistralModels() { .build()); String mistralMediumLatest = MISTRAL_MEDIUM_LATEST.toString(); - models.put(ModelProvider.Mistral.getName() + ":" + mistralMediumLatest, + models.put(ModelProvider.MISTRAL.getName() + ":" + mistralMediumLatest, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(mistralMediumLatest) .displayName("Mistral Medium") .inputCost(2.7) @@ -561,9 +561,9 @@ private void addMistralModels() { .build()); String mistralLargeLatest = MISTRAL_LARGE_LATEST.toString(); - models.put(ModelProvider.Mistral.getName() + ":" + mistralLargeLatest, + models.put(ModelProvider.MISTRAL.getName() + ":" + mistralLargeLatest, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(mistralLargeLatest) .displayName("Mistral Large") .inputCost(4) @@ -573,9 +573,9 @@ private void addMistralModels() { .build()); String codestral = "codestral-2405"; - models.put(ModelProvider.Mistral.getName() + ":" + codestral, + models.put(ModelProvider.MISTRAL.getName() + ":" + codestral, LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName(codestral) .displayName("Codestral") .inputCost(1) @@ -587,9 +587,9 @@ private void addMistralModels() { private void addDeepSeekModels() { String coder = "deepseek-coder"; - models.put(ModelProvider.DeepSeek.getName() + ":" + coder, + models.put(ModelProvider.DEEP_SEEK.getName() + ":" + coder, LanguageModel.builder() - .provider(ModelProvider.DeepSeek) + .provider(ModelProvider.DEEP_SEEK) .modelName(coder) .displayName("DeepSeek Coder") .inputCost(0.14) @@ -599,9 +599,9 @@ private void addDeepSeekModels() { .build()); String chat = "deepseek-chat"; - models.put(ModelProvider.DeepSeek.getName() + ":" + chat, + models.put(ModelProvider.DEEP_SEEK.getName() + ":" + chat, LanguageModel.builder() - .provider(ModelProvider.DeepSeek) + .provider(ModelProvider.DEEP_SEEK) .modelName(chat) .displayName("DeepSeek Chat") .inputCost(0.14) @@ -621,7 +621,7 @@ public List getModels() { String apiKey = openRouterChatModelFactory.getApiKey(); if (apiKey != null && !apiKey.isEmpty()) { openRouterChatModelFactory.getModels().forEach(model -> - modelsCopy.put(ModelProvider.OpenRouter.getName() + ":" + model.getModelName(), model)); + modelsCopy.put(ModelProvider.OPEN_ROUTER.getName() + ":" + model.getModelName(), model)); } return new ArrayList<>(modelsCopy.values()); diff --git a/src/main/java/com/devoxx/genie/service/LLMProviderService.java b/src/main/java/com/devoxx/genie/service/LLMProviderService.java index 4eb9ebfb..52f656aa 100644 --- a/src/main/java/com/devoxx/genie/service/LLMProviderService.java +++ b/src/main/java/com/devoxx/genie/service/LLMProviderService.java @@ -21,15 +21,15 @@ public class LLMProviderService { static { DevoxxGenieStateService stateService = DevoxxGenieStateService.getInstance(); - providerKeyMap.put(OpenAI, stateService::getOpenAIKey); - providerKeyMap.put(Anthropic, stateService::getAnthropicKey); - providerKeyMap.put(Mistral, stateService::getMistralKey); - providerKeyMap.put(Groq, stateService::getGroqKey); - providerKeyMap.put(DeepInfra, stateService::getDeepInfraKey); - providerKeyMap.put(Google, stateService::getGeminiKey); - providerKeyMap.put(DeepSeek, stateService::getDeepSeekKey); - providerKeyMap.put(OpenRouter, stateService::getOpenRouterKey); - providerKeyMap.put(AzureOpenAI, stateService::getAzureOpenAIKey); + providerKeyMap.put(OPENAI, stateService::getOpenAIKey); + providerKeyMap.put(ANTHROPIC, stateService::getAnthropicKey); + providerKeyMap.put(MISTRAL, stateService::getMistralKey); + providerKeyMap.put(GROQ, stateService::getGroqKey); + providerKeyMap.put(DEEP_INFRA, stateService::getDeepInfraKey); + providerKeyMap.put(GOOGLE, stateService::getGeminiKey); + providerKeyMap.put(DEEP_SEEK, stateService::getDeepSeekKey); + providerKeyMap.put(OPEN_ROUTER, stateService::getOpenRouterKey); + providerKeyMap.put(AZURE_OPEN_AI, stateService::getAzureOpenAIKey); } @NotNull @@ -72,7 +72,7 @@ private List getOptionalProviders() { List optionalModelProviders = new ArrayList<>(); if (DevoxxGenieStateService.getInstance().getShowAzureOpenAIFields()) { - optionalModelProviders.add(AzureOpenAI); + optionalModelProviders.add(AZURE_OPEN_AI); } return optionalModelProviders; diff --git a/src/main/java/com/devoxx/genie/service/ProjectContentService.java b/src/main/java/com/devoxx/genie/service/ProjectContentService.java index 0b684c33..8da83c2b 100644 --- a/src/main/java/com/devoxx/genie/service/ProjectContentService.java +++ b/src/main/java/com/devoxx/genie/service/ProjectContentService.java @@ -77,9 +77,9 @@ public CompletableFuture getDirectoryContent(Project project, public static Encoding getEncodingForProvider(@NotNull ModelProvider provider) { return switch (provider) { - case OpenAI, Anthropic, Google, AzureOpenAI -> + case OPENAI, ANTHROPIC, GOOGLE, AZURE_OPEN_AI -> Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); - case Mistral, DeepInfra, Groq, DeepSeek, OpenRouter -> + case MISTRAL, DEEP_INFRA, GROQ, DEEP_SEEK, OPEN_ROUTER -> // These often use the Llama tokenizer or similar Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.R50K_BASE); default -> diff --git a/src/main/java/com/devoxx/genie/service/PromptExecutionService.java b/src/main/java/com/devoxx/genie/service/PromptExecutionService.java index 82b93231..06cf75cf 100644 --- a/src/main/java/com/devoxx/genie/service/PromptExecutionService.java +++ b/src/main/java/com/devoxx/genie/service/PromptExecutionService.java @@ -123,7 +123,7 @@ private boolean isCanceled() { ChatMemoryService.getInstance().add(chatMessageContext.getProject(), response.content()); return response; } catch (Exception e) { - if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.Jan)) { + if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.JAN)) { throw new ModelNotActiveException("Selected Jan model is not active. Download and make it active or add API Key in Jan settings."); } ChatMemoryService.getInstance().removeLast(chatMessageContext.getProject()); diff --git a/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java b/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java index 80aa63c3..4bd1542c 100644 --- a/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java +++ b/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java @@ -107,7 +107,6 @@ private void setupListeners() { /** * Create the top panel. - * * @return the top panel */ private @NotNull JPanel createTopPanel() { diff --git a/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java b/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java index 39153559..839703fe 100644 --- a/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java +++ b/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java @@ -293,14 +293,14 @@ private void removeProjectContext() { } private boolean isSupportedProvider(@NotNull ModelProvider modelProvider) { - return modelProvider.equals(ModelProvider.Google) || - modelProvider.equals(ModelProvider.Anthropic) || - modelProvider.equals(ModelProvider.OpenAI) || - modelProvider.equals(ModelProvider.Mistral) || - modelProvider.equals(ModelProvider.DeepSeek) || - modelProvider.equals(ModelProvider.OpenRouter) || - modelProvider.equals(ModelProvider.DeepInfra) || - modelProvider.equals(ModelProvider.Ollama); + return modelProvider.equals(ModelProvider.GOOGLE) || + modelProvider.equals(ModelProvider.ANTHROPIC) || + modelProvider.equals(ModelProvider.OPENAI) || + modelProvider.equals(ModelProvider.MISTRAL) || + modelProvider.equals(ModelProvider.DEEP_SEEK) || + modelProvider.equals(ModelProvider.OPEN_ROUTER) || + modelProvider.equals(ModelProvider.DEEP_INFRA) || + modelProvider.equals(ModelProvider.OLLAMA); } private void addProjectToContext() { diff --git a/src/main/java/com/devoxx/genie/ui/panel/ChatResponsePanel.java b/src/main/java/com/devoxx/genie/ui/panel/ChatResponsePanel.java index b6b7eea1..4180bf2f 100644 --- a/src/main/java/com/devoxx/genie/ui/panel/ChatResponsePanel.java +++ b/src/main/java/com/devoxx/genie/ui/panel/ChatResponsePanel.java @@ -161,7 +161,7 @@ private void processGitDiff(@NotNull ChatMessageContext chatMessageContext, @Not * @return the updated token usage */ private static TokenUsage calcOllamaInputTokenCount(@NotNull ChatMessageContext chatMessageContext, TokenUsage tokenUsage) { - if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.Ollama)) { + if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.OLLAMA)) { int inputContextTokens = 0; if (chatMessageContext.getContext() != null) { Encoding encodingForProvider = ProjectContentService.getEncodingForProvider(chatMessageContext.getLanguageModel().getProvider()); diff --git a/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java b/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java index ab8d08e0..312b121a 100644 --- a/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java +++ b/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java @@ -34,7 +34,7 @@ public class LlmProviderPanel extends JBPanel implements LLMSe private static final Logger LOG = Logger.getInstance(LlmProviderPanel.class); - private final Project project; + private final transient Project project; @Getter private final JPanel contentPanel = new JPanel(); @@ -56,7 +56,7 @@ public class LlmProviderPanel extends JBPanel implements LLMSe * * @param project the project instance */ - public LlmProviderPanel(Project project) { + public LlmProviderPanel(@NotNull Project project) { super(new BorderLayout()); this.project = project; @@ -109,11 +109,32 @@ public LlmProviderPanel(Project project) { /** * Add the LLM providers to combobox. - * Only show the cloud-based LLM providers for which we have an API Key. + * Only show the enabled LLM providers. */ public void addModelProvidersToComboBox() { LLMProviderService providerService = LLMProviderService.getInstance(); + DevoxxGenieStateService stateService = DevoxxGenieStateService.getInstance(); + providerService.getAvailableModelProviders().stream() + .filter(provider -> switch (provider) { + case OLLAMA -> stateService.isOllamaEnabled(); + case LMSTUDIO -> stateService.isLmStudioEnabled(); + case GPT_4_ALL -> stateService.isGpt4AllEnabled(); + case JAN -> stateService.isJanEnabled(); + case EXO -> stateService.isExoEnabled(); + case LLAMA -> stateService.isLlamaCPPEnabled(); + case JLAMA -> stateService.isJlamaEnabled(); + case CUSTOM_OPEN_AI -> stateService.isCustomOpenAIEnabled(); + case OPENAI -> stateService.isOpenAIEnabled(); + case MISTRAL -> stateService.isMistralEnabled(); + case ANTHROPIC -> stateService.isAnthropicEnabled(); + case GROQ -> stateService.isGroqEnabled(); + case DEEP_INFRA -> stateService.isDeepInfraEnabled(); + case GOOGLE -> stateService.isGoogleEnabled(); + case DEEP_SEEK -> stateService.isDeepSeekEnabled(); + case OPEN_ROUTER -> stateService.isOpenRouterEnabled(); + case AZURE_OPEN_AI -> stateService.isAzureOpenAIEnabled(); + }) .distinct() .sorted(Comparator.comparing(ModelProvider::getName)) .forEach(modelProviderComboBox::addItem); @@ -128,7 +149,7 @@ private void refreshModels() { return; } - if (selectedProvider == ModelProvider.LMStudio || selectedProvider == ModelProvider.Ollama || selectedProvider == ModelProvider.Jan) { + if (selectedProvider == ModelProvider.LMSTUDIO || selectedProvider == ModelProvider.OLLAMA || selectedProvider == ModelProvider.JAN) { ApplicationManager.getApplication().invokeLater(() -> { refreshButton.setEnabled(false); diff --git a/src/main/java/com/devoxx/genie/ui/settings/AbstractSettingsComponent.java b/src/main/java/com/devoxx/genie/ui/settings/AbstractSettingsComponent.java index 70a53cc4..b6bea27e 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/AbstractSettingsComponent.java +++ b/src/main/java/com/devoxx/genie/ui/settings/AbstractSettingsComponent.java @@ -53,6 +53,14 @@ protected void addSettingRow(@NotNull JPanel panel, @NotNull GridBagConstraints gbc.gridy++; } + protected void addProviderSettingRow(JPanel panel, GridBagConstraints gbc, String label, JCheckBox checkbox, JComponent urlComponent) { + JPanel providerPanel = new JPanel(new BorderLayout(5, 0)); + providerPanel.add(checkbox, BorderLayout.WEST); + providerPanel.add(urlComponent, BorderLayout.CENTER); + + addSettingRow(panel, gbc, label, providerPanel); + } + protected @NotNull JComponent createTextWithPasswordButton(JComponent jComponent, String url) { return createTextWithLinkButton(jComponent, url); } diff --git a/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java b/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java index f580ca48..370c9f04 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java +++ b/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java @@ -33,7 +33,7 @@ public static DevoxxGenieStateService getInstance() { } private List excludedFiles = new ArrayList<>(Arrays.asList( - "package-lock.json", "yarn.lock", "pom.xml", "build.gradle", "settings.gradle" + "package-lock.json", "yarn.lock", "pom.xml", "build.gradle", "settings.gradle" )); private List customPrompts = new ArrayList<>(); @@ -69,6 +69,26 @@ public static DevoxxGenieStateService getInstance() { private String jlamaUrl = JLAMA_MODEL_URL; private String customOpenAIUrl = ""; + // + private boolean isOllamaEnabled = true; + private boolean isLmStudioEnabled = true; + private boolean isGpt4AllEnabled = true; + private boolean isJanEnabled = true; + private boolean isExoEnabled = true; + private boolean isLlamaCPPEnabled = true; + private boolean isJlamaEnabled = true; + private boolean isCustomOpenAIEnabled = false; + + private boolean isOpenAIEnabled = false; + private boolean isMistralEnabled = false; + private boolean isAnthropicEnabled = false; + private boolean isGroqEnabled = false; + private boolean isDeepInfraEnabled = false; + private boolean isGoogleEnabled = false; + private boolean isDeepSeekEnabled = false; + private boolean isOpenRouterEnabled = false; + private boolean isAzureOpenAIEnabled = false; + // LLM API Keys private String openAIKey = ""; private String mistralKey = ""; @@ -221,9 +241,16 @@ public void setSelectedProvider(@NotNull String projectLocation, String selected public String getSelectedProvider(@NotNull String projectLocation) { if (lastSelectedProvider != null) { - return lastSelectedProvider.getOrDefault(projectLocation, ModelProvider.Ollama.getName()); + return lastSelectedProvider.getOrDefault(projectLocation, ModelProvider.OLLAMA.getName()); } else { - return ModelProvider.Ollama.getName(); + return ModelProvider.OLLAMA.getName(); } } + + public boolean isAzureOpenAIEnabled() { + return showAzureOpenAIFields && + !azureOpenAIKey.isEmpty() && + !azureOpenAIEndpoint.isEmpty() && + !azureOpenAIDeployment.isEmpty(); + } } diff --git a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java index 6c0f35b2..51a946bd 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java +++ b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java @@ -57,7 +57,40 @@ public class LLMProvidersComponent extends AbstractSettingsComponent { private final JCheckBox streamModeCheckBox = new JCheckBox("", stateService.getStreamMode()); @Getter - private final JCheckBox enableAzureOpenAI = new JCheckBox("", stateService.getShowAzureOpenAIFields()); + private final JCheckBox ollamaEnabledCheckBox = new JCheckBox("", stateService.isOllamaEnabled()); + @Getter + private final JCheckBox lmStudioEnabledCheckBox = new JCheckBox("", stateService.isLmStudioEnabled()); + @Getter + private final JCheckBox gpt4AllEnabledCheckBox = new JCheckBox("", stateService.isGpt4AllEnabled()); + @Getter + private final JCheckBox janEnabledCheckBox = new JCheckBox("", stateService.isJanEnabled()); + @Getter + private final JCheckBox exoEnabledCheckBox = new JCheckBox("", stateService.isExoEnabled()); + @Getter + private final JCheckBox llamaCPPEnabledCheckBox = new JCheckBox("", stateService.isLlamaCPPEnabled()); + @Getter + private final JCheckBox jlamaEnabledCheckBox = new JCheckBox("", stateService.isJlamaEnabled()); + @Getter + private final JCheckBox customOpenAIEnabledCheckBox = new JCheckBox("", stateService.isCustomOpenAIEnabled()); + + @Getter + private final JCheckBox openAIEnabledCheckBox = new JCheckBox("", stateService.isOpenAIEnabled()); + @Getter + private final JCheckBox mistralEnabledCheckBox = new JCheckBox("", stateService.isMistralEnabled()); + @Getter + private final JCheckBox anthropicEnabledCheckBox = new JCheckBox("", stateService.isAnthropicEnabled()); + @Getter + private final JCheckBox groqEnabledCheckBox = new JCheckBox("", stateService.isGroqEnabled()); + @Getter + private final JCheckBox deepInfraEnabledCheckBox = new JCheckBox("", stateService.isDeepInfraEnabled()); + @Getter + private final JCheckBox geminiEnabledCheckBox = new JCheckBox("", stateService.isGoogleEnabled()); + @Getter + private final JCheckBox deepSeekEnabledCheckBox = new JCheckBox("", stateService.isDeepSeekEnabled()); + @Getter + private final JCheckBox openRouterEnabledCheckBox = new JCheckBox("", stateService.isOpenRouterEnabled()); + @Getter + private final JCheckBox enableAzureOpenAICheckBox = new JCheckBox("", stateService.getShowAzureOpenAIFields()); private final java.util.List azureComponents = new ArrayList<>(); @@ -78,25 +111,43 @@ public JPanel createPanel() { addSection(panel, gbc, "Local Large Language Response"); addSettingRow(panel, gbc, "Enable Stream Mode (Beta)", streamModeCheckBox); + // Local LLM Providers section addSection(panel, gbc, "Local LLM Providers"); - addSettingRow(panel, gbc, "Ollama URL", createTextWithLinkButton(ollamaModelUrlField, "https://ollama.com")); - addSettingRow(panel, gbc, "LMStudio URL", createTextWithLinkButton(lmStudioModelUrlField, "https://lmstudio.ai/")); - addSettingRow(panel, gbc, "GPT4All URL", createTextWithLinkButton(gpt4AllModelUrlField, "https://gpt4all.io/")); - addSettingRow(panel, gbc, "Jan URL", createTextWithLinkButton(janModelUrlField, "https://jan.ai/download")); - addSettingRow(panel, gbc, "Exo URL", createTextWithLinkButton(exoModelUrlField, "https://github.com/exo-explore/exo")); - addSettingRow(panel, gbc, "LLaMA.c++ URL", createTextWithLinkButton(llamaCPPModelUrlField, "https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md")); - addSettingRow(panel, gbc, "JLama URL", createTextWithLinkButton(jlamaModelUrlField, "https://github.com/tjake/Jlama")); - addSettingRow(panel, gbc, "Custom OpenAI URL", customOpenAIUrlField); + addProviderSettingRow(panel, gbc, "Ollama URL", ollamaEnabledCheckBox, + createTextWithLinkButton(ollamaModelUrlField, "https://ollama.com")); + addProviderSettingRow(panel, gbc, "LMStudio URL", lmStudioEnabledCheckBox, + createTextWithLinkButton(lmStudioModelUrlField, "https://lmstudio.ai/")); + addProviderSettingRow(panel, gbc, "GPT4All URL", gpt4AllEnabledCheckBox, + createTextWithLinkButton(gpt4AllModelUrlField, "https://gpt4all.io/")); + addProviderSettingRow(panel, gbc, "Jan URL", janEnabledCheckBox, + createTextWithLinkButton(janModelUrlField, "https://jan.ai/download")); + addProviderSettingRow(panel, gbc, "Exo URL", exoEnabledCheckBox, + createTextWithLinkButton(exoModelUrlField, "https://github.com/exo-explore/exo")); + addProviderSettingRow(panel, gbc, "LLaMA.c++ URL", llamaCPPEnabledCheckBox, + createTextWithLinkButton(llamaCPPModelUrlField, "https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md")); + addProviderSettingRow(panel, gbc, "JLama URL", jlamaEnabledCheckBox, + createTextWithLinkButton(jlamaModelUrlField, "https://github.com/tjake/Jlama")); + addProviderSettingRow(panel, gbc, "Custom OpenAI URL", customOpenAIEnabledCheckBox, customOpenAIUrlField); + + // Cloud LLM Providers section addSection(panel, gbc, "Cloud LLM Providers"); - addSettingRow(panel, gbc, "OpenAI API Key", createTextWithPasswordButton(openAIKeyField, "https://platform.openai.com/api-keys")); - addSettingRow(panel, gbc, "Mistral API Key", createTextWithPasswordButton(mistralApiKeyField, "https://console.mistral.ai/api-keys")); - addSettingRow(panel, gbc, "Anthropic API Key", createTextWithPasswordButton(anthropicApiKeyField, "https://console.anthropic.com/settings/keys")); - addSettingRow(panel, gbc, "Groq API Key", createTextWithPasswordButton(groqApiKeyField, "https://console.groq.com/keys")); - addSettingRow(panel, gbc, "DeepInfra API Key", createTextWithPasswordButton(deepInfraApiKeyField, "https://deepinfra.com/dash/api_keys")); - addSettingRow(panel, gbc, "Google Gemini API Key", createTextWithPasswordButton(geminiApiKeyField, "https://aistudio.google.com/app/apikey")); - addSettingRow(panel, gbc, "Deep Seek API Key", createTextWithPasswordButton(deepSeekApiKeyField, "https://platform.deepseek.com/api_keys")); - addSettingRow(panel, gbc, "Open Router API Key", createTextWithPasswordButton(openRouterApiKeyField, "https://openrouter.ai/settings/keys")); + addProviderSettingRow(panel, gbc, "OpenAI API Key", openAIEnabledCheckBox, + createTextWithPasswordButton(openAIKeyField, "https://platform.openai.com/api-keys")); + addProviderSettingRow(panel, gbc, "Mistral API Key", mistralEnabledCheckBox, + createTextWithPasswordButton(mistralApiKeyField, "https://console.mistral.ai/api-keys")); + addProviderSettingRow(panel, gbc, "Anthropic API Key", anthropicEnabledCheckBox, + createTextWithPasswordButton(anthropicApiKeyField, "https://console.anthropic.com/settings/keys")); + addProviderSettingRow(panel, gbc, "Groq API Key", groqEnabledCheckBox, + createTextWithPasswordButton(groqApiKeyField, "https://console.groq.com/keys")); + addProviderSettingRow(panel, gbc, "DeepInfra API Key", deepInfraEnabledCheckBox, + createTextWithPasswordButton(deepInfraApiKeyField, "https://deepinfra.com/dash/api_keys")); + addProviderSettingRow(panel, gbc, "Google Gemini API Key", geminiEnabledCheckBox, + createTextWithPasswordButton(geminiApiKeyField, "https://aistudio.google.com/app/apikey")); + addProviderSettingRow(panel, gbc, "Deep Seek API Key", deepSeekEnabledCheckBox, + createTextWithPasswordButton(deepSeekApiKeyField, "https://platform.deepseek.com/api_keys")); + addProviderSettingRow(panel, gbc, "Open Router API Key", openRouterEnabledCheckBox, + createTextWithPasswordButton(openRouterApiKeyField, "https://openrouter.ai/settings/keys")); addAzureOpenAIPanel(panel, gbc); @@ -106,15 +157,60 @@ public JPanel createPanel() { return panel; } + private void updateUrlFieldState(@NotNull JCheckBox checkbox, + @NotNull JComponent urlComponent) { + urlComponent.setEnabled(checkbox.isSelected()); + } + + @Override + public void addListeners() { + // Keep existing listeners + enableAzureOpenAICheckBox.addItemListener(event -> { + azureComponents.forEach(comp -> comp.setVisible(event.getStateChange() == ItemEvent.SELECTED)); + panel.revalidate(); + panel.repaint(); + }); + + // Add new listeners for enable/disable checkboxes + ollamaEnabledCheckBox.addItemListener(e -> updateUrlFieldState(ollamaEnabledCheckBox, ollamaModelUrlField)); + lmStudioEnabledCheckBox.addItemListener(e -> updateUrlFieldState(lmStudioEnabledCheckBox, lmStudioModelUrlField)); + gpt4AllEnabledCheckBox.addItemListener(e -> updateUrlFieldState(gpt4AllEnabledCheckBox, gpt4AllModelUrlField)); + janEnabledCheckBox.addItemListener(e -> updateUrlFieldState(janEnabledCheckBox, janModelUrlField)); + exoEnabledCheckBox.addItemListener(e -> updateUrlFieldState(exoEnabledCheckBox, exoModelUrlField)); + llamaCPPEnabledCheckBox.addItemListener(e -> updateUrlFieldState(llamaCPPEnabledCheckBox, llamaCPPModelUrlField)); + jlamaEnabledCheckBox.addItemListener(e -> updateUrlFieldState(jlamaEnabledCheckBox, jlamaModelUrlField)); + customOpenAIEnabledCheckBox.addItemListener(e -> updateUrlFieldState(customOpenAIEnabledCheckBox, customOpenAIUrlField)); + + openAIEnabledCheckBox.addItemListener(e -> updateUrlFieldState(openAIEnabledCheckBox, openAIKeyField)); + mistralEnabledCheckBox.addItemListener(e -> updateUrlFieldState(mistralEnabledCheckBox, mistralApiKeyField)); + anthropicEnabledCheckBox.addItemListener(e -> updateUrlFieldState(anthropicEnabledCheckBox, anthropicApiKeyField)); + groqEnabledCheckBox.addItemListener(e -> updateUrlFieldState(groqEnabledCheckBox, groqApiKeyField)); + deepInfraEnabledCheckBox.addItemListener(e -> updateUrlFieldState(deepInfraEnabledCheckBox, deepInfraApiKeyField)); + geminiEnabledCheckBox.addItemListener(e -> updateUrlFieldState(geminiEnabledCheckBox, geminiApiKeyField)); + deepSeekEnabledCheckBox.addItemListener(e -> updateUrlFieldState(deepSeekEnabledCheckBox, deepSeekApiKeyField)); + openRouterEnabledCheckBox.addItemListener(e -> updateUrlFieldState(openRouterEnabledCheckBox, openRouterApiKeyField)); + enableAzureOpenAICheckBox.addItemListener(e -> updateUrlFieldState(enableAzureOpenAICheckBox, azureOpenAIEndpointField)); + } + + // In LLMProvidersComponent.java + private boolean isAzureConfigValid() { + return !azureOpenAIKeyField.getPassword().toString().isEmpty() + && !azureOpenAIEndpointField.getText().trim().isEmpty() + && !azureOpenAIDeploymentField.getText().trim().isEmpty(); + } + private void addAzureOpenAIPanel(JPanel panel, GridBagConstraints gbc) { - addSettingRow(panel, gbc, "Enable Azure OpenAI Provider", enableAzureOpenAI); + addSettingRow(panel, gbc, "Enable Azure OpenAI Provider", enableAzureOpenAICheckBox); - addAzureComponentsSettingRow(panel, gbc, "Azure OpenAI Endpoint", createTextWithLinkButton(azureOpenAIEndpointField, "https://learn.microsoft.com/en-us/azure/ai-services/openai/overview")); - addAzureComponentsSettingRow(panel, gbc, "Azure OpenAI Deployment", createTextWithLinkButton(azureOpenAIDeploymentField, "https://learn.microsoft.com/en-us/azure/ai-services/openai/overview")); - addAzureComponentsSettingRow(panel, gbc, "Azure OpenAI API Key", createTextWithPasswordButton(azureOpenAIKeyField, "https://learn.microsoft.com/en-us/azure/ai-services/openai/overview")); + addAzureComponentsSettingRow(panel, gbc, "Azure OpenAI Endpoint", + createTextWithLinkButton(azureOpenAIEndpointField, "https://learn.microsoft.com/en-us/azure/ai-services/openai/overview")); + addAzureComponentsSettingRow(panel, gbc, "Azure OpenAI Deployment", + createTextWithLinkButton(azureOpenAIDeploymentField, "https://learn.microsoft.com/en-us/azure/ai-services/openai/overview")); + addAzureComponentsSettingRow(panel, gbc, "Azure OpenAI API Key", + createTextWithPasswordButton(azureOpenAIKeyField, "https://learn.microsoft.com/en-us/azure/ai-services/openai/overview")); // Set initial visibility - boolean azureEnabled = enableAzureOpenAI.isSelected(); + boolean azureEnabled = enableAzureOpenAICheckBox.isSelected(); for (JComponent comp : azureComponents) { comp.setVisible(azureEnabled); } @@ -135,14 +231,4 @@ private void addAzureComponentsSettingRow(@NotNull JPanel panel, @NotNull GridBa gbc.insets = JBUI.insets(5); } - - @Override - public void addListeners() { - enableAzureOpenAI.addItemListener(event -> { - azureComponents.forEach(comp -> comp.setVisible(event.getStateChange() == ItemEvent.SELECTED)); - panel.revalidate(); - panel.repaint(); - }); - - } } diff --git a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java index 9e1d03b4..5092bb16 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java +++ b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java @@ -72,11 +72,30 @@ public boolean isModified() { isModified |= isFieldModified(llmSettingsComponent.getExoModelUrlField(), settings.getExoModelUrl()); isModified |= isFieldModified(llmSettingsComponent.getCustomOpenAIUrlField(), settings.getCustomOpenAIUrl()); - isModified |= !settings.getShowAzureOpenAIFields().equals(llmSettingsComponent.getEnableAzureOpenAI().isSelected()); + isModified |= !settings.getShowAzureOpenAIFields().equals(llmSettingsComponent.getEnableAzureOpenAICheckBox().isSelected()); isModified |= isFieldModified(llmSettingsComponent.getAzureOpenAIEndpointField(), settings.getAzureOpenAIEndpoint()); isModified |= isFieldModified(llmSettingsComponent.getAzureOpenAIDeploymentField(), settings.getAzureOpenAIDeployment()); isModified |= isFieldModified(llmSettingsComponent.getAzureOpenAIKeyField(), settings.getAzureOpenAIKey()); + isModified |= settings.isOllamaEnabled() != llmSettingsComponent.getOllamaEnabledCheckBox().isSelected(); + isModified |= settings.isLmStudioEnabled() != llmSettingsComponent.getLmStudioEnabledCheckBox().isSelected(); + isModified |= settings.isGpt4AllEnabled() != llmSettingsComponent.getGpt4AllEnabledCheckBox().isSelected(); + isModified |= settings.isJanEnabled() != llmSettingsComponent.getJanEnabledCheckBox().isSelected(); + isModified |= settings.isExoEnabled() != llmSettingsComponent.getExoEnabledCheckBox().isSelected(); + isModified |= settings.isLlamaCPPEnabled() != llmSettingsComponent.getLlamaCPPEnabledCheckBox().isSelected(); + isModified |= settings.isJlamaEnabled() != llmSettingsComponent.getJlamaEnabledCheckBox().isSelected(); + isModified |= settings.isCustomOpenAIEnabled() != llmSettingsComponent.getCustomOpenAIEnabledCheckBox().isSelected(); + + isModified |= settings.isOpenAIEnabled() != llmSettingsComponent.getOpenAIEnabledCheckBox().isSelected(); + isModified |= settings.isMistralEnabled() != llmSettingsComponent.getMistralEnabledCheckBox().isSelected(); + isModified |= settings.isAnthropicEnabled() != llmSettingsComponent.getAnthropicEnabledCheckBox().isSelected(); + isModified |= settings.isGroqEnabled() != llmSettingsComponent.getGroqEnabledCheckBox().isSelected(); + isModified |= settings.isDeepInfraEnabled() != llmSettingsComponent.getDeepInfraEnabledCheckBox().isSelected(); + isModified |= settings.isGoogleEnabled() != llmSettingsComponent.getGeminiEnabledCheckBox().isSelected(); + isModified |= settings.isDeepSeekEnabled() != llmSettingsComponent.getDeepSeekEnabledCheckBox().isSelected(); + isModified |= settings.isOpenRouterEnabled() != llmSettingsComponent.getOpenRouterEnabledCheckBox().isSelected(); + isModified |= settings.getShowAzureOpenAIFields() != llmSettingsComponent.getEnableAzureOpenAICheckBox().isSelected(); + return isModified; } @@ -109,19 +128,41 @@ public void apply() { settings.setDeepSeekKey(new String(llmSettingsComponent.getDeepSeekApiKeyField().getPassword())); settings.setOpenRouterKey(new String(llmSettingsComponent.getOpenRouterApiKeyField().getPassword())); - settings.setShowAzureOpenAIFields(llmSettingsComponent.getEnableAzureOpenAI().isSelected()); + settings.setShowAzureOpenAIFields(llmSettingsComponent.getEnableAzureOpenAICheckBox().isSelected()); settings.setAzureOpenAIEndpoint(llmSettingsComponent.getAzureOpenAIEndpointField().getText()); settings.setAzureOpenAIDeployment(llmSettingsComponent.getAzureOpenAIDeploymentField().getText()); settings.setAzureOpenAIKey(new String(llmSettingsComponent.getAzureOpenAIKeyField().getPassword())); + settings.setOllamaEnabled(llmSettingsComponent.getOllamaEnabledCheckBox().isSelected()); + settings.setLmStudioEnabled(llmSettingsComponent.getLmStudioEnabledCheckBox().isSelected()); + settings.setGpt4AllEnabled(llmSettingsComponent.getGpt4AllEnabledCheckBox().isSelected()); + settings.setJanEnabled(llmSettingsComponent.getJanEnabledCheckBox().isSelected()); + settings.setExoEnabled(llmSettingsComponent.getExoEnabledCheckBox().isSelected()); + settings.setLlamaCPPEnabled(llmSettingsComponent.getLlamaCPPEnabledCheckBox().isSelected()); + settings.setJlamaEnabled(llmSettingsComponent.getJlamaEnabledCheckBox().isSelected()); + settings.setCustomOpenAIEnabled(llmSettingsComponent.getCustomOpenAIEnabledCheckBox().isSelected()); + + settings.setOpenAIEnabled(llmSettingsComponent.getOpenAIEnabledCheckBox().isSelected()); + settings.setMistralEnabled(llmSettingsComponent.getMistralEnabledCheckBox().isSelected()); + settings.setAnthropicEnabled(llmSettingsComponent.getAnthropicEnabledCheckBox().isSelected()); + settings.setGroqEnabled(llmSettingsComponent.getGroqEnabledCheckBox().isSelected()); + settings.setDeepInfraEnabled(llmSettingsComponent.getDeepInfraEnabledCheckBox().isSelected()); + settings.setGoogleEnabled(llmSettingsComponent.getGeminiEnabledCheckBox().isSelected()); + settings.setDeepSeekEnabled(llmSettingsComponent.getDeepSeekEnabledCheckBox().isSelected()); + settings.setOpenRouterEnabled(llmSettingsComponent.getOpenRouterEnabledCheckBox().isSelected()); + settings.setShowAzureOpenAIFields(llmSettingsComponent.getEnableAzureOpenAICheckBox().isSelected()); + // Only notify the listener if an API key has changed, so we can refresh the LLM providers list in the UI if (isModified) { - boolean hasKey = !settings.getAnthropicKey().isBlank() || - !settings.getOpenAIKey().isBlank() || - !settings.getOpenRouterKey().isBlank() || - !settings.getDeepSeekKey().isBlank() || - !settings.getDeepInfraKey().isBlank() || - !settings.getGeminiKey().isBlank(); + boolean hasKey = (!settings.getAnthropicKey().isBlank() && settings.isAnthropicEnabled()) || + (!settings.getOpenAIKey().isBlank() && settings.isOpenAIEnabled()) || + (!settings.getOpenRouterKey().isBlank() && settings.isOpenRouterEnabled()) || + (!settings.getDeepSeekKey().isBlank() && settings.isDeepSeekEnabled()) || + (!settings.getDeepInfraKey().isBlank() && settings.isDeepInfraEnabled()) || + (!settings.getGeminiKey().isBlank() && settings.isGoogleEnabled()) || + (!settings.getGroqKey().isBlank() && settings.isGroqEnabled()) || + (!settings.getMistralKey().isBlank() && settings.isMistralEnabled()) || + (!settings.getAzureOpenAIKey().isBlank() && settings.getShowAzureOpenAIFields()); ApplicationManager.getApplication().getMessageBus() .syncPublisher(AppTopics.SETTINGS_CHANGED_TOPIC) @@ -156,9 +197,28 @@ public void reset() { llmSettingsComponent.getDeepSeekApiKeyField().setText(settings.getDeepSeekKey()); llmSettingsComponent.getOpenRouterApiKeyField().setText(settings.getOpenRouterKey()); - llmSettingsComponent.getEnableAzureOpenAI().setSelected(settings.getShowAzureOpenAIFields()); + llmSettingsComponent.getEnableAzureOpenAICheckBox().setSelected(settings.getShowAzureOpenAIFields()); llmSettingsComponent.getAzureOpenAIEndpointField().setText(settings.getAzureOpenAIEndpoint()); llmSettingsComponent.getAzureOpenAIDeploymentField().setText(settings.getAzureOpenAIDeployment()); llmSettingsComponent.getAzureOpenAIKeyField().setText(settings.getAzureOpenAIKey()); + + llmSettingsComponent.getOllamaEnabledCheckBox().setSelected(settings.isOllamaEnabled()); + llmSettingsComponent.getLmStudioEnabledCheckBox().setSelected(settings.isLmStudioEnabled()); + llmSettingsComponent.getGpt4AllEnabledCheckBox().setSelected(settings.isGpt4AllEnabled()); + llmSettingsComponent.getJanEnabledCheckBox().setSelected(settings.isJanEnabled()); + llmSettingsComponent.getExoEnabledCheckBox().setSelected(settings.isExoEnabled()); + llmSettingsComponent.getLlamaCPPEnabledCheckBox().setSelected(settings.isLlamaCPPEnabled()); + llmSettingsComponent.getJlamaEnabledCheckBox().setSelected(settings.isJlamaEnabled()); + llmSettingsComponent.getCustomOpenAIEnabledCheckBox().setSelected(settings.isCustomOpenAIEnabled()); + + llmSettingsComponent.getOpenAIEnabledCheckBox().setSelected(settings.isOpenAIEnabled()); + llmSettingsComponent.getMistralEnabledCheckBox().setSelected(settings.isMistralEnabled()); + llmSettingsComponent.getAnthropicEnabledCheckBox().setSelected(settings.isAnthropicEnabled()); + llmSettingsComponent.getGroqEnabledCheckBox().setSelected(settings.isGroqEnabled()); + llmSettingsComponent.getDeepInfraEnabledCheckBox().setSelected(settings.isDeepInfraEnabled()); + llmSettingsComponent.getGeminiEnabledCheckBox().setSelected(settings.isGoogleEnabled()); + llmSettingsComponent.getDeepSeekEnabledCheckBox().setSelected(settings.isDeepSeekEnabled()); + llmSettingsComponent.getOpenRouterEnabledCheckBox().setSelected(settings.isOpenRouterEnabled()); + llmSettingsComponent.getEnableAzureOpenAICheckBox().setSelected(settings.getShowAzureOpenAIFields()); } } diff --git a/src/main/java/com/devoxx/genie/util/ChatMessageContextUtil.java b/src/main/java/com/devoxx/genie/util/ChatMessageContextUtil.java index 8f5c4d32..22ba68ae 100644 --- a/src/main/java/com/devoxx/genie/util/ChatMessageContextUtil.java +++ b/src/main/java/com/devoxx/genie/util/ChatMessageContextUtil.java @@ -125,7 +125,7 @@ private static void addEditorInfoToMessageContext(Editor editor, public static boolean isOpenAIo1Model(LanguageModel languageModel) { return languageModel != null && - languageModel.getProvider() == ModelProvider.OpenAI && + languageModel.getProvider() == ModelProvider.OPENAI && languageModel.getModelName() != null && languageModel.getModelName().toLowerCase().startsWith("o1-"); } diff --git a/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java b/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java index e2b80f44..133e29c0 100644 --- a/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java +++ b/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java @@ -28,14 +28,14 @@ public class DefaultLLMSettingsUtil { * @return true when API Key is required, meaning a cost is involved */ public static boolean isApiKeyBasedProvider(ModelProvider provider) { - return provider == ModelProvider.OpenAI || - provider == ModelProvider.Anthropic || - provider == ModelProvider.Mistral || - provider == ModelProvider.Groq || - provider == ModelProvider.DeepInfra || - provider == ModelProvider.Google || - provider == ModelProvider.OpenRouter || - provider == ModelProvider.AzureOpenAI; + return provider == ModelProvider.OPENAI || + provider == ModelProvider.ANTHROPIC || + provider == ModelProvider.MISTRAL || + provider == ModelProvider.GROQ || + provider == ModelProvider.DEEP_INFRA || + provider == ModelProvider.GOOGLE || + provider == ModelProvider.OPEN_ROUTER || + provider == ModelProvider.AZURE_OPEN_AI; } public record CostKey(ModelProvider provider, String modelName) { diff --git a/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java b/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java index bdf6f4ab..268626ef 100644 --- a/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java +++ b/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java @@ -19,16 +19,16 @@ public static List getApiKeyEnabledProviders() { DevoxxGenieSettingsService settings = DevoxxGenieStateService.getInstance(); return Arrays.stream(ModelProvider.values()) .filter(provider -> switch (provider) { - case OpenAI -> !settings.getOpenAIKey().isEmpty(); - case AzureOpenAI -> !settings.getAzureOpenAIKey().isEmpty() && + case OPENAI -> !settings.getOpenAIKey().isEmpty(); + case AZURE_OPEN_AI -> !settings.getAzureOpenAIKey().isEmpty() && !settings.getAzureOpenAIEndpoint().isEmpty() && !settings.getAzureOpenAIDeployment().isEmpty(); - case Anthropic -> !settings.getAnthropicKey().isEmpty(); - case Mistral -> !settings.getMistralKey().isEmpty(); - case Groq -> !settings.getGroqKey().isEmpty(); - case DeepInfra -> !settings.getDeepInfraKey().isEmpty(); - case DeepSeek -> !settings.getDeepSeekKey().isEmpty(); - case OpenRouter -> !settings.getOpenRouterKey().isEmpty(); - case Google -> !settings.getGeminiKey().isEmpty(); + case ANTHROPIC -> !settings.getAnthropicKey().isEmpty(); + case MISTRAL -> !settings.getMistralKey().isEmpty(); + case GROQ -> !settings.getGroqKey().isEmpty(); + case DEEP_INFRA -> !settings.getDeepInfraKey().isEmpty(); + case DEEP_SEEK -> !settings.getDeepSeekKey().isEmpty(); + case OPEN_ROUTER -> !settings.getOpenRouterKey().isEmpty(); + case GOOGLE -> !settings.getGeminiKey().isEmpty(); default -> false; }) .collect(Collectors.toList()); diff --git a/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java b/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java index 1fc0f7bf..e4415bee 100644 --- a/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java +++ b/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java @@ -56,7 +56,7 @@ private void mockSettingsState() { @Test public void testExecuteQueryOpenAI() { LanguageModel model = LanguageModel.builder() - .provider(ModelProvider.OpenAI) + .provider(ModelProvider.OPENAI) .modelName("gpt-3.5-turbo") .displayName("GPT-3.5 Turbo") .apiKeyUsed(true) @@ -70,7 +70,7 @@ public void testExecuteQueryOpenAI() { @Test public void testExecuteQueryAnthropic() { LanguageModel model = LanguageModel.builder() - .provider(ModelProvider.Anthropic) + .provider(ModelProvider.ANTHROPIC) .modelName("claude-3-5-sonnet-20240620") .displayName("claude-3-5-sonnet-20240620") .apiKeyUsed(true) @@ -84,7 +84,7 @@ public void testExecuteQueryAnthropic() { @Test public void testExecuteQueryGemini() { LanguageModel model = LanguageModel.builder() - .provider(ModelProvider.Google) + .provider(ModelProvider.GOOGLE) .modelName("gemini-1.5-flash") .displayName("Gemini 1.5 Flash") .apiKeyUsed(true) @@ -98,7 +98,7 @@ public void testExecuteQueryGemini() { @Test public void testExecuteQueryMistral() { LanguageModel model = LanguageModel.builder() - .provider(ModelProvider.Mistral) + .provider(ModelProvider.MISTRAL) .modelName("mistral-medium") .displayName("Mistral Medium") .apiKeyUsed(true) @@ -112,7 +112,7 @@ public void testExecuteQueryMistral() { @Test public void testExecuteQueryDeepInfra() { LanguageModel model = LanguageModel.builder() - .provider(ModelProvider.DeepInfra) + .provider(ModelProvider.DEEP_INFRA) .modelName("mistralai/Mixtral-8x7B-Instruct-v0.1") .displayName("Mixtral 8x7B") .apiKeyUsed(true) @@ -125,23 +125,23 @@ public void testExecuteQueryDeepInfra() { private ChatLanguageModel createChatModel(LanguageModel languageModel) { return switch (languageModel.getProvider()) { - case OpenAI -> OpenAiChatModel.builder() + case OPENAI -> OpenAiChatModel.builder() .apiKey(dotenv.get("OPENAI_API_KEY")) .modelName(languageModel.getModelName()) .build(); - case Anthropic -> AnthropicChatModel.builder() + case ANTHROPIC -> AnthropicChatModel.builder() .apiKey(dotenv.get("ANTHROPIC_API_KEY")) .modelName(languageModel.getModelName()) .build(); - case Google -> GoogleAiGeminiChatModel.builder() + case GOOGLE -> GoogleAiGeminiChatModel.builder() .apiKey(dotenv.get("GEMINI_API_KEY")) .modelName(languageModel.getModelName()) .build(); - case Mistral -> MistralAiChatModel.builder() + case MISTRAL -> MistralAiChatModel.builder() .apiKey(dotenv.get("MISTRAL_API_KEY")) .modelName(languageModel.getModelName()) .build(); - case DeepInfra -> OpenAiChatModel.builder() + case DEEP_INFRA -> OpenAiChatModel.builder() .baseUrl("https://api.deepinfra.com/v1/openai") .apiKey(dotenv.get("DEEPINFRA_API_KEY")) .modelName(languageModel.getModelName())