Skip to content

Commit

Permalink
Merge pull request #393 from devoxx/issue-392
Browse files Browse the repository at this point in the history
Issue #392 GPT4All regression fixed
  • Loading branch information
stephanj authored Dec 16, 2024
2 parents f66a030 + aaa3343 commit 6ed3c0b
Show file tree
Hide file tree
Showing 16 changed files with 417 additions and 274 deletions.
32 changes: 32 additions & 0 deletions core/src/main/java/com/devoxx/genie/model/gpt4all/Model.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.devoxx.genie.model.gpt4all;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Getter;
import lombok.Setter;

import java.util.List;

@Getter
@Setter
public class Model {
@JsonProperty("created")
private long created;

@JsonProperty("id")
private String id;

@JsonProperty("object")
private String object;

@JsonProperty("owned_by")
private String ownedBy;

@JsonProperty("parent")
private String parent;

@JsonProperty("permissions")
private List<ModelPermission> permissions;

@JsonProperty("root")
private String root;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.devoxx.genie.model.gpt4all;


import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class ModelPermission {
@JsonProperty("allow_create_engine")
private boolean allowCreateEngine;

@JsonProperty("allow_fine_tuning")
private boolean allowFineTuning;

@JsonProperty("allow_logprobs")
private boolean allowLogprobs;

@JsonProperty("allow_sampling")
private boolean allowSampling;

@JsonProperty("allow_search_indices")
private boolean allowSearchIndices;

@JsonProperty("allow_view")
private boolean allowView;

@JsonProperty("created")
private long created;

@JsonProperty("group")
private String group;

@JsonProperty("id")
private String id;

@JsonProperty("is_blocking")
private boolean isBlocking;

@JsonProperty("object")
private String object;

@JsonProperty("organization")
private String organization;
}
17 changes: 17 additions & 0 deletions core/src/main/java/com/devoxx/genie/model/gpt4all/ResponseDTO.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.devoxx.genie.model.gpt4all;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Getter;
import lombok.Setter;

import java.util.List;

@Getter
@Setter
public class ResponseDTO {
@JsonProperty("data")
private List<Model> data;

@JsonProperty("object")
private String object;
}
12 changes: 12 additions & 0 deletions core/src/main/java/com/devoxx/genie/model/jan/Data.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ public class Data {
@JsonProperty("object")
private String object;

@JsonProperty("ctx_len")
private Integer ctxLen;

@JsonProperty("max_tokens")
private Integer maxTokens;

@JsonProperty("top_k")
private Integer topK;

@JsonProperty("top_p")
private Double topP;

@JsonProperty("name")
private String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
@Getter
public class ResponseDTO {

@JsonProperty("object")
private String object;

@JsonProperty("data")
private List<Data> data;

@JsonProperty("object")
private String object;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.devoxx.genie.chatmodel.deepseek.DeepSeekChatModelFactory;
import com.devoxx.genie.chatmodel.exo.ExoChatModelFactory;
import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.jan.JanChatModelFactory;
import com.devoxx.genie.chatmodel.jlama.JLamaChatModelFactory;
Expand Down Expand Up @@ -55,6 +56,7 @@ private ChatModelFactoryProvider() {
case "DeepSeek" -> new DeepSeekChatModelFactory();
case "Jlama" -> new JLamaChatModelFactory();
case "AzureOpenAI" -> new AzureOpenAIChatModelFactory();
case "GPT4All" -> new GPT4AllChatModelFactory();
default -> null;
};
}
Expand Down
113 changes: 113 additions & 0 deletions src/main/java/com/devoxx/genie/chatmodel/LocalChatModelFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.project.ProjectManager;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.localai.LocalAiStreamingChatModel;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public abstract class LocalChatModelFactory implements ChatModelFactory {

protected final ModelProvider modelProvider;
protected List<LanguageModel> cachedModels = null;
protected static final ExecutorService executorService = Executors.newFixedThreadPool(5);
protected static boolean warningShown = false;

protected LocalChatModelFactory(ModelProvider modelProvider) {
this.modelProvider = modelProvider;
}

@Override
public abstract ChatLanguageModel createChatModel(@NotNull ChatModel chatModel);

@Override
public abstract StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel);

protected abstract String getModelUrl();

protected ChatLanguageModel createLocalAiChatModel(@NotNull ChatModel chatModel) {
return LocalAiChatModel.builder()
.baseUrl(getModelUrl())
.modelName(chatModel.getModelName())
.maxRetries(chatModel.getMaxRetries())
.temperature(chatModel.getTemperature())
.maxTokens(chatModel.getMaxTokens())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.topP(chatModel.getTopP())
.build();
}

protected StreamingChatLanguageModel createLocalAiStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(getModelUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.build();
}

@Override
public List<LanguageModel> getModels() {
if (cachedModels != null) {
return cachedModels;
}
List<LanguageModel> modelNames = new ArrayList<>();
List<CompletableFuture<Void>> futures = new ArrayList<>();
try {
Object[] models = fetchModels();
for (Object model : models) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
try {
LanguageModel languageModel = buildLanguageModel(model);
synchronized (modelNames) {
modelNames.add(languageModel);
}
} catch (IOException e) {
handleModelFetchError(model, e);
}
}, executorService);
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
cachedModels = modelNames;
} catch (IOException e) {
handleGeneralFetchError(e);
cachedModels = List.of();
}
return cachedModels;
}

protected abstract Object[] fetchModels() throws IOException;

protected abstract LanguageModel buildLanguageModel(Object model) throws IOException;

protected void handleModelFetchError(Object model, @NotNull IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching model details: " + e.getMessage());
}

protected void handleGeneralFetchError(IOException e) {
if (!warningShown) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching models: " + e.getMessage());
warningShown = true;
}
}

@Override
public void resetModels() {
cachedModels = null;
}
}
Original file line number Diff line number Diff line change
@@ -1,61 +1,56 @@
package com.devoxx.genie.chatmodel.gpt4all;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.chatmodel.LocalChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.gpt4all.Model;
import com.devoxx.genie.service.gpt4all.GPT4AllService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.localai.LocalAiStreamingChatModel;
import org.jetbrains.annotations.NotNull;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.io.IOException;

public class GPT4AllChatModelFactory implements ChatModelFactory {
public class GPT4AllChatModelFactory extends LocalChatModelFactory {

private final ModelProvider MODEL_PROVIDER = ModelProvider.GPT4All;;
public GPT4AllChatModelFactory() {
super(ModelProvider.GPT4All);
}

@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
return LocalAiChatModel.builder()
.baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl())
.modelName(TEST_MODEL)
.maxRetries(chatModel.getMaxRetries())
.maxTokens(chatModel.getMaxTokens())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.topP(chatModel.getTopP())
.build();
return createLocalAiChatModel(chatModel);
}

@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl())
.modelName(TEST_MODEL)
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.build();
return createLocalAiStreamingChatModel(chatModel);
}

@Override
protected String getModelUrl() {
return DevoxxGenieStateService.getInstance().getGpt4allModelUrl();
}

@Override
protected Model[] fetchModels() throws IOException {
return GPT4AllService.getInstance().getModels().toArray(new Model[0]);
}

@Override
public List<LanguageModel> getModels() {
LanguageModel lmStudio = LanguageModel.builder()
.provider(MODEL_PROVIDER)
.modelName("GPT4All")
.inputCost(0)
.outputCost(0)
.contextWindow(8000)
.apiKeyUsed(false)
.build();

List<LanguageModel> modelNames = new ArrayList<>();
modelNames.add(lmStudio);
return modelNames;
protected LanguageModel buildLanguageModel(Object model) {
Model gpt4AllModel = (Model) model;
// int contextWindow = GPT4AllService.getInstance()
return LanguageModel.builder()
.provider(modelProvider)
.modelName(gpt4AllModel.getId())
.displayName(gpt4AllModel.getId())
.inputCost(0)
.outputCost(0)
// .contextWindow(contextWindow)
.apiKeyUsed(false)
.build();
}
}
Loading

0 comments on commit 6ed3c0b

Please sign in to comment.