From aaa3343f9179195e72f37db225212648fadd06e3 Mon Sep 17 00:00:00 2001 From: Stephan Janssen Date: Mon, 16 Dec 2024 09:53:04 +0100 Subject: [PATCH] Issue #392 GPT4All regression fixed --- .../com/devoxx/genie/model/gpt4all/Model.java | 32 +++++ .../genie/model/gpt4all/ModelPermission.java | 46 +++++++ .../genie/model/gpt4all/ResponseDTO.java | 17 +++ .../java/com/devoxx/genie/model/jan/Data.java | 12 ++ .../devoxx/genie/model/jan/ResponseDTO.java | 5 +- .../chatmodel/ChatModelFactoryProvider.java | 2 + .../chatmodel/LocalChatModelFactory.java | 113 +++++++++++++++++ .../gpt4all/GPT4AllChatModelFactory.java | 69 +++++------ .../chatmodel/jan/JanChatModelFactory.java | 98 ++++----------- .../chatmodel/lmstudio/LMStudioChatModel.java | 7 +- .../lmstudio/LMStudioChatModelFactory.java | 107 ++++++---------- .../ollama/OllamaChatModelFactory.java | 117 ++++++------------ .../genie/service/gpt4all/GPT4AllService.java | 55 ++++++++ .../genie/ui/panel/LlmProviderPanel.java | 7 +- src/main/resources/META-INF/plugin.xml | 2 + src/main/resources/application.properties | 2 +- 16 files changed, 417 insertions(+), 274 deletions(-) create mode 100644 core/src/main/java/com/devoxx/genie/model/gpt4all/Model.java create mode 100644 core/src/main/java/com/devoxx/genie/model/gpt4all/ModelPermission.java create mode 100644 core/src/main/java/com/devoxx/genie/model/gpt4all/ResponseDTO.java create mode 100644 src/main/java/com/devoxx/genie/chatmodel/LocalChatModelFactory.java create mode 100644 src/main/java/com/devoxx/genie/service/gpt4all/GPT4AllService.java diff --git a/core/src/main/java/com/devoxx/genie/model/gpt4all/Model.java b/core/src/main/java/com/devoxx/genie/model/gpt4all/Model.java new file mode 100644 index 00000000..4d91fd28 --- /dev/null +++ b/core/src/main/java/com/devoxx/genie/model/gpt4all/Model.java @@ -0,0 +1,32 @@ +package com.devoxx.genie.model.gpt4all; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +@Getter +@Setter +public class Model { + @JsonProperty("created") + private long created; + + @JsonProperty("id") + private String id; + + @JsonProperty("object") + private String object; + + @JsonProperty("owned_by") + private String ownedBy; + + @JsonProperty("parent") + private String parent; + + @JsonProperty("permissions") + private List permissions; + + @JsonProperty("root") + private String root; +} \ No newline at end of file diff --git a/core/src/main/java/com/devoxx/genie/model/gpt4all/ModelPermission.java b/core/src/main/java/com/devoxx/genie/model/gpt4all/ModelPermission.java new file mode 100644 index 00000000..5c260457 --- /dev/null +++ b/core/src/main/java/com/devoxx/genie/model/gpt4all/ModelPermission.java @@ -0,0 +1,46 @@ +package com.devoxx.genie.model.gpt4all; + + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class ModelPermission { + @JsonProperty("allow_create_engine") + private boolean allowCreateEngine; + + @JsonProperty("allow_fine_tuning") + private boolean allowFineTuning; + + @JsonProperty("allow_logprobs") + private boolean allowLogprobs; + + @JsonProperty("allow_sampling") + private boolean allowSampling; + + @JsonProperty("allow_search_indices") + private boolean allowSearchIndices; + + @JsonProperty("allow_view") + private boolean allowView; + + @JsonProperty("created") + private long created; + + @JsonProperty("group") + private String group; + + @JsonProperty("id") + private String id; + + @JsonProperty("is_blocking") + private boolean isBlocking; + + @JsonProperty("object") + private String object; + + @JsonProperty("organization") + private String organization; +} \ No newline at end of file diff --git a/core/src/main/java/com/devoxx/genie/model/gpt4all/ResponseDTO.java b/core/src/main/java/com/devoxx/genie/model/gpt4all/ResponseDTO.java new file mode 100644 index 00000000..eda9ca09 --- /dev/null +++ b/core/src/main/java/com/devoxx/genie/model/gpt4all/ResponseDTO.java @@ -0,0 +1,17 @@ +package com.devoxx.genie.model.gpt4all; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +@Getter +@Setter +public class ResponseDTO { + @JsonProperty("data") + private List data; + + @JsonProperty("object") + private String object; +} \ No newline at end of file diff --git a/core/src/main/java/com/devoxx/genie/model/jan/Data.java b/core/src/main/java/com/devoxx/genie/model/jan/Data.java index 9150fd0b..1ff9007c 100644 --- a/core/src/main/java/com/devoxx/genie/model/jan/Data.java +++ b/core/src/main/java/com/devoxx/genie/model/jan/Data.java @@ -14,6 +14,18 @@ public class Data { @JsonProperty("object") private String object; + @JsonProperty("ctx_len") + private Integer ctxLen; + + @JsonProperty("max_tokens") + private Integer maxTokens; + + @JsonProperty("top_k") + private Integer topK; + + @JsonProperty("top_p") + private Double topP; + @JsonProperty("name") private String name; diff --git a/core/src/main/java/com/devoxx/genie/model/jan/ResponseDTO.java b/core/src/main/java/com/devoxx/genie/model/jan/ResponseDTO.java index b5381672..26bbb25b 100644 --- a/core/src/main/java/com/devoxx/genie/model/jan/ResponseDTO.java +++ b/core/src/main/java/com/devoxx/genie/model/jan/ResponseDTO.java @@ -10,10 +10,9 @@ @Getter public class ResponseDTO { - @JsonProperty("object") - private String object; - @JsonProperty("data") private List data; + @JsonProperty("object") + private String object; } diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java index b3b7b362..32adb3be 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java @@ -6,6 +6,7 @@ import com.devoxx.genie.chatmodel.deepseek.DeepSeekChatModelFactory; import com.devoxx.genie.chatmodel.exo.ExoChatModelFactory; import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory; +import com.devoxx.genie.chatmodel.gpt4all.GPT4AllChatModelFactory; import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory; import com.devoxx.genie.chatmodel.jan.JanChatModelFactory; import com.devoxx.genie.chatmodel.jlama.JLamaChatModelFactory; @@ -55,6 +56,7 @@ private ChatModelFactoryProvider() { case "DeepSeek" -> new DeepSeekChatModelFactory(); case "Jlama" -> new JLamaChatModelFactory(); case "AzureOpenAI" -> new AzureOpenAIChatModelFactory(); + case "GPT4All" -> new GPT4AllChatModelFactory(); default -> null; }; } diff --git a/src/main/java/com/devoxx/genie/chatmodel/LocalChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/LocalChatModelFactory.java new file mode 100644 index 00000000..251b1d20 --- /dev/null +++ b/src/main/java/com/devoxx/genie/chatmodel/LocalChatModelFactory.java @@ -0,0 +1,113 @@ +package com.devoxx.genie.chatmodel; + +import com.devoxx.genie.model.ChatModel; +import com.devoxx.genie.model.LanguageModel; +import com.devoxx.genie.model.enumarations.ModelProvider; +import com.devoxx.genie.ui.util.NotificationUtil; +import com.intellij.openapi.project.ProjectManager; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.localai.LocalAiChatModel; +import dev.langchain4j.model.localai.LocalAiStreamingChatModel; +import org.jetbrains.annotations.NotNull; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public abstract class LocalChatModelFactory implements ChatModelFactory { + + protected final ModelProvider modelProvider; + protected List cachedModels = null; + protected static final ExecutorService executorService = Executors.newFixedThreadPool(5); + protected static boolean warningShown = false; + + protected LocalChatModelFactory(ModelProvider modelProvider) { + this.modelProvider = modelProvider; + } + + @Override + public abstract ChatLanguageModel createChatModel(@NotNull ChatModel chatModel); + + @Override + public abstract StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel); + + protected abstract String getModelUrl(); + + protected ChatLanguageModel createLocalAiChatModel(@NotNull ChatModel chatModel) { + return LocalAiChatModel.builder() + .baseUrl(getModelUrl()) + .modelName(chatModel.getModelName()) + .maxRetries(chatModel.getMaxRetries()) + .temperature(chatModel.getTemperature()) + .maxTokens(chatModel.getMaxTokens()) + .timeout(Duration.ofSeconds(chatModel.getTimeout())) + .topP(chatModel.getTopP()) + .build(); + } + + protected StreamingChatLanguageModel createLocalAiStreamingChatModel(@NotNull ChatModel chatModel) { + return LocalAiStreamingChatModel.builder() + .baseUrl(getModelUrl()) + .modelName(chatModel.getModelName()) + .temperature(chatModel.getTemperature()) + .topP(chatModel.getTopP()) + .timeout(Duration.ofSeconds(chatModel.getTimeout())) + .build(); + } + + @Override + public List getModels() { + if (cachedModels != null) { + return cachedModels; + } + List modelNames = new ArrayList<>(); + List> futures = new ArrayList<>(); + try { + Object[] models = fetchModels(); + for (Object model : models) { + CompletableFuture future = CompletableFuture.runAsync(() -> { + try { + LanguageModel languageModel = buildLanguageModel(model); + synchronized (modelNames) { + modelNames.add(languageModel); + } + } catch (IOException e) { + handleModelFetchError(model, e); + } + }, executorService); + futures.add(future); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + cachedModels = modelNames; + } catch (IOException e) { + handleGeneralFetchError(e); + cachedModels = List.of(); + } + return cachedModels; + } + + protected abstract Object[] fetchModels() throws IOException; + + protected abstract LanguageModel buildLanguageModel(Object model) throws IOException; + + protected void handleModelFetchError(Object model, @NotNull IOException e) { + NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching model details: " + e.getMessage()); + } + + protected void handleGeneralFetchError(IOException e) { + if (!warningShown) { + NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching models: " + e.getMessage()); + warningShown = true; + } + } + + @Override + public void resetModels() { + cachedModels = null; + } +} \ No newline at end of file diff --git a/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java index 8b6e2afa..3b19cf64 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java @@ -1,61 +1,56 @@ package com.devoxx.genie.chatmodel.gpt4all; -import com.devoxx.genie.chatmodel.ChatModelFactory; +import com.devoxx.genie.chatmodel.LocalChatModelFactory; import com.devoxx.genie.model.ChatModel; import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; +import com.devoxx.genie.model.gpt4all.Model; +import com.devoxx.genie.service.gpt4all.GPT4AllService; import com.devoxx.genie.ui.settings.DevoxxGenieStateService; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.localai.LocalAiChatModel; -import dev.langchain4j.model.localai.LocalAiStreamingChatModel; import org.jetbrains.annotations.NotNull; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; +import java.io.IOException; -public class GPT4AllChatModelFactory implements ChatModelFactory { +public class GPT4AllChatModelFactory extends LocalChatModelFactory { - private final ModelProvider MODEL_PROVIDER = ModelProvider.GPT4All;; + public GPT4AllChatModelFactory() { + super(ModelProvider.GPT4All); + } @Override public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) { - return LocalAiChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl()) - .modelName(TEST_MODEL) - .maxRetries(chatModel.getMaxRetries()) - .maxTokens(chatModel.getMaxTokens()) - .temperature(chatModel.getTemperature()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .topP(chatModel.getTopP()) - .build(); + return createLocalAiChatModel(chatModel); } @Override public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) { - return LocalAiStreamingChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl()) - .modelName(TEST_MODEL) - .temperature(chatModel.getTemperature()) - .topP(chatModel.getTopP()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .build(); + return createLocalAiStreamingChatModel(chatModel); + } + + @Override + protected String getModelUrl() { + return DevoxxGenieStateService.getInstance().getGpt4allModelUrl(); + } + + @Override + protected Model[] fetchModels() throws IOException { + return GPT4AllService.getInstance().getModels().toArray(new Model[0]); } @Override - public List getModels() { - LanguageModel lmStudio = LanguageModel.builder() - .provider(MODEL_PROVIDER) - .modelName("GPT4All") - .inputCost(0) - .outputCost(0) - .contextWindow(8000) - .apiKeyUsed(false) - .build(); - - List modelNames = new ArrayList<>(); - modelNames.add(lmStudio); - return modelNames; + protected LanguageModel buildLanguageModel(Object model) { + Model gpt4AllModel = (Model) model; + // int contextWindow = GPT4AllService.getInstance() + return LanguageModel.builder() + .provider(modelProvider) + .modelName(gpt4AllModel.getId()) + .displayName(gpt4AllModel.getId()) + .inputCost(0) + .outputCost(0) + // .contextWindow(contextWindow) + .apiKeyUsed(false) + .build(); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java index ffd79efc..f65c55b4 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java @@ -1,105 +1,55 @@ package com.devoxx.genie.chatmodel.jan; -import com.devoxx.genie.chatmodel.ChatModelFactory; +import com.devoxx.genie.chatmodel.LocalChatModelFactory; import com.devoxx.genie.model.ChatModel; import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; import com.devoxx.genie.model.jan.Data; import com.devoxx.genie.service.jan.JanService; import com.devoxx.genie.ui.settings.DevoxxGenieStateService; -import com.devoxx.genie.ui.util.NotificationUtil; -import com.intellij.openapi.project.ProjectManager; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.localai.LocalAiChatModel; -import dev.langchain4j.model.localai.LocalAiStreamingChatModel; import org.jetbrains.annotations.NotNull; import java.io.IOException; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -public class JanChatModelFactory implements ChatModelFactory { +public class JanChatModelFactory extends LocalChatModelFactory { - private final ModelProvider MODEL_PROVIDER = ModelProvider.Jan; - - private List cachedModels = null; - private static final ExecutorService executorService = Executors.newFixedThreadPool(5); + public JanChatModelFactory() { + super(ModelProvider.Jan); + } @Override public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) { - return LocalAiChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getJanModelUrl()) - .modelName(chatModel.getModelName()) - .maxRetries(chatModel.getMaxRetries()) - .temperature(chatModel.getTemperature()) - .maxTokens(chatModel.getMaxTokens()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .topP(chatModel.getTopP()) - .build(); + return createLocalAiChatModel(chatModel); } @Override public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) { - return LocalAiStreamingChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getJanModelUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .topP(chatModel.getTopP()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .build(); + return createLocalAiStreamingChatModel(chatModel); } - /** - * Get the model names from the Jan service. - * - * @return List of model names - */ @Override - public List getModels() { - if (cachedModels != null) { - return cachedModels; - } - - List modelNames = new ArrayList<>(); - List> futures = new ArrayList<>(); - - try { - List models = JanService.getInstance().getModels(); - for (Data model : models) { - CompletableFuture future = CompletableFuture.runAsync(() -> { - LanguageModel languageModel = LanguageModel.builder() - .provider(MODEL_PROVIDER) - .modelName(model.getId()) - .displayName(model.getName()) - .inputCost(0) - .outputCost(0) - .contextWindow(model.getSettings().getCtxLen() == null ? 8_000 : model.getSettings().getCtxLen()) - .apiKeyUsed(false) - .build(); - synchronized (modelNames) { - modelNames.add(languageModel); - } - }, executorService); - futures.add(future); - } + protected String getModelUrl() { + return DevoxxGenieStateService.getInstance().getJanModelUrl(); + } - CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); - cachedModels = modelNames; - } catch (IOException e) { - NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), - "Unable to reach OpenRouter, please try again later."); - cachedModels = List.of(); - } - return cachedModels; + @Override + protected Data[] fetchModels() throws IOException { + return JanService.getInstance().getModels().toArray(new Data[0]); } @Override - public void resetModels() { - cachedModels = null; + protected LanguageModel buildLanguageModel(Object model) { + Data janModel = (Data) model; + return LanguageModel.builder() + .provider(modelProvider) + .modelName(janModel.getId()) + .displayName(janModel.getName()) + .inputCost(0) + .outputCost(0) + .contextWindow(janModel.getCtxLen() == null ? 8_000 : janModel.getSettings().getCtxLen()) + .apiKeyUsed(false) + .build(); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModel.java b/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModel.java index 6e4c4fa9..1da71d85 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModel.java +++ b/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModel.java @@ -11,6 +11,7 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import lombok.Builder; +import org.jetbrains.annotations.NotNull; import java.time.Duration; import java.util.List; @@ -81,9 +82,9 @@ public Response generate(List messages, ToolSpecificatio return generate(messages, singletonList(toolSpecification), toolSpecification); } - private Response generate(List messages, - List toolSpecifications, - ToolSpecification toolThatMustBeExecuted + private @NotNull Response generate(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted ) { ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder() .model(modelName) diff --git a/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java index 84ad9f4a..d711abc9 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java @@ -1,6 +1,6 @@ package com.devoxx.genie.chatmodel.lmstudio; -import com.devoxx.genie.chatmodel.ChatModelFactory; +import com.devoxx.genie.chatmodel.LocalChatModelFactory; import com.devoxx.genie.model.ChatModel; import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; @@ -12,100 +12,63 @@ import com.intellij.openapi.project.ProjectManager; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.localai.LocalAiStreamingChatModel; import org.jetbrains.annotations.NotNull; import java.io.IOException; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -public class LMStudioChatModelFactory implements ChatModelFactory { +public class LMStudioChatModelFactory extends LocalChatModelFactory { - public static final ModelProvider MODEL_PROVIDER = ModelProvider.LMStudio; - - private static final ExecutorService executorService = Executors.newFixedThreadPool(5); - private static boolean warningShown = false; - private List cachedModels = null; public static final int DEFAULT_CONTEXT_LENGTH = 8000; + public LMStudioChatModelFactory() { + super(ModelProvider.LMStudio); + } + @Override public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) { return LMStudioChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getLmstudioModelUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .topP(chatModel.getTopP()) - .maxTokens(chatModel.getMaxTokens()) - .maxRetries(chatModel.getMaxRetries()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .build(); + .baseUrl(getModelUrl()) + .modelName(chatModel.getModelName()) + .temperature(chatModel.getTemperature()) + .topP(chatModel.getTopP()) + .maxTokens(chatModel.getMaxTokens()) + .maxRetries(chatModel.getMaxRetries()) + .timeout(Duration.ofSeconds(chatModel.getTimeout())) + .build(); } @Override public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) { - return LocalAiStreamingChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getLmstudioModelUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .topP(chatModel.getTopP()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .build(); + return createLocalAiStreamingChatModel(chatModel); + } + + @Override + protected String getModelUrl() { + return DevoxxGenieStateService.getInstance().getLmstudioModelUrl(); } @Override - public List getModels() { + protected LMStudioModelEntryDTO[] fetchModels() throws IOException { if (!LMStudioUtil.isLMStudioRunning()) { NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), - "LMStudio is not running. Please start it and try again."); - return List.of(); - } - - if (cachedModels != null) { - return cachedModels; - } - - List modelNames = new ArrayList<>(); - List> futures = new ArrayList<>(); - - try { - LMStudioModelEntryDTO[] lmStudioModels = LMStudioService.getInstance().getModels(); - for (LMStudioModelEntryDTO model : lmStudioModels) { - CompletableFuture future = CompletableFuture.runAsync(() -> { - LanguageModel languageModel = LanguageModel.builder() - .provider(MODEL_PROVIDER) - .modelName(model.getId()) - .displayName(model.getId()) - .inputCost(0) - .outputCost(0) - .contextWindow(DEFAULT_CONTEXT_LENGTH) - .apiKeyUsed(false) - .build(); - synchronized (modelNames) { - modelNames.add(languageModel); - } - }, executorService); - futures.add(future); - } - - CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); - cachedModels = modelNames; - } catch (IOException e) { - if (!warningShown) { - NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), - "LMStudio is not running, please start it."); - warningShown = true; - } - cachedModels = List.of(); + "LMStudio is not running. Please start it and try again."); + throw new IOException("LMStudio is not running"); } - return cachedModels; + return LMStudioService.getInstance().getModels(); } @Override - public void resetModels() { - cachedModels = null; + protected LanguageModel buildLanguageModel(Object model) { + LMStudioModelEntryDTO lmStudioModel = (LMStudioModelEntryDTO) model; + return LanguageModel.builder() + .provider(modelProvider) + .modelName(lmStudioModel.getId()) + .displayName(lmStudioModel.getId()) + .inputCost(0) + .outputCost(0) + .contextWindow(DEFAULT_CONTEXT_LENGTH) + .apiKeyUsed(false) + .build(); } } diff --git a/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java index e8589f57..e4424cc4 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java @@ -1,6 +1,6 @@ package com.devoxx.genie.chatmodel.ollama; -import com.devoxx.genie.chatmodel.ChatModelFactory; +import com.devoxx.genie.chatmodel.LocalChatModelFactory; import com.devoxx.genie.model.ChatModel; import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; @@ -8,8 +8,6 @@ import com.devoxx.genie.service.ollama.OllamaApiService; import com.devoxx.genie.service.ollama.OllamaService; import com.devoxx.genie.ui.settings.DevoxxGenieStateService; -import com.devoxx.genie.ui.util.NotificationUtil; -import com.intellij.openapi.project.ProjectManager; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.ollama.OllamaChatModel; @@ -18,101 +16,58 @@ import java.io.IOException; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -public class OllamaChatModelFactory implements ChatModelFactory { +public class OllamaChatModelFactory extends LocalChatModelFactory { - private final ModelProvider MODEL_PROVIDER = ModelProvider.Ollama; - - private static final ExecutorService executorService = Executors.newFixedThreadPool(5); - private static boolean warningShown = false; - private List cachedModels = null; + public OllamaChatModelFactory() { + super(ModelProvider.Ollama); + } @Override public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) { return OllamaChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getOllamaModelUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .topP(chatModel.getTopP()) - .maxRetries(chatModel.getMaxRetries()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .build(); + .baseUrl(DevoxxGenieStateService.getInstance().getOllamaModelUrl()) + .modelName(chatModel.getModelName()) + .temperature(chatModel.getTemperature()) + .topP(chatModel.getTopP()) + .maxRetries(chatModel.getMaxRetries()) + .timeout(Duration.ofSeconds(chatModel.getTimeout())) + .build(); } @Override public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) { return OllamaStreamingChatModel.builder() - .baseUrl(DevoxxGenieStateService.getInstance().getOllamaModelUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .topP(chatModel.getTopP()) - .timeout(Duration.ofSeconds(chatModel.getTimeout())) - .build(); + .baseUrl(DevoxxGenieStateService.getInstance().getOllamaModelUrl()) + .modelName(chatModel.getModelName()) + .temperature(chatModel.getTemperature()) + .topP(chatModel.getTopP()) + .timeout(Duration.ofSeconds(chatModel.getTimeout())) + .build(); } - /** - * Get the model names from the Ollama service. - * We're currently adding a fixed number of tokens to the model size. - * TODO - Get the model size from the Ollama service or have the user define them in Options panel? - * - * @return List of model names - */ @Override - public List getModels() { - if (cachedModels != null) { - return cachedModels; - } - - List modelNames = new ArrayList<>(); - List> futures = new ArrayList<>(); - - try { - OllamaModelEntryDTO[] ollamaModels = OllamaService.getInstance().getModels(); - for (OllamaModelEntryDTO model : ollamaModels) { - CompletableFuture future = CompletableFuture.runAsync(() -> { - try { - int contextWindow = OllamaApiService.getModelContext(model.getName()); - LanguageModel languageModel = LanguageModel.builder() - .provider(MODEL_PROVIDER) - .modelName(model.getName()) - .displayName(model.getName()) - .inputCost(0) - .outputCost(0) - .contextWindow(contextWindow) - .apiKeyUsed(false) - .build(); - synchronized (modelNames) { - modelNames.add(languageModel); - } - } catch (IOException e) { - NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), - "Error fetching context window for model: " + model.getName()); - } - }, executorService); - futures.add(future); - } - - CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); - cachedModels = modelNames; - } catch (IOException e) { - if (!warningShown) { - NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), - "Ollama is not running, please start it."); - warningShown = true; - } - cachedModels = List.of(); - } - return cachedModels; + protected String getModelUrl() { + return DevoxxGenieStateService.getInstance().getOllamaModelUrl(); } @Override - public void resetModels() { - cachedModels = null; + protected OllamaModelEntryDTO[] fetchModels() throws IOException { + return OllamaService.getInstance().getModels(); } + @Override + protected LanguageModel buildLanguageModel(Object model) throws IOException { + OllamaModelEntryDTO ollamaModel = (OllamaModelEntryDTO) model; + int contextWindow = OllamaApiService.getModelContext(ollamaModel.getName()); + return LanguageModel.builder() + .provider(modelProvider) + .modelName(ollamaModel.getName()) + .displayName(ollamaModel.getName()) + .inputCost(0) + .outputCost(0) + .contextWindow(contextWindow) + .apiKeyUsed(false) + .build(); + } } diff --git a/src/main/java/com/devoxx/genie/service/gpt4all/GPT4AllService.java b/src/main/java/com/devoxx/genie/service/gpt4all/GPT4AllService.java new file mode 100644 index 00000000..bb73f4ba --- /dev/null +++ b/src/main/java/com/devoxx/genie/service/gpt4all/GPT4AllService.java @@ -0,0 +1,55 @@ +package com.devoxx.genie.service.gpt4all; + +import com.devoxx.genie.model.gpt4all.Model; +import com.devoxx.genie.model.gpt4all.ResponseDTO; +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import com.google.gson.Gson; +import com.intellij.openapi.application.ApplicationManager; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.jetbrains.annotations.NotNull; + +import java.io.IOException; +import java.util.List; + +import static com.devoxx.genie.util.HttpUtil.ensureEndsWithSlash; + +public class GPT4AllService { + private final OkHttpClient client = new OkHttpClient(); + + @NotNull + public static GPT4AllService getInstance() { + return ApplicationManager.getApplication().getService(GPT4AllService.class); + } + + /** + * Get the models from the GPT4All service. + * @return array of model names + * @throws IOException if there is an error + */ + public List getModels() throws IOException { + String baseUrl = ensureEndsWithSlash(DevoxxGenieStateService.getInstance().getGpt4allModelUrl()); + + Request request = new Request.Builder() + .url(baseUrl + "models") + .build(); + + try (Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new UnsuccessfulRequestException("Unexpected code " + response); + } + + assert response.body() != null; + + ResponseDTO modelResponse = new Gson().fromJson(response.body().string(), ResponseDTO.class); + return modelResponse != null && modelResponse.getData() != null ? modelResponse.getData() : List.of(); + } + } + + public static class UnsuccessfulRequestException extends IOException { + public UnsuccessfulRequestException(String message) { + super(message); + } + } +} diff --git a/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java b/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java index a2089d9b..293bf5d2 100644 --- a/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java +++ b/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java @@ -149,7 +149,10 @@ private void refreshModels() { return; } - if (selectedProvider == ModelProvider.LMStudio || selectedProvider == ModelProvider.Ollama || selectedProvider == ModelProvider.Jan) { + if (selectedProvider == ModelProvider.LMStudio || + selectedProvider == ModelProvider.Ollama || + selectedProvider == ModelProvider.Jan || + selectedProvider == ModelProvider.GPT4All) { ApplicationManager.getApplication().invokeLater(() -> { refreshButton.setEnabled(false); @@ -167,8 +170,6 @@ private void refreshModels() { } else { NotificationUtil.sendNotification(project, "Model refresh is only available for LMStudio, Ollama and Jan providers."); } - - } /** diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 15ba38af..b22653a4 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -42,6 +42,7 @@
  • Fix #384: Fix web search with enable/disable of Google or Tavily feature
  • Fix #387: Fix tokens encoding for special chars
  • Fix #389: Remove action buttons tooltips
  • +
  • Fix #392: Fixed GPT4All support
  • v0.4.4

      @@ -482,6 +483,7 @@ + diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 06e50884..23188280 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,2 +1,2 @@ -#Fri Dec 13 18:54:40 CET 2024 +#Mon Dec 16 09:47:38 CET 2024 version=0.4.5