Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #359 #360

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class LanguageModel implements Comparable<LanguageModel> {
private int contextWindow;

public LanguageModel() {
this(ModelProvider.OPENAI, "", "", false, 0.0, 0.0, 0);
this(ModelProvider.OpenAI, "", "", false, 0.0, 0.0, 0);
}

public LanguageModel(ModelProvider provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,28 @@

@Getter
public enum ModelProvider {
OPENAI("OpenAI", Type.CLOUD),
ANTHROPIC("Anthropic", Type.CLOUD),
MISTRAL("Mistral", Type.CLOUD),
GROQ("Groq", Type.CLOUD),
DEEP_INFRA("DeepInfra", Type.CLOUD),
GOOGLE("Google", Type.CLOUD),
LLAMA("LLaMA.c++", Type.LOCAL),
OPEN_ROUTER("OpenRouter", Type.CLOUD),
DEEP_SEEK("DeepSeek", Type.CLOUD),
AZURE_OPEN_AI("AzureOpenAI", Type.CLOUD),
OLLAMA("Ollama", Type.LOCAL),
LMSTUDIO("LMStudio", Type.LOCAL),
GPT_4_ALL("GPT4All", Type.LOCAL),
JAN("Jan", Type.LOCAL),
JLAMA("Jlama (Experimental /w REST API)", Type.LOCAL),
EXO("Exo (Experimental)", Type.LOCAL),
CUSTOM_OPEN_AI("CustomOpenAI", Type.OPTIONAL);
Ollama("Ollama", Type.LOCAL),
LMStudio("LMStudio", Type.LOCAL),
GPT4All("GPT4All", Type.LOCAL),
Jan("Jan", Type.LOCAL),
OpenAI("OpenAI", Type.CLOUD),
Anthropic("Anthropic", Type.CLOUD),
Mistral("Mistral", Type.CLOUD),
Groq("Groq", Type.CLOUD),
DeepInfra("DeepInfra", Type.CLOUD),
Google("Google", Type.CLOUD),
Exo("Exo (Experimental)", Type.LOCAL),
LLaMA("LLaMA.c++", Type.LOCAL),
OpenRouter("OpenRouter", Type.CLOUD),
DeepSeek("DeepSeek", Type.CLOUD),
Jlama("Jlama (Experimental /w REST API)", Type.LOCAL),
AzureOpenAI("AzureOpenAI", Type.OPTIONAL),
CustomOpenAI("CustomOpenAI", Type.OPTIONAL);

public enum Type {
LOCAL, // Local Providers
CLOUD, // Cloud Providers
OPTIONAL // Optional Providers (Need to be enabled from settings, due to inconvenient setup)
OPTIONAL // Optional Providers(Need to be enabled from settings, due to inconvenient setup)
}

private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

public class ChatModelFactoryProvider {

private ChatModelFactoryProvider() {
throw new IllegalStateException("Utility class");
}

private static final Map<String, ChatModelFactory> factoryCache = new ConcurrentHashMap<>();

public static @NotNull Optional<ChatModelFactory> getFactoryByProvider(@NotNull String modelProvider) {
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@Setter
public class ChatModelProvider {

private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OPENAI; // Choose an appropriate default
private static final ModelProvider DEFAULT_PROVIDER = ModelProvider.OpenAI; // Choose an appropriate default

public ChatLanguageModel getChatLanguageModel(@NotNull ChatMessageContext chatMessageContext) {
ChatModel chatModel = initChatModel(chatMessageContext);
Expand Down Expand Up @@ -72,22 +72,22 @@ private void setLocalBaseUrl(@NotNull LanguageModel languageModel,
DevoxxGenieSettingsService stateService) {
// Set base URL for local providers
switch (languageModel.getProvider()) {
case LMSTUDIO:
case LMStudio:
chatModel.setBaseUrl(stateService.getLmstudioModelUrl());
break;
case OLLAMA:
case Ollama:
chatModel.setBaseUrl(stateService.getOllamaModelUrl());
break;
case GPT_4_ALL:
case GPT4All:
chatModel.setBaseUrl(stateService.getGpt4allModelUrl());
break;
case EXO:
case Exo:
chatModel.setBaseUrl(stateService.getExoModelUrl());
break;
case LLAMA:
case LLaMA:
chatModel.setBaseUrl(stateService.getLlamaCPPUrl());
break;
case JLAMA:
case Jlama:
chatModel.setBaseUrl(stateService.getJlamaUrl());
break;
// Add other local providers as needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.ANTHROPIC);
return getModels(ModelProvider.Anthropic);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public String getApiKey() {
@Override
public List<LanguageModel> getModels() {
return List.of(LanguageModel.builder()
.provider(ModelProvider.AZURE_OPEN_AI)
.provider(ModelProvider.AzureOpenAI)
.modelName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.displayName(DevoxxGenieStateService.getInstance().getAzureOpenAIDeployment())
.inputCost(0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.DEEP_INFRA);
return getModels(ModelProvider.DeepInfra);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.DEEP_SEEK);
return getModels(ModelProvider.DeepSeek);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-405b")
.displayName("Llama 3.1 405B")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(131_000)
Expand All @@ -65,7 +65,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-8b")
.displayName("Llama 3.1 8B")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -75,7 +75,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3.1-70b")
.displayName("Llama 3.1 70B")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(131_000)
Expand All @@ -85,7 +85,7 @@ public List<LanguageModel> getModels() {
.modelName("llama-3-8b")
.displayName("Llama 3 8B")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -96,7 +96,7 @@ public List<LanguageModel> getModels() {
.modelName("mistral-nemo")
.displayName("Mistral Nemo")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -107,7 +107,7 @@ public List<LanguageModel> getModels() {
.modelName("mistral-large")
.displayName("Mistral Large")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -118,7 +118,7 @@ public List<LanguageModel> getModels() {
.modelName("deepseek-coder-v2-lite")
.displayName("Deepseek Coder V2 Lite")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand All @@ -129,7 +129,7 @@ public List<LanguageModel> getModels() {
.modelName("llava-1.5-7b-hf")
.displayName("Llava 1.5 7B HF")
.apiKeyUsed(false)
.provider(ModelProvider.EXO)
.provider(ModelProvider.Exo)
.outputCost(0)
.inputCost(0)
.contextWindow(8_000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.GOOGLE);
return getModels(ModelProvider.Google);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
.build();
}

@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(DevoxxGenieStateService.getInstance().getGpt4allModelUrl())
Expand All @@ -43,7 +44,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch
@Override
public List<LanguageModel> getModels() {
LanguageModel lmStudio = LanguageModel.builder()
.provider(ModelProvider.GPT_4_ALL)
.provider(ModelProvider.GPT4All)
.modelName("GPT4All")
.inputCost(0)
.outputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.GROQ);
return getModels(ModelProvider.Groq);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public List<LanguageModel> getModels() {
for (Data model : models) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.JAN)
.provider(ModelProvider.Jan)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public List<LanguageModel> getModels() {
for (LMStudioModelEntryDTO model : lmStudioModels) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.LMSTUDIO)
.provider(ModelProvider.LMStudio)
.modelName(model.getId())
.displayName(model.getId())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.MISTRAL);
return getModels(ModelProvider.Mistral);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public List<LanguageModel> getModels() {
try {
int contextWindow = OllamaApiService.getModelContext(model.getName());
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.OLLAMA)
.provider(ModelProvider.Ollama)
.modelName(model.getName())
.displayName(model.getName())
.inputCost(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ public String getApiKey() {

@Override
public List<LanguageModel> getModels() {
return getModels(ModelProvider.OPENAI);
return getModels(ModelProvider.OpenAI);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public List<LanguageModel> getModels() {
double outputCost = convertAndScalePrice(model.getPricing().getCompletion());

LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.OPEN_ROUTER)
.provider(ModelProvider.OpenRouter)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(inputCost)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ private boolean validateAndPreparePrompt(String actionCommand,
private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsService stateService) {
ModelProvider selectedProvider = (ModelProvider) modelProviderComboBox.getSelectedItem();
if (selectedProvider != null &&
(selectedProvider.equals(ModelProvider.LMSTUDIO) ||
selectedProvider.equals(ModelProvider.GPT_4_ALL) ||
selectedProvider.equals(ModelProvider.JLAMA) ||
selectedProvider.equals(ModelProvider.LLAMA))) {
(selectedProvider.equals(ModelProvider.LMStudio) ||
selectedProvider.equals(ModelProvider.GPT4All) ||
selectedProvider.equals(ModelProvider.Jlama) ||
selectedProvider.equals(ModelProvider.LLaMA))) {
return LanguageModel.builder()
.provider(selectedProvider)
.apiKeyUsed(false)
Expand All @@ -169,7 +169,7 @@ private LanguageModel createDefaultLanguageModel(@NotNull DevoxxGenieSettingsSer
} else {
String modelName = stateService.getSelectedLanguageModel(project.getLocationHash());
return LanguageModel.builder()
.provider(selectedProvider != null ? selectedProvider : ModelProvider.OPENAI)
.provider(selectedProvider != null ? selectedProvider : ModelProvider.OpenAI)
.modelName(modelName)
.apiKeyUsed(false)
.inputCost(0)
Expand Down
Loading
Loading