Skip to content

Commit

Permalink
Split language model tests into separated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed Jul 20, 2024
1 parent 7bba59d commit 9b02a96
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#Fri Jul 19 00:52:35 EEST 2024
#Fri Jul 19 01:07:30 EEST 2024
version=0.2.5
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.gemini.GeminiChatModel;
import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
Expand All @@ -13,22 +14,20 @@
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import io.github.cdimascio.dotenv.Dotenv;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.api.Test;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class PromptExecutionServiceIT extends AbstractLightPlatformTestCase {

private static Dotenv dotenv;
private PromptExecutionService promptExecutionService;

@BeforeAll
static void loadEnvironment() {
Expand All @@ -39,6 +38,7 @@ static void loadEnvironment() {
public void setUp() throws Exception {
super.setUp();
mockSettingsState();
promptExecutionService = new PromptExecutionService();
}

private void mockSettingsState() {
Expand All @@ -52,26 +52,77 @@ private void mockSettingsState() {
ServiceContainerUtil.replaceService(ApplicationManager.getApplication(), DevoxxGenieStateService.class, settingsStateMock, getTestRootDisposable());
}

@ParameterizedTest
@MethodSource("provideModels")
public void testExecuteQuery(LanguageModel languageModel) {
PromptExecutionService promptExecutionService = new PromptExecutionService();
@Test
public void testExecuteQueryOpenAI() {
LanguageModel model = LanguageModel.builder()
.provider(ModelProvider.OpenAI)
.modelName("gpt-3.5-turbo")
.displayName("GPT-3.5 Turbo")
.apiKeyUsed(true)
.inputCost(0.0)
.outputCost(0.0)
.contextWindow(4096)
.build();
verifyResponse(createChatModel(model), model);
}

ChatMessageContext context = ChatMessageContext.builder()
.userPrompt("What is the capital of Belgium?")
.chatLanguageModel(createChatModel(languageModel))
.languageModel(languageModel)
.project(getProject())
@Test
public void testExecuteQueryAnthropic() {
LanguageModel model = LanguageModel.builder()
.provider(ModelProvider.Anthropic)
.modelName("claude-3-5-sonnet-20240620")
.displayName("claude-3-5-sonnet-20240620")
.apiKeyUsed(true)
.inputCost(0.0)
.outputCost(0.0)
.contextWindow(100000)
.build();
verifyResponse(createChatModel(model), model);
}

@Test
public void testExecuteQueryGemini() {
LanguageModel model = LanguageModel.builder()
.provider(ModelProvider.Gemini)
.modelName("gemini-pro")
.displayName("Gemini Pro")
.apiKeyUsed(true)
.inputCost(0.0)
.outputCost(0.0)
.contextWindow(32768)
.build();
verifyResponse(createChatModel(model), model);
}

verifyResponse(promptExecutionService, context);
@Test
public void testExecuteQueryMistral() {
LanguageModel model = LanguageModel.builder()
.provider(ModelProvider.Mistral)
.modelName("mistral-medium")
.displayName("Mistral Medium")
.apiKeyUsed(true)
.inputCost(0.0)
.outputCost(0.0)
.contextWindow(32768)
.build();
verifyResponse(createChatModel(model), model);
}

private static Stream<LanguageModel> provideModels() {
return new LLMModelRegistryService().getModels().stream();
@Test
public void testExecuteQueryDeepInfra() {
LanguageModel model = LanguageModel.builder()
.provider(ModelProvider.DeepInfra)
.modelName("mistralai/Mixtral-8x7B-Instruct-v0.1")
.displayName("Mixtral 8x7B")
.apiKeyUsed(true)
.inputCost(0.0)
.outputCost(0.0)
.contextWindow(32768)
.build();
verifyResponse(createChatModel(model), model);
}

private ChatLanguageModel createChatModel(@NotNull LanguageModel languageModel) {
private ChatLanguageModel createChatModel(LanguageModel languageModel) {
return switch (languageModel.getProvider()) {
case OpenAI -> OpenAiChatModel.builder()
.apiKey(dotenv.get("OPENAI_API_KEY"))
Expand All @@ -98,8 +149,15 @@ private ChatLanguageModel createChatModel(@NotNull LanguageModel languageModel)
};
}

private static void verifyResponse(@NotNull PromptExecutionService promptExecutionService,
ChatMessageContext context) {
private void verifyResponse(ChatLanguageModel chatModel, LanguageModel languageModel) {
ChatMessageContext context = ChatMessageContext.builder()
.userPrompt("What is the capital of Belgium?")
.chatLanguageModel(chatModel)
.languageModel(languageModel)
.project(getProject())
.totalFileCount(1)
.build();

CompletableFuture<Optional<AiMessage>> response = promptExecutionService.executeQuery(context);
assertNotNull(response);

Expand Down

0 comments on commit 9b02a96

Please sign in to comment.