diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index fd9ceec4..0fd5fbe6 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,2 +1,2 @@ -#Fri Jul 19 00:52:35 EEST 2024 +#Fri Jul 19 01:07:30 EEST 2024 version=0.2.5 diff --git a/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java b/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java index 28205bb0..dc3986d9 100644 --- a/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java +++ b/src/test/java/com/devoxx/genie/service/PromptExecutionServiceIT.java @@ -2,6 +2,7 @@ import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase; import com.devoxx.genie.model.LanguageModel; +import com.devoxx.genie.model.enumarations.ModelProvider; import com.devoxx.genie.model.gemini.GeminiChatModel; import com.devoxx.genie.model.request.ChatMessageContext; import com.devoxx.genie.ui.settings.DevoxxGenieStateService; @@ -13,15 +14,12 @@ import dev.langchain4j.model.mistralai.MistralAiChatModel; import dev.langchain4j.model.openai.OpenAiChatModel; import io.github.cdimascio.dotenv.Dotenv; -import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.api.Test; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -29,6 +27,7 @@ public class PromptExecutionServiceIT extends AbstractLightPlatformTestCase { private static Dotenv dotenv; + private PromptExecutionService promptExecutionService; @BeforeAll static void loadEnvironment() { @@ -39,6 +38,7 @@ static void loadEnvironment() { public void setUp() throws Exception { super.setUp(); mockSettingsState(); + promptExecutionService = new PromptExecutionService(); } private void mockSettingsState() { @@ -52,26 +52,77 @@ private void mockSettingsState() { ServiceContainerUtil.replaceService(ApplicationManager.getApplication(), DevoxxGenieStateService.class, settingsStateMock, getTestRootDisposable()); } - @ParameterizedTest - @MethodSource("provideModels") - public void testExecuteQuery(LanguageModel languageModel) { - PromptExecutionService promptExecutionService = new PromptExecutionService(); + @Test + public void testExecuteQueryOpenAI() { + LanguageModel model = LanguageModel.builder() + .provider(ModelProvider.OpenAI) + .modelName("gpt-3.5-turbo") + .displayName("GPT-3.5 Turbo") + .apiKeyUsed(true) + .inputCost(0.0) + .outputCost(0.0) + .contextWindow(4096) + .build(); + verifyResponse(createChatModel(model), model); + } - ChatMessageContext context = ChatMessageContext.builder() - .userPrompt("What is the capital of Belgium?") - .chatLanguageModel(createChatModel(languageModel)) - .languageModel(languageModel) - .project(getProject()) + @Test + public void testExecuteQueryAnthropic() { + LanguageModel model = LanguageModel.builder() + .provider(ModelProvider.Anthropic) + .modelName("claude-3-5-sonnet-20240620") + .displayName("claude-3-5-sonnet-20240620") + .apiKeyUsed(true) + .inputCost(0.0) + .outputCost(0.0) + .contextWindow(100000) + .build(); + verifyResponse(createChatModel(model), model); + } + + @Test + public void testExecuteQueryGemini() { + LanguageModel model = LanguageModel.builder() + .provider(ModelProvider.Gemini) + .modelName("gemini-pro") + .displayName("Gemini Pro") + .apiKeyUsed(true) + .inputCost(0.0) + .outputCost(0.0) + .contextWindow(32768) .build(); + verifyResponse(createChatModel(model), model); + } - verifyResponse(promptExecutionService, context); + @Test + public void testExecuteQueryMistral() { + LanguageModel model = LanguageModel.builder() + .provider(ModelProvider.Mistral) + .modelName("mistral-medium") + .displayName("Mistral Medium") + .apiKeyUsed(true) + .inputCost(0.0) + .outputCost(0.0) + .contextWindow(32768) + .build(); + verifyResponse(createChatModel(model), model); } - private static Stream provideModels() { - return new LLMModelRegistryService().getModels().stream(); + @Test + public void testExecuteQueryDeepInfra() { + LanguageModel model = LanguageModel.builder() + .provider(ModelProvider.DeepInfra) + .modelName("mistralai/Mixtral-8x7B-Instruct-v0.1") + .displayName("Mixtral 8x7B") + .apiKeyUsed(true) + .inputCost(0.0) + .outputCost(0.0) + .contextWindow(32768) + .build(); + verifyResponse(createChatModel(model), model); } - private ChatLanguageModel createChatModel(@NotNull LanguageModel languageModel) { + private ChatLanguageModel createChatModel(LanguageModel languageModel) { return switch (languageModel.getProvider()) { case OpenAI -> OpenAiChatModel.builder() .apiKey(dotenv.get("OPENAI_API_KEY")) @@ -98,8 +149,15 @@ private ChatLanguageModel createChatModel(@NotNull LanguageModel languageModel) }; } - private static void verifyResponse(@NotNull PromptExecutionService promptExecutionService, - ChatMessageContext context) { + private void verifyResponse(ChatLanguageModel chatModel, LanguageModel languageModel) { + ChatMessageContext context = ChatMessageContext.builder() + .userPrompt("What is the capital of Belgium?") + .chatLanguageModel(chatModel) + .languageModel(languageModel) + .project(getProject()) + .totalFileCount(1) + .build(); + CompletableFuture> response = promptExecutionService.executeQuery(context); assertNotNull(response);