Skip to content

Commit

Permalink
Merge pull request #77 from devoxx/issue-70
Browse files Browse the repository at this point in the history
Issue 70
  • Loading branch information
stephanj authored May 26, 2024
2 parents 841b909 + 42c04de commit 8b2655e
Show file tree
Hide file tree
Showing 68 changed files with 978 additions and 510 deletions.
14 changes: 7 additions & 7 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.1.11"
version = "0.1.12"

repositories {
mavenCentral()
}

dependencies {
implementation("dev.langchain4j:langchain4j:0.30.0")
implementation("dev.langchain4j:langchain4j-ollama:0.30.0")
implementation("dev.langchain4j:langchain4j-local-ai:0.30.0")
implementation("dev.langchain4j:langchain4j-open-ai:0.30.0")
implementation("dev.langchain4j:langchain4j-anthropic:0.30.0")
implementation("dev.langchain4j:langchain4j-mistral-ai:0.30.0")
implementation("dev.langchain4j:langchain4j:0.31.0")
implementation("dev.langchain4j:langchain4j-ollama:0.31.0")
implementation("dev.langchain4j:langchain4j-local-ai:0.31.0")
implementation("dev.langchain4j:langchain4j-open-ai:0.31.0")
implementation("dev.langchain4j:langchain4j-anthropic:0.31.0")
implementation("dev.langchain4j:langchain4j-mistral-ai:0.31.0")
implementation("org.commonmark:commonmark:0.22.0")

compileOnly("org.projectlombok:lombok:1.18.32")
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/com/devoxx/genie/action/AddFileAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import com.intellij.openapi.actionSystem.CommonDataKeys;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.openapi.wm.ToolWindow;
import com.intellij.openapi.wm.ToolWindowManager;
import org.jetbrains.annotations.NotNull;

import static com.devoxx.genie.ui.util.WindowPluginUtil.ensureToolWindowVisible;
Expand All @@ -17,6 +15,7 @@ public class AddFileAction extends AnAction {

/**
* Add file to the window context.
*
* @param e the action event
*/
@Override
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/com/devoxx/genie/action/AddSnippetAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class AddSnippetAction extends AnAction {

/**
* Add a snippet to the tool window.
*
* @param e the action event
*/
@Override
Expand All @@ -53,6 +54,7 @@ public void actionPerformed(@NotNull AnActionEvent e) {

/**
* Add the selected file to the file list manager.
*
* @param selectedFile the selected file
*/
private static void addSelectedFile(VirtualFile selectedFile) {
Expand All @@ -65,9 +67,10 @@ private static void addSelectedFile(VirtualFile selectedFile) {

/**
* Create a virtual file and add it to the file list manager.
* @param originalFile the original file
*
* @param originalFile the original file
* @param selectionModel the selection model
* @param selectedText the selected text
* @param selectedText the selected text
*/
private void createAndAddVirtualFile(VirtualFile originalFile,
SelectionModel selectionModel,
Expand Down
35 changes: 18 additions & 17 deletions src/main/java/com/devoxx/genie/chatmodel/ChatModelFactory.java
Original file line number Diff line number Diff line change
@@ -1,45 +1,46 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.ui.SettingsState;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.jetbrains.annotations.NotNull;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;

import java.util.List;

public interface ChatModelFactory {

/**
* Create a chat model with the given parameters.
*
* @param chatModel the chat model
* @return the chat model
*/
ChatLanguageModel createChatModel(ChatModel chatModel);

/**
* Create a streaming chat model with the given parameters.
*
* @param chatModel the chat model
* @return the streaming chat model
*/
default StreamingChatLanguageModel createStreamingChatModel(ChatModel chatModel) {
return null;
}

/**
* List the available model names.
*
* @return the list of model names
*/
default List<String> getModelNames() {
return List.of();
}

default String getApiKey() {
return "";
}

/**
* Get the base URL by the model type.
* @param modelProvider the language model provider
* @return the base URL
* Get the model provider API key.
*
* @return the API key
*/
default String getBaseUrlByType(@NotNull ModelProvider modelProvider) {
return switch (modelProvider) {
case GPT4All -> SettingsState.getInstance().getGpt4allModelUrl();
case LMStudio -> SettingsState.getInstance().getLmstudioModelUrl();
case Ollama -> SettingsState.getInstance().getOllamaModelUrl();
default -> "na";
};
default String getApiKey() {
return "";
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.chatmodel.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.gemini.GeminiChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.gemini.GeminiChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.model.enumarations.ModelProvider;
import org.jetbrains.annotations.NotNull;

import java.util.Map;
Expand All @@ -29,10 +29,11 @@ public class ChatModelFactoryProvider {
ModelProvider.DeepInfra, DeepInfraChatModelFactory::new,
ModelProvider.Gemini, GeminiChatModelFactory::new,
ModelProvider.Jan, JanChatModelFactory::new
);
);

/**
* Get the factory by provider.
*
* @param provider the provider
* @return the factory
*/
Expand Down
36 changes: 28 additions & 8 deletions src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.SettingsState;
import com.intellij.openapi.diagnostic.Logger;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Setter;
import org.jetbrains.annotations.NotNull;

Expand All @@ -24,12 +24,9 @@

@Setter
public class ChatModelProvider {
private static final Logger LOG = Logger.getInstance(ChatModelProvider.class);

private final Map<ModelProvider, ChatModelFactory> factories = new HashMap<>();

private String modelName;

public ChatModelProvider() {
factories.put(ModelProvider.Ollama, new OllamaChatModelFactory());
factories.put(ModelProvider.LMStudio, new LMStudioChatModelFactory());
Expand All @@ -44,22 +41,44 @@ public ChatModelProvider() {

/**
* Get the chat language model for selected model provider.
*
* @param chatMessageContext the chat message context
* @return the chat language model
*/
public ChatLanguageModel getChatLanguageModel(@NotNull ChatMessageContext chatMessageContext) {
ChatModel chatModel = initChatModel(chatMessageContext);
return getFactory(chatMessageContext).createChatModel(chatModel);
}

/**
* Get the streaming chat language model for selected model provider.
*
* @param chatMessageContext the chat message context
* @return the streaming chat language model
*/
public StreamingChatLanguageModel getStreamingChatLanguageModel(@NotNull ChatMessageContext chatMessageContext) {
ChatModel chatModel = initChatModel(chatMessageContext);
return getFactory(chatMessageContext).createStreamingChatModel(chatModel);
}

/**
* Get the chat model factory for the selected model provider.
*
* @param chatMessageContext the chat message context
* @return the chat model factory
*/
private @NotNull ChatModelFactory getFactory(@NotNull ChatMessageContext chatMessageContext) {
ModelProvider provider = ModelProvider.valueOf(chatMessageContext.getLlmProvider());
LOG.info("Chat model provider: " + provider);
ChatModelFactory factory = factories.get(provider);
if (factory == null) {
throw new IllegalArgumentException("No factory for provider: " + provider);
}
ChatModel chatModel = initChatModel(chatMessageContext);
return factory.createChatModel(chatModel);
return factory;
}

/**
* Initialize chat model settings by default or by user settings.
*
* @return the chat model
*/
public @NotNull ChatModel initChatModel(@NotNull ChatMessageContext chatMessageContext) {
Expand All @@ -78,8 +97,9 @@ public ChatLanguageModel getChatLanguageModel(@NotNull ChatMessageContext chatMe
/**
* Set max output tokens.
* Some extra work because of the settings state that didn't like the integer input field.
*
* @param settingsState the settings state
* @param chatModel the chat model
* @param chatModel the chat model
*/
private static void setMaxOutputTokens(@NotNull SettingsState settingsState, ChatModel chatModel) {
String maxOutputTokens = settingsState.getMaxOutputTokens();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import dev.langchain4j.model.anthropic.AnthropicChatModel;
import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import org.jetbrains.annotations.NotNull;

import java.util.List;
Expand All @@ -25,6 +27,17 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
.build();
}

@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return AnthropicStreamingChatModel.builder()
.apiKey(getApiKey())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.maxTokens(chatModel.getMaxTokens())
.build();
}

@Override
public String getApiKey() {
return SettingsState.getInstance().getAnthropicKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.jetbrains.annotations.NotNull;

import java.time.Duration;
Expand All @@ -26,25 +28,37 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
.build();
}

@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return OpenAiStreamingChatModel.builder()
.baseUrl("https://api.deepinfra.com/v1/openai")
.apiKey(getApiKey())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.build();
}

@Override
public String getApiKey() {
return SettingsState.getInstance().getDeepInfraKey();
}

@Override
public List<String> getModelNames() {
return List.of(
"meta-llama/Meta-Llama-3-70B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
"microsoft/WizardLM-2-8x22B",
"microsoft/WizardLM-2-7B",
"databricks/dbrx-instruct",
"openchat/openchat_3.5",
"google/gemma-7b-it",
"Phind/Phind-CodeLlama-34B-v2",
"bigcode/starcoder2-15b"
);
return List.of(
"meta-llama/Meta-Llama-3-70B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
"microsoft/WizardLM-2-8x22B",
"microsoft/WizardLM-2-7B",
"databricks/dbrx-instruct",
"openchat/openchat_3.5",
"google/gemma-7b-it",
"Phind/Phind-CodeLlama-34B-v2",
"bigcode/starcoder2-15b"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.ui.SettingsState;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.localai.LocalAiStreamingChatModel;
import org.jetbrains.annotations.NotNull;

import java.time.Duration;
Expand All @@ -14,9 +15,8 @@ public class GPT4AllChatModelFactory implements ChatModelFactory {

@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
chatModel.setBaseUrl(SettingsState.getInstance().getGpt4allModelUrl());
return LocalAiChatModel.builder()
.baseUrl(getBaseUrlByType(ModelProvider.GPT4All))
.baseUrl(SettingsState.getInstance().getGpt4allModelUrl())
.modelName("test-model")
.maxRetries(chatModel.getMaxRetries())
.maxTokens(chatModel.getMaxTokens())
Expand All @@ -25,4 +25,15 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
.topP(chatModel.getTopP())
.build();
}

public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(SettingsState.getInstance().getGpt4allModelUrl())
.modelName("test-model")
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.build();
}

}
Loading

0 comments on commit 8b2655e

Please sign in to comment.