Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat #406 : Getting local models is now centralised LocalChatModelFac… #407

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/java/com/devoxx/genie/model/jan/Data.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class Data {
private String object;

@JsonProperty("ctx_len")
private Integer ctxLen;
private Long ctxLen;

@JsonProperty("max_tokens")
private Integer maxTokens;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.chatmodel.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.azureopenai.AzureOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.customopenai.CustomOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.deepseek.DeepSeekChatModelFactory;
import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.openrouter.OpenRouterChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.azureopenai.AzureOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.local.customopenai.CustomOpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.deepseek.DeepSeekChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.local.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.local.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.local.llamaCPP.LlamaChatModelFactory;
import com.devoxx.genie.chatmodel.local.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.local.ollama.OllamaChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openai.OpenAIChatModelFactory;
import com.devoxx.genie.chatmodel.cloud.openrouter.OpenRouterChatModelFactory;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -50,6 +51,7 @@ private ChatModelFactoryProvider() {
case "Groq" -> new GroqChatModelFactory();
case "GPT4All" -> new GPT4AllChatModelFactory();
case "Jan" -> new JanChatModelFactory();
case "LLaMA" -> new LlamaChatModelFactory();
case "LMStudio" -> new LMStudioChatModelFactory();
case "Mistral" -> new MistralChatModelFactory();
case "Ollama" -> new OllamaChatModelFactory();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.chatmodel.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.chatmodel.local.lmstudio.LMStudioChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.Constant;
import com.devoxx.genie.model.LanguageModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public abstract class LocalChatModelFactory implements ChatModelFactory {
protected List<LanguageModel> cachedModels = null;
protected static final ExecutorService executorService = Executors.newFixedThreadPool(5);
protected static boolean warningShown = false;
protected boolean providerRunning = false;
protected boolean providerChecked = false;

protected LocalChatModelFactory(ModelProvider modelProvider) {
this.modelProvider = modelProvider;
Expand Down Expand Up @@ -62,9 +64,18 @@ protected StreamingChatLanguageModel createLocalAiStreamingChatModel(@NotNull Ch

@Override
public List<LanguageModel> getModels() {
if (cachedModels != null) {
return cachedModels;
if (!providerChecked) {
checkAndFetchModels();
}
if (!providerRunning) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
"LLM provider is not running. Please start it and try again.");
return List.of();
}
return cachedModels;
}

private void checkAndFetchModels() {
List<LanguageModel> modelNames = new ArrayList<>();
List<CompletableFuture<Void>> futures = new ArrayList<>();
try {
Expand All @@ -77,25 +88,28 @@ public List<LanguageModel> getModels() {
modelNames.add(languageModel);
}
} catch (IOException e) {
handleModelFetchError(model, e);
handleModelFetchError(e);
}
}, executorService);
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
cachedModels = modelNames;
providerRunning = true;
} catch (IOException e) {
handleGeneralFetchError(e);
cachedModels = List.of();
providerRunning = false;
} finally {
providerChecked = true;
}
return cachedModels;
}

protected abstract Object[] fetchModels() throws IOException;

protected abstract LanguageModel buildLanguageModel(Object model) throws IOException;

protected void handleModelFetchError(Object model, @NotNull IOException e) {
protected void handleModelFetchError(@NotNull IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching model details: " + e.getMessage());
}

Expand All @@ -109,5 +123,7 @@ protected void handleGeneralFetchError(IOException e) {
@Override
public void resetModels() {
cachedModels = null;
providerChecked = false;
providerRunning = false;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.anthropic;
package com.devoxx.genie.chatmodel.cloud.anthropic;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.azureopenai;
package com.devoxx.genie.chatmodel.cloud.azureopenai;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.deepinfra;
package com.devoxx.genie.chatmodel.cloud.deepinfra;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.deepseek;
package com.devoxx.genie.chatmodel.cloud.deepseek;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.google;
package com.devoxx.genie.chatmodel.cloud.google;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.groq;
package com.devoxx.genie.chatmodel.cloud.groq;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.mistral;
package com.devoxx.genie.chatmodel.cloud.mistral;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.openai;
package com.devoxx.genie.chatmodel.cloud.openai;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.openrouter;
package com.devoxx.genie.chatmodel.cloud.openrouter;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.devoxx.genie.chatmodel.local;

import java.io.IOException;

public interface LocalLLMProvider {
Object getModels() throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.devoxx.genie.chatmodel.local;

import com.devoxx.genie.model.lmstudio.LMStudioModelEntryDTO;
import com.devoxx.genie.service.exception.UnsuccessfulRequestException;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;

import java.io.IOException;
import java.util.Objects;

import static com.devoxx.genie.util.HttpUtil.ensureEndsWithSlash;

public class LocalLLMProviderUtil {

private static final OkHttpClient client = new OkHttpClient();
private static final Gson gson = new Gson();

public static <T> T getModels(String baseUrlConfigKey, String endpoint, Class<T> responseType) throws IOException {
String configValue = DevoxxGenieStateService.getInstance().getConfigValue(baseUrlConfigKey);
String baseUrl = ensureEndsWithSlash(Objects.requireNonNull(configValue));

Request request = new Request.Builder()
.url(baseUrl + endpoint)
.build();

try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new UnsuccessfulRequestException("Unexpected code " + response);
}

if (response.body() == null) {
throw new UnsuccessfulRequestException("Response body is null");
}

String json = response.body().string();

// Special handling for LM Studio
if (responseType.equals(LMStudioModelEntryDTO[].class)) {
JsonElement jsonElement = gson.fromJson(json, JsonElement.class);
if (jsonElement.isJsonObject() && jsonElement.getAsJsonObject().has("data")) {
return gson.fromJson(jsonElement.getAsJsonObject().get("data"), responseType);
} else {
return responseType.cast(new LMStudioModelEntryDTO[0]);
}
}

return gson.fromJson(json, responseType);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.customopenai;
package com.devoxx.genie.chatmodel.local.customopenai;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package com.devoxx.genie.chatmodel.gpt4all;
package com.devoxx.genie.chatmodel.local.gpt4all;

import com.devoxx.genie.chatmodel.LocalChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.gpt4all.Model;
import com.devoxx.genie.service.gpt4all.GPT4AllService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
Expand Down Expand Up @@ -36,20 +35,19 @@ protected String getModelUrl() {

@Override
protected Model[] fetchModels() throws IOException {
return GPT4AllService.getInstance().getModels().toArray(new Model[0]);
return GPT4AllModelService.getInstance().getModels().toArray(new Model[0]);
}

@Override
protected LanguageModel buildLanguageModel(Object model) {
Model gpt4AllModel = (Model) model;
// int contextWindow = GPT4AllService.getInstance()
return LanguageModel.builder()
.provider(modelProvider)
.modelName(gpt4AllModel.getId())
.displayName(gpt4AllModel.getId())
.inputCost(0)
.outputCost(0)
// .contextWindow(contextWindow)
// .contextWindow(contextWindow) // GPT4All does not provide context window :(
.apiKeyUsed(false)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.devoxx.genie.chatmodel.local.gpt4all;

import com.devoxx.genie.chatmodel.local.LocalLLMProvider;
import com.devoxx.genie.chatmodel.local.LocalLLMProviderUtil;
import com.devoxx.genie.model.gpt4all.Model;
import com.devoxx.genie.model.gpt4all.ResponseDTO;
import com.intellij.openapi.application.ApplicationManager;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.List;

public class GPT4AllModelService implements LocalLLMProvider {

@NotNull
public static GPT4AllModelService getInstance() {
return ApplicationManager.getApplication().getService(GPT4AllModelService.class);
}

@Override
public List<Model> getModels() throws IOException {
return LocalLLMProviderUtil
.getModels("gpt4allModelUrl", "models", ResponseDTO.class)
.getData();
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package com.devoxx.genie.chatmodel.jan;
package com.devoxx.genie.chatmodel.local.jan;

import com.devoxx.genie.chatmodel.LocalChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.service.jan.JanService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
Expand Down Expand Up @@ -36,7 +35,7 @@ protected String getModelUrl() {

@Override
protected Data[] fetchModels() throws IOException {
return JanService.getInstance().getModels().toArray(new Data[0]);
return JanModelService.getInstance().getModels().toArray(new Data[0]);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.devoxx.genie.chatmodel.local.jan;

import com.devoxx.genie.chatmodel.local.LocalLLMProvider;
import com.devoxx.genie.chatmodel.local.LocalLLMProviderUtil;
import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.model.jan.ResponseDTO;
import com.intellij.openapi.application.ApplicationManager;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.List;

public class JanModelService implements LocalLLMProvider {

@NotNull
public static JanModelService getInstance() {
return ApplicationManager.getApplication().getService(JanModelService.class);
}

@Override
public List<Data> getModels() throws IOException {
return LocalLLMProviderUtil
.getModels("janModelUrl", "models", ResponseDTO.class)
.getData();
}
}
Loading
Loading