diff --git a/build.gradle.kts b/build.gradle.kts index 72792602..cf1f7deb 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -7,7 +7,7 @@ plugins { } group = "com.devoxx.genie" -version = "0.2.5" +version = "0.2.6" repositories { mavenCentral() diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java index c272f086..9896f537 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java @@ -4,11 +4,9 @@ import com.devoxx.genie.chatmodel.deepinfra.DeepInfraChatModelFactory; import com.devoxx.genie.chatmodel.gemini.GeminiChatModelFactory; import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory; -import com.devoxx.genie.chatmodel.jan.JanChatModelFactory; import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory; import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory; import com.devoxx.genie.chatmodel.openai.OpenAIChatModelFactory; -import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; import org.jetbrains.annotations.NotNull; @@ -28,7 +26,7 @@ public class ChatModelFactoryProvider { ModelProvider.Mistral.getName(), MistralChatModelFactory::new, ModelProvider.Groq.getName(), GroqChatModelFactory::new, ModelProvider.DeepInfra.getName(), DeepInfraChatModelFactory::new, - ModelProvider.Gemini.getName(), GeminiChatModelFactory::new + ModelProvider.Google.getName(), GeminiChatModelFactory::new // TODO Removed because currently is broken by latest Jan! version // ModelProvider.Jan, JanChatModelFactory::new ); diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java index 53eda20e..5d7a9346 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java @@ -17,6 +17,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import lombok.Setter; +import org.jetbrains.annotations.Contract; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -38,7 +39,7 @@ public ChatModelProvider() { factories.put(ModelProvider.Mistral, new MistralChatModelFactory()); factories.put(ModelProvider.Anthropic, new AnthropicChatModelFactory()); factories.put(ModelProvider.Groq, new GroqChatModelFactory()); - factories.put(ModelProvider.Gemini, new GeminiChatModelFactory()); + factories.put(ModelProvider.Google, new GeminiChatModelFactory()); // TODO Currently broken by latest Jan! version // factories.put(ModelProvider.Jan, new JanChatModelFactory()); } @@ -107,7 +108,8 @@ private String getModelName(@Nullable LanguageModel languageModel) { .orElseGet(() -> getDefaultModelName(languageModel.getProvider())); } - private String getDefaultModelName(@Nullable ModelProvider provider) { + @Contract(pure = true) + private @NotNull String getDefaultModelName(@Nullable ModelProvider provider) { if (provider == null) { return "DefaultModel"; } @@ -118,7 +120,8 @@ private String getDefaultModelName(@Nullable ModelProvider provider) { }; } - private static void setMaxOutputTokens(@NotNull DevoxxGenieStateService settingsState, ChatModel chatModel) { + private static void setMaxOutputTokens(@NotNull DevoxxGenieStateService settingsState, + @NotNull ChatModel chatModel) { Integer maxOutputTokens = settingsState.getMaxOutputTokens(); chatModel.setMaxTokens(maxOutputTokens != null ? maxOutputTokens : Constant.MAX_OUTPUT_TOKENS); } diff --git a/src/main/java/com/devoxx/genie/chatmodel/gemini/GeminiChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/gemini/GeminiChatModelFactory.java index 57a46c6e..35ca3b0a 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/gemini/GeminiChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/gemini/GeminiChatModelFactory.java @@ -32,6 +32,6 @@ public String getApiKey() { @Override public List getModels() { - return getModels(ModelProvider.Gemini); + return getModels(ModelProvider.Google); } } diff --git a/src/main/java/com/devoxx/genie/model/LanguageModel.java b/src/main/java/com/devoxx/genie/model/LanguageModel.java index 6d8aa218..0f7380c8 100644 --- a/src/main/java/com/devoxx/genie/model/LanguageModel.java +++ b/src/main/java/com/devoxx/genie/model/LanguageModel.java @@ -5,6 +5,8 @@ import lombok.Data; import org.jetbrains.annotations.NotNull; +import java.util.Comparator; + @Data @Builder public class LanguageModel implements Comparable { @@ -16,11 +18,72 @@ public class LanguageModel implements Comparable { private double outputCost; private int contextWindow; - public int compareTo(@NotNull LanguageModel languageModel) { - return this.displayName.compareTo(languageModel.displayName); + @Override + public String toString() { + return displayName; } - public String toString() { - return provider.getName(); + @Override + public int compareTo(@NotNull LanguageModel other) { + return new ModelVersionComparator().compare(this.displayName, other.displayName); + } + + private static class ModelVersionComparator implements Comparator { + @Override + public int compare(String v1, String v2) { + String[] parts1 = v1.split(" "); + String[] parts2 = v2.split(" "); + + // Compare model names + int modelNameCompare = parts1[0].compareTo(parts2[0]); + if (modelNameCompare != 0) return modelNameCompare; + + // Extract version strings + String version1 = parts1.length > 1 ? parts1[1] : ""; + String version2 = parts2.length > 1 ? parts2[1] : ""; + + // Handle special versions (Sonnet, Haiku, Opus) + if (isSpecialVersion(version1) || isSpecialVersion(version2)) { + return compareSpecialVersions(version1, version2); + } + + // Compare version strings + return compareVersions(version1, version2); + } + + private boolean isSpecialVersion(@NotNull String version) { + return version.equals("Sonnet") || version.equals("Haiku") || version.equals("Opus"); + } + + private int compareSpecialVersions(@NotNull String v1, String v2) { + if (v1.equals(v2)) return 0; + if (v1.equals("Opus")) return 1; + if (v2.equals("Opus")) return -1; + if (v1.equals("Sonnet")) return 1; + if (v2.equals("Sonnet")) return -1; + return v1.compareTo(v2); + } + + private int compareVersions(@NotNull String v1, @NotNull String v2) { + String[] parts1 = v1.split("[^a-zA-Z0-9]+"); + String[] parts2 = v2.split("[^a-zA-Z0-9]+"); + + for (int i = 0; i < Math.max(parts1.length, parts2.length); i++) { + String part1 = i < parts1.length ? parts1[i] : ""; + String part2 = i < parts2.length ? parts2[i] : ""; + + int cmp = compareAlphanumeric(part1, part2); + if (cmp != 0) return cmp; + } + + return 0; + } + + private int compareAlphanumeric(@NotNull String s1, String s2) { + if (s1.matches("\\d+") && s2.matches("\\d+")) { + return Integer.compare(Integer.parseInt(s1), Integer.parseInt(s2)); + } + return s1.compareTo(s2); + } } } diff --git a/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java b/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java index 790d886d..90022191 100644 --- a/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java +++ b/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java @@ -13,7 +13,7 @@ public enum ModelProvider { Mistral("Mistral"), Groq("Groq"), DeepInfra("DeepInfra"), - Gemini("Gemini"); + Google("Google"); private final String name; diff --git a/src/main/java/com/devoxx/genie/model/gemini/GeminiChatModel.java b/src/main/java/com/devoxx/genie/model/gemini/GeminiChatModel.java index d0c9078d..a56dd6d7 100644 --- a/src/main/java/com/devoxx/genie/model/gemini/GeminiChatModel.java +++ b/src/main/java/com/devoxx/genie/model/gemini/GeminiChatModel.java @@ -43,15 +43,6 @@ public GeminiChatModel(String apiKey, .modelName(modelName) .timeout(getOrDefault(timeout, ofSeconds(60))) .build(); - -// messageRequest.setGenerationConfig(GenerationConfig.builder() -// .maxOutputTokens(maxTokens) -// .temperature(temperature) -// .build()); - -// messageRequest.setSystemInstruction(SystemInstruction.builder() -// .parts(List.of(Part.builder().text("Always return response in markdown").build())) -// .build()); } @Override diff --git a/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java b/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java index bd666135..c1b354e4 100644 --- a/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java +++ b/src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java @@ -273,7 +273,7 @@ private void addDeepInfraModels() { private void addGeminiModels() { models.add(LanguageModel.builder() - .provider(ModelProvider.Gemini) + .provider(ModelProvider.Google) .modelName("gemini-1.5-flash-latest") .displayName("Gemini 1.5 Flash") .inputCost(0.7) @@ -283,17 +283,17 @@ private void addGeminiModels() { .build()); models.add(LanguageModel.builder() - .provider(ModelProvider.Gemini) + .provider(ModelProvider.Google) .modelName("gemini-1.5-pro-latest") .displayName("Gemini 1.5 Pro") .inputCost(7) .outputCost(21) - .contextWindow(1_000_000) + .contextWindow(2_000_000) .apiKeyUsed(true) .build()); models.add(LanguageModel.builder() - .provider(ModelProvider.Gemini) + .provider(ModelProvider.Google) .modelName("gemini-1.0-pro") .displayName("Gemini 1.0 Pro") .inputCost(0.5) diff --git a/src/main/java/com/devoxx/genie/service/LLMProviderService.java b/src/main/java/com/devoxx/genie/service/LLMProviderService.java index 5fbf2ce0..731fc15c 100644 --- a/src/main/java/com/devoxx/genie/service/LLMProviderService.java +++ b/src/main/java/com/devoxx/genie/service/LLMProviderService.java @@ -9,10 +9,9 @@ import java.util.*; import java.util.function.Supplier; import java.util.stream.Collectors; -import java.util.stream.Stream; import static com.devoxx.genie.model.enumarations.ModelProvider.*; -import static com.devoxx.genie.model.enumarations.ModelProvider.Gemini; +import static com.devoxx.genie.model.enumarations.ModelProvider.Google; public class LLMProviderService { @@ -37,7 +36,7 @@ public List getModelProvidersWithApiKeyConfigured() { providerKeyMap.put(Mistral, settings::getMistralKey); providerKeyMap.put(Groq, settings::getGroqKey); providerKeyMap.put(DeepInfra, settings::getDeepInfraKey); - providerKeyMap.put(Gemini, settings::getGeminiKey); + providerKeyMap.put(Google, settings::getGeminiKey); // Filter out cloud LLM providers that do not have a key List providersWithRequiredKey = LLMModelRegistryService.getInstance().getModels() diff --git a/src/main/java/com/devoxx/genie/service/ProjectContentService.java b/src/main/java/com/devoxx/genie/service/ProjectContentService.java index 51fdfeea..d6086b3a 100644 --- a/src/main/java/com/devoxx/genie/service/ProjectContentService.java +++ b/src/main/java/com/devoxx/genie/service/ProjectContentService.java @@ -122,17 +122,25 @@ public void calculateTokensAndCost(Project project, .thenAccept(projectContent -> { int tokenCount = ENCODING.countTokens(projectContent); double estimatedInputCost = calculateCost(tokenCount, inputCost); - String message = String.format("Project contains %s. Estimated min. cost using %s is $%.6f", + String message = String.format("Project contains %s. Estimated min. cost using %s %s is $%.6f", WindowContextFormatterUtil.format(tokenCount, "tokens"), + provider.getName(), languageModel.getDisplayName(), estimatedInputCost); + + // Add check for token count exceeding max context size + if (tokenCount > languageModel.getContextWindow()) { + message += String.format(". Total project size exceeds model's max context of %s tokens.", + WindowContextFormatterUtil.format(languageModel.getContextWindow())); + } + NotificationUtil.sendNotification(project, message); }); } private Encoding getEncodingForProvider(@NotNull ModelProvider provider) { return switch (provider) { - case OpenAI, Anthropic, Gemini -> + case OpenAI, Anthropic, Google -> Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); case Mistral, DeepInfra, Groq -> // These often use the Llama tokenizer or similar diff --git a/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java b/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java index 40cafcb9..f689eb9d 100644 --- a/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java +++ b/src/main/java/com/devoxx/genie/ui/DevoxxGenieToolWindowContent.java @@ -35,9 +35,8 @@ import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; +import java.util.*; import java.util.List; -import java.util.Optional; -import java.util.ResourceBundle; import java.util.stream.Stream; import static com.devoxx.genie.model.Constant.MESSAGES; @@ -211,7 +210,7 @@ private void addModelProvidersToComboBox() { providerService.getLocalModelProviders().stream() ) .distinct() - .sorted() + .sorted(Comparator.comparing(ModelProvider::getName)) .forEach(modelProviderComboBox::addItem); } @@ -369,13 +368,12 @@ private void updateModelNamesComboBox(String modelProvider) { */ private void populateModelNames(@NotNull ChatModelFactory chatModelFactory) { modelNameComboBox.removeAllItems(); - List modelNames = chatModelFactory.getModels(); + List modelNames = new ArrayList<>(chatModelFactory.getModels()); if (modelNames.isEmpty()) { hideModelNameComboBox(); } else { - modelNames.stream() - .sorted() - .forEach(modelNameComboBox::addItem); + modelNames.sort(Comparator.naturalOrder()); + modelNames.forEach(modelNameComboBox::addItem); } } diff --git a/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java b/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java index d37ddc7b..72c3e956 100644 --- a/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java +++ b/src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java @@ -247,7 +247,7 @@ private boolean isProjectContextSupportedProvider() { return selectedProvider != null && ( selectedProvider.equals(ModelProvider.OpenAI) || selectedProvider.equals(ModelProvider.Anthropic) || - selectedProvider.equals(ModelProvider.Gemini) + selectedProvider.equals(ModelProvider.Google) ); } @@ -409,7 +409,7 @@ private void addProjectToContext() { return; } - if (!modelProvider.equals(ModelProvider.Gemini) && + if (!modelProvider.equals(ModelProvider.Google) && !modelProvider.equals(ModelProvider.Anthropic) && !modelProvider.equals(ModelProvider.OpenAI)) { NotificationUtil.sendNotification(project, diff --git a/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java b/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java index a1742d9e..647966f4 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java +++ b/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java @@ -135,11 +135,39 @@ public void setModelCost(ModelProvider provider, } } - public double getModelInputCost(ModelProvider provider, String modelName) { +// public double getModelInputCost(ModelProvider provider, String modelName) { +// if (DefaultLLMSettingsUtil.isApiBasedProvider(provider)) { +// String key = provider.getName() + ":" + modelName; +// return modelInputCosts.getOrDefault(key, +// DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0)); +// } +// return 0.0; +// } + + public double getModelInputCost(@NotNull ModelProvider provider, String modelName) { + String key = provider.getName() + ":" + modelName; + double cost = modelInputCosts.getOrDefault(key, 0.0); + if (cost == 0.0) { + DefaultLLMSettingsUtil.CostKey costKey = new DefaultLLMSettingsUtil.CostKey(provider, modelName); + cost = DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(costKey, 0.0); + if (cost == 0.0) { + // Fallback to similar model names + for (Map.Entry entry : DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.entrySet()) { + if (entry.getKey().provider == provider && entry.getKey().modelName.startsWith(modelName.split("-")[0])) { + cost = entry.getValue(); + break; + } + } + } + } + return cost; + } + + public double getModelOutputCost(ModelProvider provider, String modelName) { if (DefaultLLMSettingsUtil.isApiBasedProvider(provider)) { String key = provider.getName() + ":" + modelName; - return modelInputCosts.getOrDefault(key, - DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0)); + return modelOutputCosts.getOrDefault(key, + DefaultLLMSettingsUtil.DEFAULT_OUTPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0)); } return 0.0; } diff --git a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java index 8f21ebfd..b9fbd99d 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java +++ b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java @@ -89,7 +89,7 @@ public JPanel createPanel() { addSettingRow(panel, gbc, "Anthropic API Key", createTextWithPasswordButton(anthropicApiKeyField, "https://console.anthropic.com/settings/keys")); addSettingRow(panel, gbc, "Groq API Key", createTextWithPasswordButton(groqApiKeyField, "https://console.groq.com/keys")); addSettingRow(panel, gbc, "DeepInfra API Key", createTextWithPasswordButton(deepInfraApiKeyField, "https://deepinfra.com/dash/api_keys")); - addSettingRow(panel, gbc, "Gemini API Key", createTextWithPasswordButton(geminiApiKeyField, "https://aistudio.google.com/app/apikey")); + addSettingRow(panel, gbc, "Google Gemini API Key", createTextWithPasswordButton(geminiApiKeyField, "https://aistudio.google.com/app/apikey")); addSection(panel, gbc, "Search Providers"); addSettingRow(panel, gbc, "Tavily Web Search API Key", createTextWithPasswordButton(tavilySearchApiKeyField, "https://app.tavily.com/home")); diff --git a/src/main/java/com/devoxx/genie/ui/util/WindowContextFormatterUtil.java b/src/main/java/com/devoxx/genie/ui/util/WindowContextFormatterUtil.java index f977e163..2ecdcbf4 100644 --- a/src/main/java/com/devoxx/genie/ui/util/WindowContextFormatterUtil.java +++ b/src/main/java/com/devoxx/genie/ui/util/WindowContextFormatterUtil.java @@ -6,11 +6,11 @@ public class WindowContextFormatterUtil { public static @NotNull String format(int tokens, String suffix) { if (tokens >= 1_000_000_000) { - return String.format("%.2fB %s", tokens / 1_000_000_000.0, suffix); + return String.format("%dB %s", (tokens / 1_000_000_000), suffix); } else if (tokens >= 1_000_000) { - return String.format("%.2fM %s", tokens / 1_000_000.0, suffix); + return String.format("%dM %s", (tokens / 1_000_000), suffix); } else if (tokens >= 1_000) { - return String.format("%.2fK %s", tokens / 1_000.0, suffix); + return String.format("%dK %s", (tokens / 1_000), suffix); } else { return String.format("%d %s", tokens, suffix); } diff --git a/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java b/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java index 96909a0e..2de19ed6 100644 --- a/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java +++ b/src/main/java/com/devoxx/genie/util/DefaultLLMSettingsUtil.java @@ -1,95 +1,35 @@ package com.devoxx.genie.util; - +import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; +import com.devoxx.genie.service.LLMModelRegistryService; import java.util.HashMap; import java.util.Map; import java.util.Objects; -import static dev.langchain4j.model.anthropic.AnthropicChatModelName.*; -import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_INSTANT_1_2; -import static dev.langchain4j.model.mistralai.MistralAiChatModelName.*; - public class DefaultLLMSettingsUtil { + public static final Map DEFAULT_INPUT_COSTS = new HashMap<>(); public static final Map DEFAULT_OUTPUT_COSTS = new HashMap<>(); static { - // OpenAI models - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.OpenAI, "gpt-4"), 0.03); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.OpenAI, "gpt-4"), 0.06); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.OpenAI, "gpt-3.5-turbo"), 0.0015); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.OpenAI, "gpt-3.5-turbo"), 0.002); - - // Anthropic models - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, "claude-3-5-sonnet-20240620"), 3.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, "claude-3-5-sonnet-20240620"), 15.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_3_OPUS_20240229.toString()), 15.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_3_OPUS_20240229.toString()), 75.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_3_SONNET_20240229.toString()), 3.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_3_SONNET_20240229.toString()), 15.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_3_HAIKU_20240307.toString()), 0.25); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_3_HAIKU_20240307.toString()), 1.25); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_2_1.toString()), 8.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_2_1.toString()), 24.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_2.toString()), 8.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_2.toString()), 24.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_INSTANT_1_2.toString()), 0.8); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Anthropic, CLAUDE_INSTANT_1_2.toString()), 2.4); - - // Gemini - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Gemini, "gemini-pro"), 0.5); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Gemini, "gemini-pro"), 1.5); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Gemini, "gemini-1.5-pro-latest"), 7.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Gemini, "gemini-1.5-pro-latest"), 21.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Gemini, "gemini-1.5-flash-latest"), 0.7); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Gemini, "gemini-1.5-flash-latest"), 2.1); - - // DeepInfra - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "meta-llama/Meta-Llama-3-70B-Instruct"), 0.56); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "meta-llama/Meta-Llama-3-70B-Instruct"), 0.77); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "meta-llama/Meta-Llama-3-8B-Instruct"), 0.064); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "meta-llama/Meta-Llama-3-8B-Instruct"), 0.064); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "mistralai/Mixtral-8x7B-Instruct-v0.1"), 0.24); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "mistralai/Mixtral-8x7B-Instruct-v0.1"), 0.24); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "mistralai/Mixtral-8x22B-Instruct-v0.1"), 0.65); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "mistralai/Mixtral-8x22B-Instruct-v0.1"), 0.65); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "mistralai/Mistral-7B-Instruct-v0.3"), 0.07); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "mistralai/Mistral-7B-Instruct-v0.3"), 0.07); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "microsoft/WizardLM-2-8x22B"), 0.65); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "microsoft/WizardLM-2-8x22B"), 0.65); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "microsoft/WizardLM-2-7B"), 0.07); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "microsoft/WizardLM-2-7B"), 0.07); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "openchat/openchat_3.5"), 0.07); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "openchat/openchat_3.5"), 0.07); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "google/gemma-1.1-7b-it"), 0.07); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "google/gemma-1.1-7b-it"), 0.07); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "Phind/Phind-CodeLlama-34B-v2"), 0.6); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "Phind/Phind-CodeLlama-34B-v2"), 0.6); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "cognitivecomputations/dolphin-2.6-mixtral-8x7b"), 0.24); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.DeepInfra, "cognitivecomputations/dolphin-2.6-mixtral-8x7b"), 0.24); + initializeDefaultCosts(); + } - // Mistral - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Mistral, OPEN_MISTRAL_7B.toString()), 0.25); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Mistral, OPEN_MISTRAL_7B.toString()), 0.25); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Mistral, OPEN_MIXTRAL_8x7B.toString()), 0.7); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Mistral, OPEN_MIXTRAL_8x7B.toString()), 0.7); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Mistral, MISTRAL_SMALL_LATEST.toString()), 1.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Mistral, MISTRAL_SMALL_LATEST.toString()), 3.0); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Mistral, MISTRAL_MEDIUM_LATEST.toString()), 2.7); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Mistral, MISTRAL_MEDIUM_LATEST.toString()), 8.1); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Mistral, MISTRAL_LARGE_LATEST.toString()), 4.0); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Mistral, MISTRAL_LARGE_LATEST.toString()), 12.0); + private static void initializeDefaultCosts() { + LLMModelRegistryService modelRegistry = LLMModelRegistryService.getInstance(); + for (LanguageModel model : modelRegistry.getModels()) { + if (isApiBasedProvider(model.getProvider())) { + DEFAULT_INPUT_COSTS.put(new CostKey(model.getProvider(), model.getModelName()), model.getInputCost()); + DEFAULT_OUTPUT_COSTS.put(new CostKey(model.getProvider(), model.getModelName()), model.getOutputCost()); + } + } - // Groq - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Groq, "gemma-7b-it"), 0.07); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Groq, "gemma-7b-it"), 0.07); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Groq, "llama3-8b-8192"), 0.05); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Groq, "llama3-8b-8192"), 0.08); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Groq, "llama3-70b-8192"), 0.59); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Groq, "llama3-70b-8192"), 0.79); - DEFAULT_INPUT_COSTS.put(new CostKey(ModelProvider.Groq, "mixtral-8x7b-32768"), 0.24); - DEFAULT_OUTPUT_COSTS.put(new CostKey(ModelProvider.Groq, "mixtral-8x7b-32768"), 0.24); + // Print out the contents for verification + System.out.println("DEFAULT_INPUT_COSTS contents:"); + for (Map.Entry entry : DEFAULT_INPUT_COSTS.entrySet()) { + System.out.println(entry.getKey() + " -> " + entry.getValue()); + } } public static boolean isApiBasedProvider(ModelProvider provider) { @@ -98,7 +38,7 @@ public static boolean isApiBasedProvider(ModelProvider provider) { provider == ModelProvider.Mistral || provider == ModelProvider.Groq || provider == ModelProvider.DeepInfra || - provider == ModelProvider.Gemini; + provider == ModelProvider.Google; } public static class CostKey { @@ -115,12 +55,18 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; CostKey costKey = (CostKey) o; - return provider == costKey.provider && modelName.equals(costKey.modelName); + return provider == costKey.provider && + Objects.equals(modelName, costKey.modelName); } @Override public int hashCode() { return Objects.hash(provider, modelName); } + + @Override + public String toString() { + return "CostKey{provider=" + provider + ", modelName='" + modelName + "'}"; + } } } diff --git a/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java b/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java index 696bc2d0..b784bfa4 100644 --- a/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java +++ b/src/main/java/com/devoxx/genie/util/LLMProviderUtil.java @@ -17,7 +17,7 @@ public static java.util.List getApiKeyEnabledProviders() { case Mistral -> !settings.getMistralKey().isEmpty(); case Groq -> !settings.getGroqKey().isEmpty(); case DeepInfra -> !settings.getDeepInfraKey().isEmpty(); - case Gemini -> !settings.getGeminiKey().isEmpty(); + case Google -> !settings.getGeminiKey().isEmpty(); default -> false; }) .collect(Collectors.toList()); diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index e30f1e43..69e96e53 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -35,6 +35,13 @@ ]]> v0.2.6 +
    +
  • Renamed Gemini LLM provider to Google
  • +
  • Increased Gemini Pro 1.5 window context to 2M
  • +
  • Sorting LLM providers and model names alphabetically in combobox
  • +
  • LLM cost calculation refactored
  • +

v0.2.5

  • Feat #171: Support OpenAI GPT 4o mini
  • diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 0fd5fbe6..f4730531 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,2 +1,2 @@ -#Fri Jul 19 01:07:30 EEST 2024 -version=0.2.5 +#Mon Jul 22 11:43:01 CEST 2024 +version=0.2.6 diff --git a/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java b/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java index dc3986d9..ebc6dc8a 100644 --- a/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java +++ b/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java @@ -83,7 +83,7 @@ public void testExecuteQueryAnthropic() { @Test public void testExecuteQueryGemini() { LanguageModel model = LanguageModel.builder() - .provider(ModelProvider.Gemini) + .provider(ModelProvider.Google) .modelName("gemini-pro") .displayName("Gemini Pro") .apiKeyUsed(true) @@ -132,7 +132,7 @@ private ChatLanguageModel createChatModel(LanguageModel languageModel) { .apiKey(dotenv.get("ANTHROPIC_API_KEY")) .modelName(languageModel.getModelName()) .build(); - case Gemini -> GeminiChatModel.builder() + case Google -> GeminiChatModel.builder() .apiKey(dotenv.get("GEMINI_API_KEY")) .modelName(languageModel.getModelName()) .build();