Skip to content

Commit

Permalink
Renamed Gemini LLM Provider to Google. Increased context window to 2M…
Browse files Browse the repository at this point in the history
… for Gemini Pro. Sorting LLM providers and model names. Calc cost fix
  • Loading branch information
stephanj committed Jul 22, 2024
1 parent 9b02a96 commit 9d52a90
Show file tree
Hide file tree
Showing 20 changed files with 173 additions and 132 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.2.5"
version = "0.2.6"

repositories {
mavenCentral()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import com.devoxx.genie.chatmodel.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.gemini.GeminiChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.openai.OpenAIChatModelFactory;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import org.jetbrains.annotations.NotNull;

Expand All @@ -28,7 +26,7 @@ public class ChatModelFactoryProvider {
ModelProvider.Mistral.getName(), MistralChatModelFactory::new,
ModelProvider.Groq.getName(), GroqChatModelFactory::new,
ModelProvider.DeepInfra.getName(), DeepInfraChatModelFactory::new,
ModelProvider.Gemini.getName(), GeminiChatModelFactory::new
ModelProvider.Google.getName(), GeminiChatModelFactory::new
// TODO Removed because currently is broken by latest Jan! version
// ModelProvider.Jan, JanChatModelFactory::new
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Setter;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand All @@ -38,7 +39,7 @@ public ChatModelProvider() {
factories.put(ModelProvider.Mistral, new MistralChatModelFactory());
factories.put(ModelProvider.Anthropic, new AnthropicChatModelFactory());
factories.put(ModelProvider.Groq, new GroqChatModelFactory());
factories.put(ModelProvider.Gemini, new GeminiChatModelFactory());
factories.put(ModelProvider.Google, new GeminiChatModelFactory());
// TODO Currently broken by latest Jan! version
// factories.put(ModelProvider.Jan, new JanChatModelFactory());
}
Expand Down Expand Up @@ -107,7 +108,8 @@ private String getModelName(@Nullable LanguageModel languageModel) {
.orElseGet(() -> getDefaultModelName(languageModel.getProvider()));
}

private String getDefaultModelName(@Nullable ModelProvider provider) {
@Contract(pure = true)
private @NotNull String getDefaultModelName(@Nullable ModelProvider provider) {
if (provider == null) {
return "DefaultModel";
}
Expand All @@ -118,7 +120,8 @@ private String getDefaultModelName(@Nullable ModelProvider provider) {
};
}

private static void setMaxOutputTokens(@NotNull DevoxxGenieStateService settingsState, ChatModel chatModel) {
private static void setMaxOutputTokens(@NotNull DevoxxGenieStateService settingsState,
@NotNull ChatModel chatModel) {
Integer maxOutputTokens = settingsState.getMaxOutputTokens();
chatModel.setMaxTokens(maxOutputTokens != null ? maxOutputTokens : Constant.MAX_OUTPUT_TOKENS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Gemini);
return getModels(ModelProvider.Google);
}
}
71 changes: 67 additions & 4 deletions src/main/java/com/devoxx/genie/model/LanguageModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import lombok.Data;
import org.jetbrains.annotations.NotNull;

import java.util.Comparator;

@Data
@Builder
public class LanguageModel implements Comparable<LanguageModel> {
Expand All @@ -16,11 +18,72 @@ public class LanguageModel implements Comparable<LanguageModel> {
private double outputCost;
private int contextWindow;

public int compareTo(@NotNull LanguageModel languageModel) {
return this.displayName.compareTo(languageModel.displayName);
@Override
public String toString() {
return displayName;
}

public String toString() {
return provider.getName();
@Override
public int compareTo(@NotNull LanguageModel other) {
return new ModelVersionComparator().compare(this.displayName, other.displayName);
}

private static class ModelVersionComparator implements Comparator<String> {
@Override
public int compare(String v1, String v2) {
String[] parts1 = v1.split(" ");
String[] parts2 = v2.split(" ");

// Compare model names
int modelNameCompare = parts1[0].compareTo(parts2[0]);
if (modelNameCompare != 0) return modelNameCompare;

// Extract version strings
String version1 = parts1.length > 1 ? parts1[1] : "";
String version2 = parts2.length > 1 ? parts2[1] : "";

// Handle special versions (Sonnet, Haiku, Opus)
if (isSpecialVersion(version1) || isSpecialVersion(version2)) {
return compareSpecialVersions(version1, version2);
}

// Compare version strings
return compareVersions(version1, version2);
}

private boolean isSpecialVersion(@NotNull String version) {
return version.equals("Sonnet") || version.equals("Haiku") || version.equals("Opus");
}

private int compareSpecialVersions(@NotNull String v1, String v2) {
if (v1.equals(v2)) return 0;
if (v1.equals("Opus")) return 1;
if (v2.equals("Opus")) return -1;
if (v1.equals("Sonnet")) return 1;
if (v2.equals("Sonnet")) return -1;
return v1.compareTo(v2);
}

private int compareVersions(@NotNull String v1, @NotNull String v2) {
String[] parts1 = v1.split("[^a-zA-Z0-9]+");
String[] parts2 = v2.split("[^a-zA-Z0-9]+");

for (int i = 0; i < Math.max(parts1.length, parts2.length); i++) {
String part1 = i < parts1.length ? parts1[i] : "";
String part2 = i < parts2.length ? parts2[i] : "";

int cmp = compareAlphanumeric(part1, part2);
if (cmp != 0) return cmp;
}

return 0;
}

private int compareAlphanumeric(@NotNull String s1, String s2) {
if (s1.matches("\\d+") && s2.matches("\\d+")) {
return Integer.compare(Integer.parseInt(s1), Integer.parseInt(s2));
}
return s1.compareTo(s2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public enum ModelProvider {
Mistral("Mistral"),
Groq("Groq"),
DeepInfra("DeepInfra"),
Gemini("Gemini");
Google("Google");

private final String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,6 @@ public GeminiChatModel(String apiKey,
.modelName(modelName)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.build();

// messageRequest.setGenerationConfig(GenerationConfig.builder()
// .maxOutputTokens(maxTokens)
// .temperature(temperature)
// .build());

// messageRequest.setSystemInstruction(SystemInstruction.builder()
// .parts(List.of(Part.builder().text("Always return response in markdown").build()))
// .build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ private void addDeepInfraModels() {
private void addGeminiModels() {

models.add(LanguageModel.builder()
.provider(ModelProvider.Gemini)
.provider(ModelProvider.Google)
.modelName("gemini-1.5-flash-latest")
.displayName("Gemini 1.5 Flash")
.inputCost(0.7)
Expand All @@ -283,17 +283,17 @@ private void addGeminiModels() {
.build());

models.add(LanguageModel.builder()
.provider(ModelProvider.Gemini)
.provider(ModelProvider.Google)
.modelName("gemini-1.5-pro-latest")
.displayName("Gemini 1.5 Pro")
.inputCost(7)
.outputCost(21)
.contextWindow(1_000_000)
.contextWindow(2_000_000)
.apiKeyUsed(true)
.build());

models.add(LanguageModel.builder()
.provider(ModelProvider.Gemini)
.provider(ModelProvider.Google)
.modelName("gemini-1.0-pro")
.displayName("Gemini 1.0 Pro")
.inputCost(0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.devoxx.genie.model.enumarations.ModelProvider.*;
import static com.devoxx.genie.model.enumarations.ModelProvider.Gemini;
import static com.devoxx.genie.model.enumarations.ModelProvider.Google;

public class LLMProviderService {

Expand All @@ -37,7 +36,7 @@ public List<ModelProvider> getModelProvidersWithApiKeyConfigured() {
providerKeyMap.put(Mistral, settings::getMistralKey);
providerKeyMap.put(Groq, settings::getGroqKey);
providerKeyMap.put(DeepInfra, settings::getDeepInfraKey);
providerKeyMap.put(Gemini, settings::getGeminiKey);
providerKeyMap.put(Google, settings::getGeminiKey);

// Filter out cloud LLM providers that do not have a key
List<ModelProvider> providersWithRequiredKey = LLMModelRegistryService.getInstance().getModels()
Expand Down
12 changes: 10 additions & 2 deletions src/main/java/com/devoxx/genie/service/ProjectContentService.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,25 @@ public void calculateTokensAndCost(Project project,
.thenAccept(projectContent -> {
int tokenCount = ENCODING.countTokens(projectContent);
double estimatedInputCost = calculateCost(tokenCount, inputCost);
String message = String.format("Project contains %s. Estimated min. cost using %s is $%.6f",
String message = String.format("Project contains %s. Estimated min. cost using %s %s is $%.6f",
WindowContextFormatterUtil.format(tokenCount, "tokens"),
provider.getName(),
languageModel.getDisplayName(),
estimatedInputCost);

// Add check for token count exceeding max context size
if (tokenCount > languageModel.getContextWindow()) {
message += String.format(". Total project size exceeds model's max context of %s tokens.",
WindowContextFormatterUtil.format(languageModel.getContextWindow()));
}

NotificationUtil.sendNotification(project, message);
});
}

private Encoding getEncodingForProvider(@NotNull ModelProvider provider) {
return switch (provider) {
case OpenAI, Anthropic, Gemini ->
case OpenAI, Anthropic, Google ->
Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
case Mistral, DeepInfra, Groq ->
// These often use the Llama tokenizer or similar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.util.*;
import java.util.List;
import java.util.Optional;
import java.util.ResourceBundle;
import java.util.stream.Stream;

import static com.devoxx.genie.model.Constant.MESSAGES;
Expand Down Expand Up @@ -211,7 +210,7 @@ private void addModelProvidersToComboBox() {
providerService.getLocalModelProviders().stream()
)
.distinct()
.sorted()
.sorted(Comparator.comparing(ModelProvider::getName))
.forEach(modelProviderComboBox::addItem);
}

Expand Down Expand Up @@ -369,13 +368,12 @@ private void updateModelNamesComboBox(String modelProvider) {
*/
private void populateModelNames(@NotNull ChatModelFactory chatModelFactory) {
modelNameComboBox.removeAllItems();
List<LanguageModel> modelNames = chatModelFactory.getModels();
List<LanguageModel> modelNames = new ArrayList<>(chatModelFactory.getModels());
if (modelNames.isEmpty()) {
hideModelNameComboBox();
} else {
modelNames.stream()
.sorted()
.forEach(modelNameComboBox::addItem);
modelNames.sort(Comparator.naturalOrder());
modelNames.forEach(modelNameComboBox::addItem);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ private boolean isProjectContextSupportedProvider() {
return selectedProvider != null && (
selectedProvider.equals(ModelProvider.OpenAI) ||
selectedProvider.equals(ModelProvider.Anthropic) ||
selectedProvider.equals(ModelProvider.Gemini)
selectedProvider.equals(ModelProvider.Google)
);
}

Expand Down Expand Up @@ -409,7 +409,7 @@ private void addProjectToContext() {
return;
}

if (!modelProvider.equals(ModelProvider.Gemini) &&
if (!modelProvider.equals(ModelProvider.Google) &&
!modelProvider.equals(ModelProvider.Anthropic) &&
!modelProvider.equals(ModelProvider.OpenAI)) {
NotificationUtil.sendNotification(project,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,39 @@ public void setModelCost(ModelProvider provider,
}
}

public double getModelInputCost(ModelProvider provider, String modelName) {
// public double getModelInputCost(ModelProvider provider, String modelName) {
// if (DefaultLLMSettingsUtil.isApiBasedProvider(provider)) {
// String key = provider.getName() + ":" + modelName;
// return modelInputCosts.getOrDefault(key,
// DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0));
// }
// return 0.0;
// }

public double getModelInputCost(@NotNull ModelProvider provider, String modelName) {
String key = provider.getName() + ":" + modelName;
double cost = modelInputCosts.getOrDefault(key, 0.0);
if (cost == 0.0) {
DefaultLLMSettingsUtil.CostKey costKey = new DefaultLLMSettingsUtil.CostKey(provider, modelName);
cost = DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(costKey, 0.0);
if (cost == 0.0) {
// Fallback to similar model names
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.entrySet()) {
if (entry.getKey().provider == provider && entry.getKey().modelName.startsWith(modelName.split("-")[0])) {
cost = entry.getValue();
break;
}
}
}
}
return cost;
}

public double getModelOutputCost(ModelProvider provider, String modelName) {
if (DefaultLLMSettingsUtil.isApiBasedProvider(provider)) {
String key = provider.getName() + ":" + modelName;
return modelInputCosts.getOrDefault(key,
DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0));
return modelOutputCosts.getOrDefault(key,
DefaultLLMSettingsUtil.DEFAULT_OUTPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0));
}
return 0.0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public JPanel createPanel() {
addSettingRow(panel, gbc, "Anthropic API Key", createTextWithPasswordButton(anthropicApiKeyField, "https://console.anthropic.com/settings/keys"));
addSettingRow(panel, gbc, "Groq API Key", createTextWithPasswordButton(groqApiKeyField, "https://console.groq.com/keys"));
addSettingRow(panel, gbc, "DeepInfra API Key", createTextWithPasswordButton(deepInfraApiKeyField, "https://deepinfra.com/dash/api_keys"));
addSettingRow(panel, gbc, "Gemini API Key", createTextWithPasswordButton(geminiApiKeyField, "https://aistudio.google.com/app/apikey"));
addSettingRow(panel, gbc, "Google Gemini API Key", createTextWithPasswordButton(geminiApiKeyField, "https://aistudio.google.com/app/apikey"));

addSection(panel, gbc, "Search Providers");
addSettingRow(panel, gbc, "Tavily Web Search API Key", createTextWithPasswordButton(tavilySearchApiKeyField, "https://app.tavily.com/home"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ public class WindowContextFormatterUtil {

public static @NotNull String format(int tokens, String suffix) {
if (tokens >= 1_000_000_000) {
return String.format("%.2fB %s", tokens / 1_000_000_000.0, suffix);
return String.format("%dB %s", (tokens / 1_000_000_000), suffix);
} else if (tokens >= 1_000_000) {
return String.format("%.2fM %s", tokens / 1_000_000.0, suffix);
return String.format("%dM %s", (tokens / 1_000_000), suffix);
} else if (tokens >= 1_000) {
return String.format("%.2fK %s", tokens / 1_000.0, suffix);
return String.format("%dK %s", (tokens / 1_000), suffix);
} else {
return String.format("%d %s", tokens, suffix);
}
Expand Down
Loading

0 comments on commit 9d52a90

Please sign in to comment.