Skip to content

Commit

Permalink
Merge pull request #285 from devoxx/issue-244
Browse files Browse the repository at this point in the history
Fix #244 Load Jan models and list them for usage
  • Loading branch information
stephanj authored Sep 9, 2024
2 parents 51fce65 + 6b8ee9e commit d50934a
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 57 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.2.18"
version = "0.2.19"

repositories {
mavenCentral()
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/java/com/devoxx/genie/model/jan/Data.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.devoxx.genie.model.jan;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.gson.annotations.SerializedName;
import lombok.Getter;
import lombok.Setter;

Expand All @@ -25,6 +26,7 @@ public class Data {
@JsonProperty("format")
private String format;

@SerializedName("settings")
@JsonProperty("settings")
private Settings settings;

Expand Down
5 changes: 4 additions & 1 deletion core/src/main/java/com/devoxx/genie/model/jan/Settings.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package com.devoxx.genie.model.jan;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.gson.annotations.SerializedName;
import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
public class Settings {

@SerializedName("ctx_len")
@JsonProperty("ctx_len")
private int ctxLen;
private Integer ctxLen;

@JsonProperty("prompt_template")
private String promptTemplate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.devoxx.genie.chatmodel.exo.ExoChatModelFactory;
import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory;
Expand Down Expand Up @@ -35,6 +36,7 @@ public class ChatModelFactoryProvider {
private static @Nullable ChatModelFactory createFactory(@NotNull String modelProvider) {
return switch (modelProvider) {
case "Ollama" -> new OllamaChatModelFactory();
case "Jan" -> new JanChatModelFactory();
case "OpenRouter" -> new OpenRouterChatModelFactory();
case "LMStudio" -> new LMStudioChatModelFactory();
case "Exo" -> new ExoChatModelFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class JanChatModelFactory implements ChatModelFactory {
private List<LanguageModel> cachedModels = null;
private static final ExecutorService executorService = Executors.newFixedThreadPool(5);

@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
Expand Down Expand Up @@ -54,28 +59,40 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch
*/
@Override
public List<LanguageModel> getModels() {
if (cachedModels != null) {
return cachedModels;
}

List<LanguageModel> modelNames = new ArrayList<>();
List<CompletableFuture<Void>> futures = new ArrayList<>();

try {
List<Data> models = new JanService().getModels();
List<Data> models = JanService.getInstance().getModels();
for (Data model : models) {
int ctxLen = model.getSettings().getCtxLen();
modelNames.add(
LanguageModel.builder()
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.Jan)
.modelName(model.getName())
.modelName(model.getId())
.displayName(model.getName())
.contextWindow(ctxLen)
.apiKeyUsed(false)
.inputCost(0)
.outputCost(0)
.build()
);
.contextWindow(model.getSettings().getCtxLen() == null ? 8_000 : model.getSettings().getCtxLen())
.apiKeyUsed(false)
.build();
synchronized (modelNames) {
modelNames.add(languageModel);
}
}, executorService);
futures.add(future);
}

CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
cachedModels = modelNames;
} catch (IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
"Jan is not running, please start it.");
return List.of();
"Unable to reach OpenRouter, please try again later.");
cachedModels = List.of();
}
return modelNames;
return cachedModels;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static LLMProviderService getInstance() {
}

public List<ModelProvider> getLocalModelProviders() {
return List.of(GPT4All, LMStudio, Ollama, Exo, LLaMA);
return List.of(GPT4All, LMStudio, Ollama, Exo, LLaMA, Jan);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import com.devoxx.genie.error.ErrorHandler;
import com.devoxx.genie.model.Constant;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.service.exception.ModelNotActiveException;
import com.devoxx.genie.service.exception.ProviderUnavailableException;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
Expand Down Expand Up @@ -117,6 +119,9 @@ private boolean isCanceled() {
ChatMemoryService.getInstance().add(chatMessageContext.getProject(), response.content());
return response;
} catch (Exception e) {
if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.Jan)) {
throw new ModelNotActiveException("Selected Jan model is not active. Download and make it active or add API Key in Jan settings.");
}
ChatMemoryService.getInstance().removeLast(chatMessageContext.getProject());
throw new ProviderUnavailableException(e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.devoxx.genie.service.exception;

public class ModelNotActiveException extends RuntimeException {

public ModelNotActiveException(String message) {
super(message);
}
}
6 changes: 5 additions & 1 deletion src/main/java/com/devoxx/genie/service/jan/JanService.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import com.devoxx.genie.model.jan.ResponseDTO;
import com.devoxx.genie.service.DevoxxGenieSettingsServiceProvider;
import com.google.gson.Gson;
import com.intellij.openapi.application.ApplicationManager;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.List;
Expand All @@ -16,7 +18,9 @@
public class JanService {
private final OkHttpClient client = new OkHttpClient();

public JanService() {
@NotNull
public static JanService getInstance() {
return ApplicationManager.getApplication().getService(JanService.class);
}

public List<Data> getModels() throws IOException {
Expand Down
5 changes: 5 additions & 0 deletions src/main/resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
]]></description>

<change-notes><![CDATA[
<h2>v0.2.19</h2>
<UL>
<LI>Feat #244 : Fix for Jan 👋🏼</LI>
</UL>
<h2>v0.2.18</h2>
<UL>
<LI>Feat #225 : Support for OpenRouter</LI>
Expand Down Expand Up @@ -378,6 +382,7 @@
<applicationService serviceImplementation="com.devoxx.genie.service.TokenCalculationService"/>
<applicationService serviceImplementation="com.devoxx.genie.service.lmstudio.LMStudioService"/>
<applicationService serviceImplementation="com.devoxx.genie.service.openrouter.OpenRouterService"/>
<applicationService serviceImplementation="com.devoxx.genie.service.jan.JanService"/>
</extensions>

<extensions defaultExtensionNs="com.intellij">
Expand Down
4 changes: 2 additions & 2 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#Mon Sep 09 09:17:52 CEST 2024
version=0.2.18
#Mon Sep 09 15:46:08 CEST 2024
version=0.2.19
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,23 @@ void testCreateChatModel() {
assertThat(result).isNotNull();
}
}

@Test
void testHelloChat() {
try (MockedStatic<DevoxxGenieSettingsServiceProvider> mockedSettings = Mockito.mockStatic(DevoxxGenieSettingsServiceProvider.class)) {
// Setup the mock for SettingsState
DevoxxGenieStateService mockSettingsState = mock(DevoxxGenieStateService.class);
when(DevoxxGenieSettingsServiceProvider.getInstance()).thenReturn(mockSettingsState);
when(mockSettingsState.getJanModelUrl()).thenReturn("http://localhost:1337/v1/");

// Instance of the class containing the method to be tested
JanChatModelFactory factory = new JanChatModelFactory();

ChatModel chatModel = new ChatModel();
chatModel.setModelName("mistral-ins-7b-q4");
ChatLanguageModel chatLanguageModel = factory.createChatModel(chatModel);
String hello = chatLanguageModel.generate("Hello");
assertThat(hello).isNotNull();
}
}
}
39 changes: 0 additions & 39 deletions src/test/java/com/devoxx/genie/model/GeminiClientTest.java

This file was deleted.

46 changes: 46 additions & 0 deletions src/test/java/com/devoxx/genie/service/jan/JanServiceTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.devoxx.genie.service.jan;

import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase;
import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.testFramework.ServiceContainerUtil;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.List;

public class JanServiceTest extends AbstractLightPlatformTestCase {

@BeforeEach
public void setUp() throws Exception {
super.setUp();
// Mock SettingsState
DevoxxGenieStateService settingsStateMock = mock(DevoxxGenieStateService.class);
when(settingsStateMock.getJanModelUrl()).thenReturn("http://localhost:1337/v1/");

// Replace the service instance with the mock
ServiceContainerUtil.replaceService(ApplicationManager.getApplication(), DevoxxGenieStateService.class, settingsStateMock, getTestRootDisposable());
}

@Test
public void testGetModels() throws IOException {
JanService janService = new JanService();
List<Data> models = janService.getModels();
assertThat(models).isNotEmpty();

models.forEach(model -> {
assertThat(model).isNotNull();
assertThat(model.getId()).isNotNull();
assertThat(model.getName()).isNotNull();
assertThat(model.getDescription()).isNotNull();
assertThat(model.getSettings()).isNotNull();
assertThat(model.getSettings().getCtxLen()).isNotNull();
});
}
}

0 comments on commit d50934a

Please sign in to comment.