-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #393 from devoxx/issue-392
Issue #392 GPT4All regression fixed
- Loading branch information
Showing
16 changed files
with
417 additions
and
274 deletions.
There are no files selected for viewing
32 changes: 32 additions & 0 deletions
32
core/src/main/java/com/devoxx/genie/model/gpt4all/Model.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package com.devoxx.genie.model.gpt4all; | ||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
|
||
import java.util.List; | ||
|
||
@Getter | ||
@Setter | ||
public class Model { | ||
@JsonProperty("created") | ||
private long created; | ||
|
||
@JsonProperty("id") | ||
private String id; | ||
|
||
@JsonProperty("object") | ||
private String object; | ||
|
||
@JsonProperty("owned_by") | ||
private String ownedBy; | ||
|
||
@JsonProperty("parent") | ||
private String parent; | ||
|
||
@JsonProperty("permissions") | ||
private List<ModelPermission> permissions; | ||
|
||
@JsonProperty("root") | ||
private String root; | ||
} |
46 changes: 46 additions & 0 deletions
46
core/src/main/java/com/devoxx/genie/model/gpt4all/ModelPermission.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package com.devoxx.genie.model.gpt4all; | ||
|
||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
|
||
@Getter | ||
@Setter | ||
public class ModelPermission { | ||
@JsonProperty("allow_create_engine") | ||
private boolean allowCreateEngine; | ||
|
||
@JsonProperty("allow_fine_tuning") | ||
private boolean allowFineTuning; | ||
|
||
@JsonProperty("allow_logprobs") | ||
private boolean allowLogprobs; | ||
|
||
@JsonProperty("allow_sampling") | ||
private boolean allowSampling; | ||
|
||
@JsonProperty("allow_search_indices") | ||
private boolean allowSearchIndices; | ||
|
||
@JsonProperty("allow_view") | ||
private boolean allowView; | ||
|
||
@JsonProperty("created") | ||
private long created; | ||
|
||
@JsonProperty("group") | ||
private String group; | ||
|
||
@JsonProperty("id") | ||
private String id; | ||
|
||
@JsonProperty("is_blocking") | ||
private boolean isBlocking; | ||
|
||
@JsonProperty("object") | ||
private String object; | ||
|
||
@JsonProperty("organization") | ||
private String organization; | ||
} |
17 changes: 17 additions & 0 deletions
17
core/src/main/java/com/devoxx/genie/model/gpt4all/ResponseDTO.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package com.devoxx.genie.model.gpt4all; | ||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
|
||
import java.util.List; | ||
|
||
@Getter | ||
@Setter | ||
public class ResponseDTO { | ||
@JsonProperty("data") | ||
private List<Model> data; | ||
|
||
@JsonProperty("object") | ||
private String object; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
src/main/java/com/devoxx/genie/chatmodel/LocalChatModelFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
package com.devoxx.genie.chatmodel; | ||
|
||
import com.devoxx.genie.model.ChatModel; | ||
import com.devoxx.genie.model.LanguageModel; | ||
import com.devoxx.genie.model.enumarations.ModelProvider; | ||
import com.devoxx.genie.ui.util.NotificationUtil; | ||
import com.intellij.openapi.project.ProjectManager; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.localai.LocalAiChatModel; | ||
import dev.langchain4j.model.localai.LocalAiStreamingChatModel; | ||
import org.jetbrains.annotations.NotNull; | ||
|
||
import java.io.IOException; | ||
import java.time.Duration; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutorService; | ||
import java.util.concurrent.Executors; | ||
|
||
public abstract class LocalChatModelFactory implements ChatModelFactory { | ||
|
||
protected final ModelProvider modelProvider; | ||
protected List<LanguageModel> cachedModels = null; | ||
protected static final ExecutorService executorService = Executors.newFixedThreadPool(5); | ||
protected static boolean warningShown = false; | ||
|
||
protected LocalChatModelFactory(ModelProvider modelProvider) { | ||
this.modelProvider = modelProvider; | ||
} | ||
|
||
@Override | ||
public abstract ChatLanguageModel createChatModel(@NotNull ChatModel chatModel); | ||
|
||
@Override | ||
public abstract StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel); | ||
|
||
protected abstract String getModelUrl(); | ||
|
||
protected ChatLanguageModel createLocalAiChatModel(@NotNull ChatModel chatModel) { | ||
return LocalAiChatModel.builder() | ||
.baseUrl(getModelUrl()) | ||
.modelName(chatModel.getModelName()) | ||
.maxRetries(chatModel.getMaxRetries()) | ||
.temperature(chatModel.getTemperature()) | ||
.maxTokens(chatModel.getMaxTokens()) | ||
.timeout(Duration.ofSeconds(chatModel.getTimeout())) | ||
.topP(chatModel.getTopP()) | ||
.build(); | ||
} | ||
|
||
protected StreamingChatLanguageModel createLocalAiStreamingChatModel(@NotNull ChatModel chatModel) { | ||
return LocalAiStreamingChatModel.builder() | ||
.baseUrl(getModelUrl()) | ||
.modelName(chatModel.getModelName()) | ||
.temperature(chatModel.getTemperature()) | ||
.topP(chatModel.getTopP()) | ||
.timeout(Duration.ofSeconds(chatModel.getTimeout())) | ||
.build(); | ||
} | ||
|
||
@Override | ||
public List<LanguageModel> getModels() { | ||
if (cachedModels != null) { | ||
return cachedModels; | ||
} | ||
List<LanguageModel> modelNames = new ArrayList<>(); | ||
List<CompletableFuture<Void>> futures = new ArrayList<>(); | ||
try { | ||
Object[] models = fetchModels(); | ||
for (Object model : models) { | ||
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> { | ||
try { | ||
LanguageModel languageModel = buildLanguageModel(model); | ||
synchronized (modelNames) { | ||
modelNames.add(languageModel); | ||
} | ||
} catch (IOException e) { | ||
handleModelFetchError(model, e); | ||
} | ||
}, executorService); | ||
futures.add(future); | ||
} | ||
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); | ||
cachedModels = modelNames; | ||
} catch (IOException e) { | ||
handleGeneralFetchError(e); | ||
cachedModels = List.of(); | ||
} | ||
return cachedModels; | ||
} | ||
|
||
protected abstract Object[] fetchModels() throws IOException; | ||
|
||
protected abstract LanguageModel buildLanguageModel(Object model) throws IOException; | ||
|
||
protected void handleModelFetchError(Object model, @NotNull IOException e) { | ||
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching model details: " + e.getMessage()); | ||
} | ||
|
||
protected void handleGeneralFetchError(IOException e) { | ||
if (!warningShown) { | ||
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(), "Error fetching models: " + e.getMessage()); | ||
warningShown = true; | ||
} | ||
} | ||
|
||
@Override | ||
public void resetModels() { | ||
cachedModels = null; | ||
} | ||
} |
69 changes: 32 additions & 37 deletions
69
src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,56 @@ | ||
package com.devoxx.genie.chatmodel.gpt4all; | ||
|
||
import com.devoxx.genie.chatmodel.ChatModelFactory; | ||
import com.devoxx.genie.chatmodel.LocalChatModelFactory; | ||
import com.devoxx.genie.model.ChatModel; | ||
import com.devoxx.genie.model.LanguageModel; | ||
import com.devoxx.genie.model.enumarations.ModelProvider; | ||
import com.devoxx.genie.model.gpt4all.Model; | ||
import com.devoxx.genie.service.gpt4all.GPT4AllService; | ||
import com.devoxx.genie.ui.settings.DevoxxGenieStateService; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.localai.LocalAiChatModel; | ||
import dev.langchain4j.model.localai.LocalAiStreamingChatModel; | ||
import org.jetbrains.annotations.NotNull; | ||
|
||
import java.time.Duration; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.io.IOException; | ||
|
||
public class GPT4AllChatModelFactory implements ChatModelFactory { | ||
public class GPT4AllChatModelFactory extends LocalChatModelFactory { | ||
|
||
private final ModelProvider MODEL_PROVIDER = ModelProvider.GPT4All;; | ||
public GPT4AllChatModelFactory() { | ||
super(ModelProvider.GPT4All); | ||
} | ||
|
||
@Override | ||
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) { | ||
return LocalAiChatModel.builder() | ||
.baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl()) | ||
.modelName(TEST_MODEL) | ||
.maxRetries(chatModel.getMaxRetries()) | ||
.maxTokens(chatModel.getMaxTokens()) | ||
.temperature(chatModel.getTemperature()) | ||
.timeout(Duration.ofSeconds(chatModel.getTimeout())) | ||
.topP(chatModel.getTopP()) | ||
.build(); | ||
return createLocalAiChatModel(chatModel); | ||
} | ||
|
||
@Override | ||
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) { | ||
return LocalAiStreamingChatModel.builder() | ||
.baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl()) | ||
.modelName(TEST_MODEL) | ||
.temperature(chatModel.getTemperature()) | ||
.topP(chatModel.getTopP()) | ||
.timeout(Duration.ofSeconds(chatModel.getTimeout())) | ||
.build(); | ||
return createLocalAiStreamingChatModel(chatModel); | ||
} | ||
|
||
@Override | ||
protected String getModelUrl() { | ||
return DevoxxGenieStateService.getInstance().getGpt4allModelUrl(); | ||
} | ||
|
||
@Override | ||
protected Model[] fetchModels() throws IOException { | ||
return GPT4AllService.getInstance().getModels().toArray(new Model[0]); | ||
} | ||
|
||
@Override | ||
public List<LanguageModel> getModels() { | ||
LanguageModel lmStudio = LanguageModel.builder() | ||
.provider(MODEL_PROVIDER) | ||
.modelName("GPT4All") | ||
.inputCost(0) | ||
.outputCost(0) | ||
.contextWindow(8000) | ||
.apiKeyUsed(false) | ||
.build(); | ||
|
||
List<LanguageModel> modelNames = new ArrayList<>(); | ||
modelNames.add(lmStudio); | ||
return modelNames; | ||
protected LanguageModel buildLanguageModel(Object model) { | ||
Model gpt4AllModel = (Model) model; | ||
// int contextWindow = GPT4AllService.getInstance() | ||
return LanguageModel.builder() | ||
.provider(modelProvider) | ||
.modelName(gpt4AllModel.getId()) | ||
.displayName(gpt4AllModel.getId()) | ||
.inputCost(0) | ||
.outputCost(0) | ||
// .contextWindow(contextWindow) | ||
.apiKeyUsed(false) | ||
.build(); | ||
} | ||
} |
Oops, something went wrong.