Skip to content

Commit

Permalink
Added LLaMA.c++ as local option
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed Aug 2, 2024
1 parent c0243f6 commit 6dd8e73
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.devoxx.genie.chatmodel.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.deepinfra.DeepInfraChatModelFactory;
import com.devoxx.genie.chatmodel.exo.ExoChatModelFactory;
import com.devoxx.genie.chatmodel.gemini.GeminiChatModelFactory;
import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.mistral.MistralChatModelFactory;
import com.devoxx.genie.chatmodel.ollama.OllamaChatModelFactory;
Expand All @@ -28,7 +28,7 @@ public class ChatModelFactoryProvider {
ModelProvider.Mistral.getName(), MistralChatModelFactory::new,
ModelProvider.Groq.getName(), GroqChatModelFactory::new,
ModelProvider.DeepInfra.getName(), DeepInfraChatModelFactory::new,
ModelProvider.Google.getName(), GeminiChatModelFactory::new
ModelProvider.Google.getName(), GoogleChatModelFactory::new
);

/**
Expand Down
30 changes: 3 additions & 27 deletions src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.anthropic.AnthropicChatModelFactory;
import com.devoxx.genie.chatmodel.exo.ExoChatModelFactory;
import com.devoxx.genie.chatmodel.gemini.GeminiChatModelFactory;
import com.devoxx.genie.chatmodel.google.GoogleChatModelFactory;
import com.devoxx.genie.chatmodel.gpt4all.GPT4AllChatModelFactory;
import com.devoxx.genie.chatmodel.groq.GroqChatModelFactory;
import com.devoxx.genie.chatmodel.llama.LlamaChatModelFactory;
Expand All @@ -19,9 +19,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Setter;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.HashMap;
import java.util.Map;
Expand All @@ -43,9 +41,10 @@ public ChatModelProvider() {
factories.put(ModelProvider.Mistral, new MistralChatModelFactory());
factories.put(ModelProvider.Anthropic, new AnthropicChatModelFactory());
factories.put(ModelProvider.Groq, new GroqChatModelFactory());
factories.put(ModelProvider.Google, new GeminiChatModelFactory());
factories.put(ModelProvider.Google, new GoogleChatModelFactory());
factories.put(ModelProvider.Exo, new ExoChatModelFactory());
factories.put(ModelProvider.LLaMA, new LlamaChatModelFactory());

// TODO Currently broken by latest Jan! version
// factories.put(ModelProvider.Jan, new JanChatModelFactory());
}
Expand Down Expand Up @@ -114,29 +113,6 @@ private void setLocalBaseUrl(@NotNull LanguageModel languageModel,
}
}

// // TODO: This method is duplicated in multiple places. Consider moving it to a common utility class.
// private String getModelName(@Nullable LanguageModel languageModel) {
// if (languageModel == null) {
// return getDefaultModelName(null);
// }
//
// return Optional.ofNullable(languageModel.getModelName())
// .orElseGet(() -> getDefaultModelName(languageModel.getProvider()));
// }
//
// @Contract(pure = true)
// private @NotNull String getDefaultModelName(@Nullable ModelProvider provider) {
// if (provider == null) {
// return TEST_MODEL;
// }
// return switch (provider) {
// case LMStudio -> "LMStudio";
// case GPT4All -> "GPT4All";
// case LLaMA -> "LLaMA";
// default -> "DefaultModel";
// };
// }

private static void setMaxOutputTokens(@NotNull DevoxxGenieStateService settingsState,
@NotNull ChatModel chatModel) {
Integer maxOutputTokens = settingsState.getMaxOutputTokens();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.gemini;
package com.devoxx.genie.chatmodel.google;

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
Expand All @@ -12,7 +12,7 @@
import java.time.Duration;
import java.util.List;

public class GeminiChatModelFactory implements ChatModelFactory {
public class GoogleChatModelFactory implements ChatModelFactory {

@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/com/devoxx/genie/model/GenericOpenAIProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.devoxx.genie.model;

import lombok.Data;

@Data
public class GenericOpenAIProvider {
private String name;
private String baseUrl;
private String modelName;
private String apiKey;
private Double inputCost;
private Double outputCost;
private Integer contextWindow;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public enum ModelProvider {
DeepInfra("DeepInfra"),
Google("Google"),
Exo("Exo (Experimental)"),
LLaMA("LLaMA.c++"),;
LLaMA("LLaMA.c++");

private final String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,15 +446,6 @@ public void setModels(List<LanguageModel> models) {
this.models = new ArrayList<>(models);
}

public void addModel(LanguageModel model) {
models.add(model);
}

public void updateModel(LanguageModel updateModel) {
models.removeIf(model -> model.getModelName().equals(updateModel.getModelName()));
models.add(updateModel);
}

@Override
public LLMModelRegistryService getState() {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ public void setModelCost(ModelProvider provider,
}
}

// public double getModelInputCost(ModelProvider provider, String modelName) {
// if (DefaultLLMSettingsUtil.isApiBasedProvider(provider)) {
// String key = provider.getName() + ":" + modelName;
// return modelInputCosts.getOrDefault(key,
// DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0));
// }
// return 0.0;
// }

public double getModelInputCost(@NotNull ModelProvider provider, String modelName) {
String key = provider.getName() + ":" + modelName;
double cost = modelInputCosts.getOrDefault(key, 0.0);
Expand All @@ -165,15 +156,6 @@ public double getModelInputCost(@NotNull ModelProvider provider, String modelNam
return cost;
}

public double getModelOutputCost(ModelProvider provider, String modelName) {
if (DefaultLLMSettingsUtil.isApiBasedProvider(provider)) {
String key = provider.getName() + ":" + modelName;
return modelOutputCosts.getOrDefault(key,
DefaultLLMSettingsUtil.DEFAULT_OUTPUT_COSTS.getOrDefault(new DefaultLLMSettingsUtil.CostKey(provider, modelName), 0.0));
}
return 0.0;
}

private void initializeDefaultCostsIfEmpty() {
if (modelInputCosts.isEmpty()) {
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.entrySet()) {
Expand Down
1 change: 1 addition & 0 deletions src/main/resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
<h2>v0.2.10</h2>
<UL>
<LI>Fix #184 - Input panel has bigger min/preferred height size</LI>
<LI>Feat #186 - Support for local LLaMA.c++ http server</LI>
</UL>
<h2>v0.2.9</h2>
<UL>
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#Thu Aug 01 20:57:17 CEST 2024
#Fri Aug 02 17:59:31 CEST 2024
version=0.2.10
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.devoxx.genie.chatmodel.gemini;
package com.devoxx.genie.chatmodel.google;

import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase;
import com.devoxx.genie.model.ChatModel;
Expand Down Expand Up @@ -37,7 +37,7 @@ public void setUp() throws Exception {
@Test
public void createChatModel() {
// Instance of the class containing the method to be tested
var factory = new GeminiChatModelFactory();
var factory = new GoogleChatModelFactory();

// Create a dummy ChatModel
ChatModel chatModel = new ChatModel();
Expand All @@ -54,7 +54,7 @@ public void createChatModel() {

@Test
public void testModelNames() {
GeminiChatModelFactory factory = new GeminiChatModelFactory();
GoogleChatModelFactory factory = new GoogleChatModelFactory();
Assertions.assertThat(factory.getModels()).isNotEmpty();

List<LanguageModel> modelNames = factory.getModels();
Expand Down

0 comments on commit 6dd8e73

Please sign in to comment.