Skip to content

Commit

Permalink
Merge pull request #412 from devoxx/issue-411
Browse files Browse the repository at this point in the history
Issue 411
  • Loading branch information
stephanj authored Dec 21, 2024
2 parents 315b075 + 786d79c commit 3325961
Show file tree
Hide file tree
Showing 32 changed files with 194 additions and 122 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.4.6"
version = "0.4.7"

repositories {
mavenCentral()
Expand Down
19 changes: 15 additions & 4 deletions core/src/main/java/com/devoxx/genie/model/LanguageModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ public class LanguageModel implements Comparable<LanguageModel> {
private boolean apiKeyUsed;
private double inputCost;
private double outputCost;
private int contextWindow;
private int inputMaxTokens;
private int outputMaxTokens;

public LanguageModel() {
this(ModelProvider.OpenAI, "", "", false, 0.0, 0.0, 0);

this(ModelProvider.OpenAI,
"",
"",
false,
0.0,
0.0,
0,
0);
}

public LanguageModel(ModelProvider provider,
Expand All @@ -29,14 +38,16 @@ public LanguageModel(ModelProvider provider,
boolean apiKeyUsed,
double inputCost,
double outputCost,
int contextWindow) {
int inputMaxTokens,
int outputMaxTokens) {
this.provider = provider;
this.modelName = modelName;
this.displayName = displayName;
this.apiKeyUsed = apiKeyUsed;
this.inputCost = inputCost;
this.outputCost = outputCost;
this.contextWindow = contextWindow;
this.inputMaxTokens = inputMaxTokens;
this.outputMaxTokens = outputMaxTokens;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private void addDirectoryToContext(Project project, @NotNull VirtualFile directo
.filter(model -> model.getProvider().getName().equals(selectedProvider.getName()) &&
model.getModelName().equals(selectedModel))
.findFirst()
.map(LanguageModel::getContextWindow);
.map(LanguageModel::getInputMaxTokens);

ProjectContentService.getInstance()
.getDirectoryContent(project, directory, contextWindow.orElse(settings.getDefaultWindowContext()), false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

import com.devoxx.genie.chatmodel.cloud.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.azureopenai.AzureOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.local.customopenai.CustomOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.deepseek.DeepSeekChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.local.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openrouter.OpenRouterChatModelFactory;
import com.devoxx.genie.chatmodel.local.customopenai.CustomOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.local.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.local.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.local.llamaCPP.LlamaChatModelFactory;
import com.devoxx.genie.chatmodel.local.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.local.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openrouter.OpenRouterChatModelFactory;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.chatmodel.local.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.local.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.local.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.local.ollama.OllamaChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.Constant;
import com.devoxx.genie.model.LanguageModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public List<LanguageModel> getModels() {
.displayName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.inputCost(0.0)
.outputCost(0.0)
.contextWindow(0)
.inputMaxTokens(0)
.apiKeyUsed(true)
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel;
import org.jetbrains.annotations.NotNull;

import java.time.Duration;
import java.util.List;

public class GoogleChatModelFactory implements ChatModelFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public List<LanguageModel> getModels() {
.displayName(model.getName())
.inputCost(inputCost)
.outputCost(outputCost)
.contextWindow(model.getContextLength() == null ? model.getTopProvider().getContextLength() : model.getContextLength())
.inputMaxTokens(model.getContextLength() == null ? model.getTopProvider().getContextLength() : model.getContextLength())
.apiKeyUsed(true)
.build();
synchronized (modelNames) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected LanguageModel buildLanguageModel(Object model) {
.displayName(janModel.getName())
.inputCost(0)
.outputCost(0)
.contextWindow(janModel.getCtxLen() == null ? 8_000 : janModel.getSettings().getCtxLen())
.inputMaxTokens(janModel.getCtxLen() == null ? 8_000 : janModel.getSettings().getCtxLen())
.apiKeyUsed(false)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public List<LanguageModel> getModels() {
.displayName(TEST_MODEL)
.inputCost(0)
.outputCost(0)
.contextWindow(8000)
.inputMaxTokens(8000)
.apiKeyUsed(false)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected LanguageModel buildLanguageModel(Object model) {
.displayName(lmStudioModel.getId())
.inputCost(0)
.outputCost(0)
.contextWindow(DEFAULT_CONTEXT_LENGTH)
.inputMaxTokens(DEFAULT_CONTEXT_LENGTH)
.apiKeyUsed(false)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected LanguageModel buildLanguageModel(Object model) throws IOException {
.displayName(ollamaModel.getName())
.inputCost(0)
.outputCost(0)
.contextWindow(contextWindow)
.inputMaxTokens(contextWindow)
.apiKeyUsed(false)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsSer
.apiKeyUsed(false)
.inputCost(0)
.outputCost(0)
.contextWindow(4096)
.inputMaxTokens(4096)
.build();
} else {
String modelName = stateService.getSelectedLanguageModel(project.getLocationHash());
Expand All @@ -146,7 +146,7 @@ private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsSer
.apiKeyUsed(false)
.inputCost(0)
.outputCost(0)
.contextWindow(128_000)
.inputMaxTokens(128_000)
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ private int getWindowContext() {
LanguageModel languageModel = (LanguageModel) modelNameComboBox.getSelectedItem();
int tokenLimit = 4096;
if (languageModel != null) {
tokenLimit = languageModel.getContextWindow();
tokenLimit = languageModel.getInputMaxTokens();
}
return tokenLimit;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void calculateTokensAndCost() {
return;
}

int maxTokens = selectedModel.getContextWindow();
int maxTokens = selectedModel.getInputMaxTokens();
boolean isApiKeyBased = DefaultLLMSettingsUtil.isApiKeyBasedProvider(selectedProvider);

// Perform the token and cost calculation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import com.devoxx.genie.util.ChatMessageContextUtil;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.project.Project;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import org.jetbrains.annotations.NotNull;
Expand Down
Loading

0 comments on commit 3325961

Please sign in to comment.