From 6b8ee9edad1a42e03e7ac88d693e483b27358149 Mon Sep 17 00:00:00 2001 From: Stephan Janssen Date: Mon, 9 Sep 2024 15:53:30 +0200 Subject: [PATCH] Fix #244 Load Jan models and list them for usage --- build.gradle.kts | 2 +- .../java/com/devoxx/genie/model/jan/Data.java | 2 + .../com/devoxx/genie/model/jan/Settings.java | 5 +- .../chatmodel/ChatModelFactoryProvider.java | 2 + .../chatmodel/jan/JanChatModelFactory.java | 41 ++++++++++++----- .../genie/service/LLMProviderService.java | 2 +- .../genie/service/PromptExecutionService.java | 5 ++ .../exception/ModelNotActiveException.java | 8 ++++ .../devoxx/genie/service/jan/JanService.java | 6 ++- src/main/resources/META-INF/plugin.xml | 5 ++ src/main/resources/application.properties | 4 +- .../jan/JanChatModelFactoryTest.java | 19 ++++++++ .../devoxx/genie/model/GeminiClientTest.java | 39 ---------------- .../genie/service/jan/JanServiceTest.java | 46 +++++++++++++++++++ 14 files changed, 129 insertions(+), 57 deletions(-) create mode 100644 src/main/java/com/devoxx/genie/service/exception/ModelNotActiveException.java delete mode 100644 src/test/java/com/devoxx/genie/model/GeminiClientTest.java create mode 100644 src/test/java/com/devoxx/genie/service/jan/JanServiceTest.java diff --git a/build.gradle.kts b/build.gradle.kts index eef88ecc..f02b2d8e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -7,7 +7,7 @@ plugins { } group = "com.devoxx.genie" -version = "0.2.18" +version = "0.2.19" repositories { mavenCentral() diff --git a/core/src/main/java/com/devoxx/genie/model/jan/Data.java b/core/src/main/java/com/devoxx/genie/model/jan/Data.java index 80240660..9150fd0b 100644 --- a/core/src/main/java/com/devoxx/genie/model/jan/Data.java +++ b/core/src/main/java/com/devoxx/genie/model/jan/Data.java @@ -1,6 +1,7 @@ package com.devoxx.genie.model.jan; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.gson.annotations.SerializedName; import lombok.Getter; import lombok.Setter; @@ -25,6 +26,7 @@ public class Data { @JsonProperty("format") private String format; + @SerializedName("settings") @JsonProperty("settings") private Settings settings; diff --git a/core/src/main/java/com/devoxx/genie/model/jan/Settings.java b/core/src/main/java/com/devoxx/genie/model/jan/Settings.java index 2e014c1b..a6ed5501 100644 --- a/core/src/main/java/com/devoxx/genie/model/jan/Settings.java +++ b/core/src/main/java/com/devoxx/genie/model/jan/Settings.java @@ -1,14 +1,17 @@ package com.devoxx.genie.model.jan; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.gson.annotations.SerializedName; import lombok.Getter; import lombok.Setter; @Setter @Getter public class Settings { + + @SerializedName("ctx_len") @JsonProperty("ctx_len") - private int ctxLen; + private Integer ctxLen; @JsonProperty("prompt_template") private String promptTemplate; diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java index 11bb184f..41837a01 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java @@ -6,6 +6,7 @@ import com.devoxx.genie.chatmodel.exo.ExoChatModelFactory; import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory; import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory; +import com.devoxx.genie.chatmodel.jan.JanChatModelFactory; import com.devoxx.genie.chatmodel.lmstudio.LMStudioChatModelFactory; import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory; import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory; @@ -35,6 +36,7 @@ public class ChatModelFactoryProvider { private static @Nullable ChatModelFactory createFactory(@NotNull String modelProvider) { return switch (modelProvider) { case "Ollama" -> new OllamaChatModelFactory(); + case "Jan" -> new JanChatModelFactory(); case "OpenRouter" -> new OpenRouterChatModelFactory(); case "LMStudio" -> new LMStudioChatModelFactory(); case "Exo" -> new ExoChatModelFactory(); diff --git a/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java index 2c0527a8..63985407 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java +++ b/src/main/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactory.java @@ -19,8 +19,13 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; public class JanChatModelFactory implements ChatModelFactory { + private List cachedModels = null; + private static final ExecutorService executorService = Executors.newFixedThreadPool(5); @Override public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) { @@ -54,28 +59,40 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch */ @Override public List getModels() { + if (cachedModels != null) { + return cachedModels; + } + List modelNames = new ArrayList<>(); + List> futures = new ArrayList<>(); + try { - List models = new JanService().getModels(); + List models = JanService.getInstance().getModels(); for (Data model : models) { - int ctxLen = model.getSettings().getCtxLen(); - modelNames.add( - LanguageModel.builder() + CompletableFuture future = CompletableFuture.runAsync(() -> { + LanguageModel languageModel = LanguageModel.builder() .provider(ModelProvider.Jan) - .modelName(model.getName()) + .modelName(model.getId()) .displayName(model.getName()) - .contextWindow(ctxLen) - .apiKeyUsed(false) .inputCost(0) .outputCost(0) - .build() - ); + .contextWindow(model.getSettings().getCtxLen() == null ? 8_000 : model.getSettings().getCtxLen()) + .apiKeyUsed(false) + .build(); + synchronized (modelNames) { + modelNames.add(languageModel); + } + }, executorService); + futures.add(future); } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + cachedModels = modelNames; } catch (IOException e) { NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), - "Jan is not running, please start it."); - return List.of(); + "Unable to reach OpenRouter, please try again later."); + cachedModels = List.of(); } - return modelNames; + return cachedModels; } } diff --git a/src/main/java/com/devoxx/genie/service/LLMProviderService.java b/src/main/java/com/devoxx/genie/service/LLMProviderService.java index 25d2e4ad..4d822b73 100644 --- a/src/main/java/com/devoxx/genie/service/LLMProviderService.java +++ b/src/main/java/com/devoxx/genie/service/LLMProviderService.java @@ -34,7 +34,7 @@ public static LLMProviderService getInstance() { } public List getLocalModelProviders() { - return List.of(GPT4All, LMStudio, Ollama, Exo, LLaMA); + return List.of(GPT4All, LMStudio, Ollama, Exo, LLaMA, Jan); } /** diff --git a/src/main/java/com/devoxx/genie/service/PromptExecutionService.java b/src/main/java/com/devoxx/genie/service/PromptExecutionService.java index 077e4e45..4339dfe3 100644 --- a/src/main/java/com/devoxx/genie/service/PromptExecutionService.java +++ b/src/main/java/com/devoxx/genie/service/PromptExecutionService.java @@ -2,7 +2,9 @@ import com.devoxx.genie.error.ErrorHandler; import com.devoxx.genie.model.Constant; +import com.devoxx.genie.model.enumarations.ModelProvider; import com.devoxx.genie.model.request.ChatMessageContext; +import com.devoxx.genie.service.exception.ModelNotActiveException; import com.devoxx.genie.service.exception.ProviderUnavailableException; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.diagnostic.Logger; @@ -117,6 +119,9 @@ private boolean isCanceled() { ChatMemoryService.getInstance().add(chatMessageContext.getProject(), response.content()); return response; } catch (Exception e) { + if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.Jan)) { + throw new ModelNotActiveException("Selected Jan model is not active. Download and make it active or add API Key in Jan settings."); + } ChatMemoryService.getInstance().removeLast(chatMessageContext.getProject()); throw new ProviderUnavailableException(e.getMessage()); } diff --git a/src/main/java/com/devoxx/genie/service/exception/ModelNotActiveException.java b/src/main/java/com/devoxx/genie/service/exception/ModelNotActiveException.java new file mode 100644 index 00000000..8f065bbd --- /dev/null +++ b/src/main/java/com/devoxx/genie/service/exception/ModelNotActiveException.java @@ -0,0 +1,8 @@ +package com.devoxx.genie.service.exception; + +public class ModelNotActiveException extends RuntimeException { + + public ModelNotActiveException(String message) { + super(message); + } +} diff --git a/src/main/java/com/devoxx/genie/service/jan/JanService.java b/src/main/java/com/devoxx/genie/service/jan/JanService.java index d330cd1e..36dbb179 100644 --- a/src/main/java/com/devoxx/genie/service/jan/JanService.java +++ b/src/main/java/com/devoxx/genie/service/jan/JanService.java @@ -4,9 +4,11 @@ import com.devoxx.genie.model.jan.ResponseDTO; import com.devoxx.genie.service.DevoxxGenieSettingsServiceProvider; import com.google.gson.Gson; +import com.intellij.openapi.application.ApplicationManager; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; +import org.jetbrains.annotations.NotNull; import java.io.IOException; import java.util.List; @@ -16,7 +18,9 @@ public class JanService { private final OkHttpClient client = new OkHttpClient(); - public JanService() { + @NotNull + public static JanService getInstance() { + return ApplicationManager.getApplication().getService(JanService.class); } public List getModels() throws IOException { diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 4ee605ba..8238f14b 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -35,6 +35,10 @@ ]]> v0.2.19 +
    +
  • Feat #244 : Fix for Jan 👋🏼
  • +

v0.2.18

  • Feat #225 : Support for OpenRouter
  • @@ -378,6 +382,7 @@ + diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index ff597bde..5a404aeb 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,2 +1,2 @@ -#Mon Sep 09 09:17:52 CEST 2024 -version=0.2.18 +#Mon Sep 09 15:46:08 CEST 2024 +version=0.2.19 diff --git a/src/test/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactoryTest.java b/src/test/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactoryTest.java index e6ddeae8..d35cb907 100644 --- a/src/test/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactoryTest.java +++ b/src/test/java/com/devoxx/genie/chatmodel/jan/JanChatModelFactoryTest.java @@ -35,4 +35,23 @@ void testCreateChatModel() { assertThat(result).isNotNull(); } } + + @Test + void testHelloChat() { + try (MockedStatic mockedSettings = Mockito.mockStatic(DevoxxGenieSettingsServiceProvider.class)) { + // Setup the mock for SettingsState + DevoxxGenieStateService mockSettingsState = mock(DevoxxGenieStateService.class); + when(DevoxxGenieSettingsServiceProvider.getInstance()).thenReturn(mockSettingsState); + when(mockSettingsState.getJanModelUrl()).thenReturn("http://localhost:1337/v1/"); + + // Instance of the class containing the method to be tested + JanChatModelFactory factory = new JanChatModelFactory(); + + ChatModel chatModel = new ChatModel(); + chatModel.setModelName("mistral-ins-7b-q4"); + ChatLanguageModel chatLanguageModel = factory.createChatModel(chatModel); + String hello = chatLanguageModel.generate("Hello"); + assertThat(hello).isNotNull(); + } + } } diff --git a/src/test/java/com/devoxx/genie/model/GeminiClientTest.java b/src/test/java/com/devoxx/genie/model/GeminiClientTest.java deleted file mode 100644 index 5d976193..00000000 --- a/src/test/java/com/devoxx/genie/model/GeminiClientTest.java +++ /dev/null @@ -1,39 +0,0 @@ -package com.devoxx.genie.model; - -import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase; -import com.devoxx.genie.ui.settings.DevoxxGenieStateService; -import com.intellij.openapi.application.ApplicationManager; -import com.intellij.testFramework.ServiceContainerUtil; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; -import org.junit.Ignore; -import org.junit.jupiter.api.BeforeEach; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class GeminiClientTest extends AbstractLightPlatformTestCase { - - @BeforeEach - public void setUp() throws Exception { - super.setUp(); - // Mock SettingsState - DevoxxGenieStateService settingsStateMock = mock(DevoxxGenieStateService.class); - when(settingsStateMock.getGeminiKey()).thenReturn("dummy-key"); - - // Replace the service instance with the mock - ServiceContainerUtil.replaceService(ApplicationManager.getApplication(), DevoxxGenieStateService.class, settingsStateMock, getTestRootDisposable()); - } - - @Ignore - public void testGeminiRequest() { - ChatLanguageModel gemini = GoogleAiGeminiChatModel.builder() - .apiKey(System.getenv("GEMINI_AI_KEY")) - .modelName("gemini-1.5-flash") - .build(); - - String response = gemini.generate("Konnichiwa Gemini!"); - assertThat(response).isNotNull(); - } -} diff --git a/src/test/java/com/devoxx/genie/service/jan/JanServiceTest.java b/src/test/java/com/devoxx/genie/service/jan/JanServiceTest.java new file mode 100644 index 00000000..887e4504 --- /dev/null +++ b/src/test/java/com/devoxx/genie/service/jan/JanServiceTest.java @@ -0,0 +1,46 @@ +package com.devoxx.genie.service.jan; + +import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase; +import com.devoxx.genie.model.jan.Data; +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.testFramework.ServiceContainerUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.List; + +public class JanServiceTest extends AbstractLightPlatformTestCase { + + @BeforeEach + public void setUp() throws Exception { + super.setUp(); + // Mock SettingsState + DevoxxGenieStateService settingsStateMock = mock(DevoxxGenieStateService.class); + when(settingsStateMock.getJanModelUrl()).thenReturn("http://localhost:1337/v1/"); + + // Replace the service instance with the mock + ServiceContainerUtil.replaceService(ApplicationManager.getApplication(), DevoxxGenieStateService.class, settingsStateMock, getTestRootDisposable()); + } + + @Test + public void testGetModels() throws IOException { + JanService janService = new JanService(); + List models = janService.getModels(); + assertThat(models).isNotEmpty(); + + models.forEach(model -> { + assertThat(model).isNotNull(); + assertThat(model.getId()).isNotNull(); + assertThat(model.getName()).isNotNull(); + assertThat(model.getDescription()).isNotNull(); + assertThat(model.getSettings()).isNotNull(); + assertThat(model.getSettings().getCtxLen()).isNotNull(); + }); + } +} +