Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature #175 Select which LLM providers you want to enable/disable #358

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions core/src/main/java/com/devoxx/genie/model/LanguageModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.devoxx.genie.model.enumarations.ModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull;

import java.util.Comparator;
Expand All @@ -21,7 +20,7 @@ public class LanguageModel implements Comparable<LanguageModel> {
private int contextWindow;

public LanguageModel() {
this(ModelProvider.OpenAI, "", "", false, 0.0, 0.0, 0);
this(ModelProvider.OPENAI, "", "", false, 0.0, 0.0, 0);
}

public LanguageModel(ModelProvider provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@

@Getter
public enum ModelProvider {
Ollama("Ollama", Type.LOCAL),
LMStudio("LMStudio", Type.LOCAL),
GPT4All("GPT4All", Type.LOCAL),
Jan("Jan", Type.LOCAL),
OpenAI("OpenAI", Type.CLOUD),
Anthropic("Anthropic", Type.CLOUD),
Mistral("Mistral", Type.CLOUD),
Groq("Groq", Type.CLOUD),
DeepInfra("DeepInfra", Type.CLOUD),
Google("Google", Type.CLOUD),
Exo("Exo (Experimental)", Type.LOCAL),
LLaMA("LLaMA.c++", Type.LOCAL),
OpenRouter("OpenRouter", Type.CLOUD),
DeepSeek("DeepSeek", Type.CLOUD),
Jlama("Jlama (Experimental /w REST API)", Type.LOCAL),
AzureOpenAI("AzureOpenAI", Type.OPTIONAL);
OPENAI("OpenAI", Type.CLOUD),
ANTHROPIC("Anthropic", Type.CLOUD),
MISTRAL("Mistral", Type.CLOUD),
GROQ("Groq", Type.CLOUD),
DEEP_INFRA("DeepInfra", Type.CLOUD),
GOOGLE("Google", Type.CLOUD),
LLAMA("LLaMA.c++", Type.LOCAL),
OPEN_ROUTER("OpenRouter", Type.CLOUD),
DEEP_SEEK("DeepSeek", Type.CLOUD),
AZURE_OPEN_AI("AzureOpenAI", Type.CLOUD),
OLLAMA("Ollama", Type.LOCAL),
LMSTUDIO("LMStudio", Type.LOCAL),
GPT_4_ALL("GPT4All", Type.LOCAL),
JAN("Jan", Type.LOCAL),
JLAMA("Jlama (Experimental /w REST API)", Type.LOCAL),
EXO("Exo (Experimental)", Type.LOCAL),
CUSTOM_OPEN_AI("CustomOpenAI", Type.OPTIONAL);

public enum Type {
LOCAL, // Local Providers
CLOUD, // Cloud Providers
OPTIONAL // Optional Providers(Need to be enabled from settings, due to inconvenient setup)
OPTIONAL // Optional Providers (Need to be enabled from settings, due to inconvenient setup)
}

private final String name;
Expand All @@ -39,6 +40,7 @@ public enum Type {
this.type = type;
}

@Override
public String toString() {
return name;
}
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@Setter
public class ChatModelProvider {

private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OpenAI; // Choose an appropriate default
private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OPENAI; // Choose an appropriate default

public ChatLanguageModel getChatLanguageModel(@NotNull ChatMessageContext chatMessageContext) {
ChatModel chatModel = initChatModel(chatMessageContext);
Expand Down Expand Up @@ -72,22 +72,22 @@ private void setLocalBaseUrl(@NotNull LanguageModel languageModel,
DevoxxGenieSettingsService stateService) {
// Set base URL for local providers
switch (languageModel.getProvider()) {
case LMStudio:
case LMSTUDIO:
chatModel.setBaseUrl(stateService.getLmstudioModelUrl());
break;
case Ollama:
case OLLAMA:
chatModel.setBaseUrl(stateService.getOllamaModelUrl());
break;
case GPT4All:
case GPT_4_ALL:
chatModel.setBaseUrl(stateService.getGpt4allModelUrl());
break;
case Exo:
case EXO:
chatModel.setBaseUrl(stateService.getExoModelUrl());
break;
case LLaMA:
case LLAMA:
chatModel.setBaseUrl(stateService.getLlamaCPPUrl());
break;
case Jlama:
case JLAMA:
chatModel.setBaseUrl(stateService.getJlamaUrl());
break;
// Add other local providers as needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Anthropic);
return getModels(ModelProvider.ANTHROPIC);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public String getApiKey() {
@Override
public List<LanguageModel> getModels() {
return List.of(LanguageModel.builder()
.provider(ModelProvider.AzureOpenAI)
.provider(ModelProvider.AZURE_OPEN_AI)
.modelName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.displayName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.inputCost(0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.DeepInfra);
return getModels(ModelProvider.DEEP_INFRA);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.DeepSeek);
return getModels(ModelProvider.DEEP_SEEK);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-405b")
.displayName("Llama 3.1 405B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(131_000)
Expand All @@ -65,7 +65,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-8b")
.displayName("Llama 3.1 8B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -75,7 +75,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-70b")
.displayName("Llama 3.1 70B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(131_000)
Expand All @@ -85,7 +85,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3-8b")
.displayName("Llama 3 8B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -96,7 +96,7 @@ public List<LanguageModel> getModels() {
.modelName("mistral-nemo")
.displayName("Mistral Nemo")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -107,7 +107,7 @@ public List<LanguageModel> getModels() {
.modelName("mistral-large")
.displayName("Mistral Large")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -118,7 +118,7 @@ public List<LanguageModel> getModels() {
.modelName("deepseek-coder-v2-lite")
.displayName("Deepseek Coder V2 Lite")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -129,7 +129,7 @@ public List<LanguageModel> getModels() {
.modelName("llava-1.5-7b-hf")
.displayName("Llava 1.5 7B HF")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Google);
return getModels(ModelProvider.GOOGLE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch
@Override
public List<LanguageModel> getModels() {
LanguageModel lmStudio = LanguageModel.builder()
.provider(ModelProvider.GPT4All)
.provider(ModelProvider.GPT_4_ALL)
.modelName("GPT4All")
.inputCost(0)
.outputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Groq);
return getModels(ModelProvider.GROQ);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public List<LanguageModel> getModels() {
for (Data model : models) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.Jan)
.provider(ModelProvider.JAN)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public List<LanguageModel> getModels() {
for (LMStudioModelEntryDTO model : lmStudioModels) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.LMStudio)
.provider(ModelProvider.LMSTUDIO)
.modelName(model.getId())
.displayName(model.getId())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Mistral);
return getModels(ModelProvider.MISTRAL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public List<LanguageModel> getModels() {
try {
int contextWindow = OllamaApiService.getModelContext(model.getName());
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.Ollama)
.provider(ModelProvider.OLLAMA)
.modelName(model.getName())
.displayName(model.getName())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.OpenAI);
return getModels(ModelProvider.OPENAI);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public List<LanguageModel> getModels() {
double outputCost = convertAndScalePrice(model.getPricing().getCompletion());

LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.OpenRouter)
.provider(ModelProvider.OPEN_ROUTER)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(inputCost)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ private boolean validateAndPreparePrompt(String actionCommand,
private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsService stateService) {
ModelProvider selectedProvider = (ModelProvider) modelProviderComboBox.getSelectedItem();
if (selectedProvider != null &&
(selectedProvider.equals(ModelProvider.LMStudio) ||
selectedProvider.equals(ModelProvider.GPT4All) ||
selectedProvider.equals(ModelProvider.Jlama) ||
selectedProvider.equals(ModelProvider.LLaMA))) {
(selectedProvider.equals(ModelProvider.LMSTUDIO) ||
selectedProvider.equals(ModelProvider.GPT_4_ALL) ||
selectedProvider.equals(ModelProvider.JLAMA) ||
selectedProvider.equals(ModelProvider.LLAMA))) {
return LanguageModel.builder()
.provider(selectedProvider)
.apiKeyUsed(false)
Expand All @@ -169,7 +169,7 @@ private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsSer
} else {
String modelName = stateService.getSelectedLanguageModel(project.getLocationHash());
return LanguageModel.builder()
.provider(selectedProvider != null ? selectedProvider : ModelProvider.OpenAI)
.provider(selectedProvider != null ? selectedProvider : ModelProvider.OPENAI)
.modelName(modelName)
.apiKeyUsed(false)
.inputCost(0)
Expand Down
Loading
Loading