Skip to content

Commit

Permalink
Feature #175 Select which LLM providers you want to enable/disable
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed Dec 10, 2024
1 parent e9adb07 commit 74bb006
Show file tree
Hide file tree
Showing 34 changed files with 446 additions and 244 deletions.
3 changes: 1 addition & 2 deletions core/src/main/java/com/devoxx/genie/model/LanguageModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.devoxx.genie.model.enumarations.ModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull;

import java.util.Comparator;
Expand All @@ -21,7 +20,7 @@ public class LanguageModel implements Comparable<LanguageModel> {
private int contextWindow;

public LanguageModel() {
this(ModelProvider.OpenAI, "", "", false, 0.0, 0.0, 0);
this(ModelProvider.OPENAI, "", "", false, 0.0, 0.0, 0);
}

public LanguageModel(ModelProvider provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@

@Getter
public enum ModelProvider {
Ollama("Ollama", Type.LOCAL),
LMStudio("LMStudio", Type.LOCAL),
GPT4All("GPT4All", Type.LOCAL),
Jan("Jan", Type.LOCAL),
OpenAI("OpenAI", Type.CLOUD),
Anthropic("Anthropic", Type.CLOUD),
Mistral("Mistral", Type.CLOUD),
Groq("Groq", Type.CLOUD),
DeepInfra("DeepInfra", Type.CLOUD),
Google("Google", Type.CLOUD),
Exo("Exo (Experimental)", Type.LOCAL),
LLaMA("LLaMA.c++", Type.LOCAL),
OpenRouter("OpenRouter", Type.CLOUD),
DeepSeek("DeepSeek", Type.CLOUD),
Jlama("Jlama (Experimental /w REST API)", Type.LOCAL),
AzureOpenAI("AzureOpenAI", Type.OPTIONAL);
OPENAI("OpenAI", Type.CLOUD),
ANTHROPIC("Anthropic", Type.CLOUD),
MISTRAL("Mistral", Type.CLOUD),
GROQ("Groq", Type.CLOUD),
DEEP_INFRA("DeepInfra", Type.CLOUD),
GOOGLE("Google", Type.CLOUD),
LLAMA("LLaMA.c++", Type.LOCAL),
OPEN_ROUTER("OpenRouter", Type.CLOUD),
DEEP_SEEK("DeepSeek", Type.CLOUD),
AZURE_OPEN_AI("AzureOpenAI", Type.CLOUD),
OLLAMA("Ollama", Type.LOCAL),
LMSTUDIO("LMStudio", Type.LOCAL),
GPT_4_ALL("GPT4All", Type.LOCAL),
JAN("Jan", Type.LOCAL),
JLAMA("Jlama (Experimental /w REST API)", Type.LOCAL),
EXO("Exo (Experimental)", Type.LOCAL),
CUSTOM_OPEN_AI("CustomOpenAI", Type.OPTIONAL);

public enum Type {
LOCAL, // Local Providers
CLOUD, // Cloud Providers
OPTIONAL // Optional Providers(Need to be enabled from settings, due to inconvenient setup)
OPTIONAL // Optional Providers (Need to be enabled from settings, due to inconvenient setup)
}

private final String name;
Expand All @@ -39,6 +40,7 @@ public enum Type {
this.type = type;
}

@Override
public String toString() {
return name;
}
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@Setter
public class ChatModelProvider {

private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OpenAI; // Choose an appropriate default
private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OPENAI; // Choose an appropriate default

public ChatLanguageModel getChatLanguageModel(@NotNull ChatMessageContext chatMessageContext) {
ChatModel chatModel = initChatModel(chatMessageContext);
Expand Down Expand Up @@ -72,22 +72,22 @@ private void setLocalBaseUrl(@NotNull LanguageModel languageModel,
DevoxxGenieSettingsService stateService) {
// Set base URL for local providers
switch (languageModel.getProvider()) {
case LMStudio:
case LMSTUDIO:
chatModel.setBaseUrl(stateService.getLmstudioModelUrl());
break;
case Ollama:
case OLLAMA:
chatModel.setBaseUrl(stateService.getOllamaModelUrl());
break;
case GPT4All:
case GPT_4_ALL:
chatModel.setBaseUrl(stateService.getGpt4allModelUrl());
break;
case Exo:
case EXO:
chatModel.setBaseUrl(stateService.getExoModelUrl());
break;
case LLaMA:
case LLAMA:
chatModel.setBaseUrl(stateService.getLlamaCPPUrl());
break;
case Jlama:
case JLAMA:
chatModel.setBaseUrl(stateService.getJlamaUrl());
break;
// Add other local providers as needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Anthropic);
return getModels(ModelProvider.ANTHROPIC);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public String getApiKey() {
@Override
public List<LanguageModel> getModels() {
return List.of(LanguageModel.builder()
.provider(ModelProvider.AzureOpenAI)
.provider(ModelProvider.AZURE_OPEN_AI)
.modelName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.displayName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.inputCost(0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.DeepInfra);
return getModels(ModelProvider.DEEP_INFRA);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.DeepSeek);
return getModels(ModelProvider.DEEP_SEEK);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-405b")
.displayName("Llama 3.1 405B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(131_000)
Expand All @@ -65,7 +65,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-8b")
.displayName("Llama 3.1 8B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -75,7 +75,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-70b")
.displayName("Llama 3.1 70B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(131_000)
Expand All @@ -85,7 +85,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3-8b")
.displayName("Llama 3 8B")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -96,7 +96,7 @@ public List<LanguageModel> getModels() {
.modelName("mistral-nemo")
.displayName("Mistral Nemo")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -107,7 +107,7 @@ public List<LanguageModel> getModels() {
.modelName("mistral-large")
.displayName("Mistral Large")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -118,7 +118,7 @@ public List<LanguageModel> getModels() {
.modelName("deepseek-coder-v2-lite")
.displayName("Deepseek Coder V2 Lite")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -129,7 +129,7 @@ public List<LanguageModel> getModels() {
.modelName("llava-1.5-7b-hf")
.displayName("Llava 1.5 7B HF")
.apiKeyUsed(false)
.provider(ModelProvider.Exo)
.provider(ModelProvider.EXO)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Google);
return getModels(ModelProvider.GOOGLE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch
@Override
public List<LanguageModel> getModels() {
LanguageModel lmStudio = LanguageModel.builder()
.provider(ModelProvider.GPT4All)
.provider(ModelProvider.GPT_4_ALL)
.modelName("GPT4All")
.inputCost(0)
.outputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Groq);
return getModels(ModelProvider.GROQ);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public List<LanguageModel> getModels() {
for (Data model : models) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.Jan)
.provider(ModelProvider.JAN)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public List<LanguageModel> getModels() {
for (LMStudioModelEntryDTO model : lmStudioModels) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.LMStudio)
.provider(ModelProvider.LMSTUDIO)
.modelName(model.getId())
.displayName(model.getId())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.Mistral);
return getModels(ModelProvider.MISTRAL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public List<LanguageModel> getModels() {
try {
int contextWindow = OllamaApiService.getModelContext(model.getName());
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.Ollama)
.provider(ModelProvider.OLLAMA)
.modelName(model.getName())
.displayName(model.getName())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.OpenAI);
return getModels(ModelProvider.OPENAI);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public List<LanguageModel> getModels() {
double outputCost = convertAndScalePrice(model.getPricing().getCompletion());

LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.OpenRouter)
.provider(ModelProvider.OPEN_ROUTER)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(inputCost)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ private boolean validateAndPreparePrompt(String actionCommand,
private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsService stateService) {
ModelProvider selectedProvider = (ModelProvider) modelProviderComboBox.getSelectedItem();
if (selectedProvider != null &&
(selectedProvider.equals(ModelProvider.LMStudio) ||
selectedProvider.equals(ModelProvider.GPT4All) ||
selectedProvider.equals(ModelProvider.Jlama) ||
selectedProvider.equals(ModelProvider.LLaMA))) {
(selectedProvider.equals(ModelProvider.LMSTUDIO) ||
selectedProvider.equals(ModelProvider.GPT_4_ALL) ||
selectedProvider.equals(ModelProvider.JLAMA) ||
selectedProvider.equals(ModelProvider.LLAMA))) {
return LanguageModel.builder()
.provider(selectedProvider)
.apiKeyUsed(false)
Expand All @@ -169,7 +169,7 @@ private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsSer
} else {
String modelName = stateService.getSelectedLanguageModel(project.getLocationHash());
return LanguageModel.builder()
.provider(selectedProvider != null ? selectedProvider : ModelProvider.OpenAI)
.provider(selectedProvider != null ? selectedProvider : ModelProvider.OPENAI)
.modelName(modelName)
.apiKeyUsed(false)
.inputCost(0)
Expand Down
Loading

0 comments on commit 74bb006

Please sign in to comment.