Skip to content

Commit

Permalink
Merge pull request #407 from devoxx/issue-406
Browse files Browse the repository at this point in the history
Feat #406 : Getting local models is now centralised LocalChatModelFac…
  • Loading branch information
stephanj authored Dec 17, 2024
2 parents 7bc3b00 + 37ecd1c commit 90c3f10
Show file tree
Hide file tree
Showing 51 changed files with 309 additions and 250 deletions.
2 changes: 1 addition & 1 deletion core/src/main/java/com/devoxx/genie/model/jan/Data.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class Data {
private String object;

@JsonProperty("ctx_len")
private Integer ctxLen;
private Long ctxLen;

@JsonProperty("max_tokens")
private Integer maxTokens;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.chatmodel.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.azureopenai.AzureOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.customopenai.CustomOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.deepseek.DeepSeekChatModelFactory;
import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.openrouter.OpenRouterChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.azureopenai.AzureOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.local.customopenai.CustomOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.deepseek.DeepSeekChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.local.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.local.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.local.llamaCPP.LlamaChatModelFactory;
import com.devoxx.genie.chatmodel.local.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.local.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openrouter.OpenRouterChatModelFactory;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -50,6 +51,7 @@ private ChatModelFactoryProvider() {
case "Groq" -> new GroqChatModelFactory();
case "GPT4All" -> new GPT4AllChatModelFactory();
case "Jan" -> new JanChatModelFactory();
case "LLaMA" -> new LlamaChatModelFactory();
case "LMStudio" -> new LMStudioChatModelFactory();
case "Mistral" -> new MistralChatModelFactory();
case "Ollama" -> new OllamaChatModelFactory();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.chatmodel.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.local.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.Constant;
import com.devoxx.genie.model.LanguageModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public abstract class LocalChatModelFactory implements ChatModelFactory {
protected List<LanguageModel> cachedModels = null;
protected static final ExecutorService executorService = Executors.newFixedThreadPool(5);
protected static boolean warningShown = false;
protected boolean providerRunning = false;
protected boolean providerChecked = false;

protected LocalChatModelFactory(ModelProvider modelProvider) {
this.modelProvider = modelProvider;
Expand Down Expand Up @@ -62,9 +64,18 @@ protected StreamingChatLanguageModel createLocalAiStreamingChatModel(@NotNull Ch

@Override
public List<LanguageModel> getModels() {
if (cachedModels != null) {
return cachedModels;
if (!providerChecked) {
checkAndFetchModels();
}
if (!providerRunning) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
"LLM provider is not running. Please start it and try again.");
return List.of();
}
return cachedModels;
}

private void checkAndFetchModels() {
List<LanguageModel> modelNames = new ArrayList<>();
List<CompletableFuture<Void>> futures = new ArrayList<>();
try {
Expand All @@ -77,25 +88,28 @@ public List<LanguageModel> getModels() {
modelNames.add(languageModel);
}
} catch (IOException e) {
handleModelFetchError(model, e);
handleModelFetchError(e);
}
}, executorService);
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
cachedModels = modelNames;
providerRunning = true;
} catch (IOException e) {
handleGeneralFetchError(e);
cachedModels = List.of();
providerRunning = false;
} finally {
providerChecked = true;
}
return cachedModels;
}

protected abstract Object[] fetchModels() throws IOException;

protected abstract LanguageModel buildLanguageModel(Object model) throws IOException;

protected void handleModelFetchError(Object model, @NotNull IOException e) {
protected void handleModelFetchError(@NotNull IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching model details: " + e.getMessage());
}

Expand All @@ -109,5 +123,7 @@ protected void handleGeneralFetchError(IOException e) {
@Override
public void resetModels() {
cachedModels = null;
providerChecked = false;
providerRunning = false;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.anthropic;
package com.devoxx.genie.chatmodel.cloud.anthropic;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.azureopenai;
package com.devoxx.genie.chatmodel.cloud.azureopenai;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.deepinfra;
package com.devoxx.genie.chatmodel.cloud.deepinfra;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.deepseek;
package com.devoxx.genie.chatmodel.cloud.deepseek;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.google;
package com.devoxx.genie.chatmodel.cloud.google;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.groq;
package com.devoxx.genie.chatmodel.cloud.groq;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.mistral;
package com.devoxx.genie.chatmodel.cloud.mistral;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.openai;
package com.devoxx.genie.chatmodel.cloud.openai;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.openrouter;
package com.devoxx.genie.chatmodel.cloud.openrouter;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.devoxx.genie.chatmodel.local;

import java.io.IOException;

public interface LocalLLMProvider {
Object getModels() throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.devoxx.genie.chatmodel.local;

import com.devoxx.genie.model.lmstudio.LMStudioModelEntryDTO;
import com.devoxx.genie.service.exception.UnsuccessfulRequestException;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;

import java.io.IOException;
import java.util.Objects;

import static com.devoxx.genie.util.HttpUtil.ensureEndsWithSlash;

public class LocalLLMProviderUtil {

private static final OkHttpClient client = new OkHttpClient();
private static final Gson gson = new Gson();

public static <T> T getModels(String baseUrlConfigKey, String endpoint, Class<T> responseType) throws IOException {
String configValue = DevoxxGenieStateService.getInstance().getConfigValue(baseUrlConfigKey);
String baseUrl = ensureEndsWithSlash(Objects.requireNonNull(configValue));

Request request = new Request.Builder()
.url(baseUrl + endpoint)
.build();

try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new UnsuccessfulRequestException("Unexpected code " + response);
}

if (response.body() == null) {
throw new UnsuccessfulRequestException("Response body is null");
}

String json = response.body().string();

// Special handling for LM Studio
if (responseType.equals(LMStudioModelEntryDTO[].class)) {
JsonElement jsonElement = gson.fromJson(json, JsonElement.class);
if (jsonElement.isJsonObject() && jsonElement.getAsJsonObject().has("data")) {
return gson.fromJson(jsonElement.getAsJsonObject().get("data"), responseType);
} else {
return responseType.cast(new LMStudioModelEntryDTO[0]);
}
}

return gson.fromJson(json, responseType);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.customopenai;
package com.devoxx.genie.chatmodel.local.customopenai;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package com.devoxx.genie.chatmodel.gpt4all;
package com.devoxx.genie.chatmodel.local.gpt4all;

import com.devoxx.genie.chatmodel.LocalChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.gpt4all.Model;
import com.devoxx.genie.service.gpt4all.GPT4AllService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
Expand Down Expand Up @@ -36,20 +35,19 @@ protected String getModelUrl() {

@Override
protected Model[] fetchModels() throws IOException {
return GPT4AllService.getInstance().getModels().toArray(new Model[0]);
return GPT4AllModelService.getInstance().getModels().toArray(new Model[0]);
}

@Override
protected LanguageModel buildLanguageModel(Object model) {
Model gpt4AllModel = (Model) model;
// int contextWindow = GPT4AllService.getInstance()
return LanguageModel.builder()
.provider(modelProvider)
.modelName(gpt4AllModel.getId())
.displayName(gpt4AllModel.getId())
.inputCost(0)
.outputCost(0)
// .contextWindow(contextWindow)
// .contextWindow(contextWindow) // GPT4All does not provide context window :(
.apiKeyUsed(false)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.devoxx.genie.chatmodel.local.gpt4all;

import com.devoxx.genie.chatmodel.local.LocalLLMProvider;
import com.devoxx.genie.chatmodel.local.LocalLLMProviderUtil;
import com.devoxx.genie.model.gpt4all.Model;
import com.devoxx.genie.model.gpt4all.ResponseDTO;
import com.intellij.openapi.application.ApplicationManager;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.List;

public class GPT4AllModelService implements LocalLLMProvider {

@NotNull
public static GPT4AllModelService getInstance() {
return ApplicationManager.getApplication().getService(GPT4AllModelService.class);
}

@Override
public List<Model> getModels() throws IOException {
return LocalLLMProviderUtil
.getModels("gpt4allModelUrl", "models", ResponseDTO.class)
.getData();
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package com.devoxx.genie.chatmodel.jan;
package com.devoxx.genie.chatmodel.local.jan;

import com.devoxx.genie.chatmodel.LocalChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.service.jan.JanService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
Expand Down Expand Up @@ -36,7 +35,7 @@ protected String getModelUrl() {

@Override
protected Data[] fetchModels() throws IOException {
return JanService.getInstance().getModels().toArray(new Data[0]);
return JanModelService.getInstance().getModels().toArray(new Data[0]);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.devoxx.genie.chatmodel.local.jan;

import com.devoxx.genie.chatmodel.local.LocalLLMProvider;
import com.devoxx.genie.chatmodel.local.LocalLLMProviderUtil;
import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.model.jan.ResponseDTO;
import com.intellij.openapi.application.ApplicationManager;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.List;

public class JanModelService implements LocalLLMProvider {

@NotNull
public static JanModelService getInstance() {
return ApplicationManager.getApplication().getService(JanModelService.class);
}

@Override
public List<Data> getModels() throws IOException {
return LocalLLMProviderUtil
.getModels("janModelUrl", "models", ResponseDTO.class)
.getData();
}
}
Loading

0 comments on commit 90c3f10

Please sign in to comment.