Skip to content

Commit

Permalink
Feat #85: Search web buttons
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed May 29, 2024
1 parent b2f616b commit 002d7c9
Show file tree
Hide file tree
Showing 21 changed files with 469 additions and 179 deletions.
3 changes: 3 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ dependencies {
implementation("dev.langchain4j:langchain4j-open-ai:0.31.0")
implementation("dev.langchain4j:langchain4j-anthropic:0.31.0")
implementation("dev.langchain4j:langchain4j-mistral-ai:0.31.0")
implementation("dev.langchain4j:langchain4j-web-search-engine-google-custom:0.31.0")
implementation("dev.langchain4j:langchain4j-web-search-engine-tavily:0.31.0")

implementation("org.commonmark:commonmark:0.22.0")

compileOnly("org.projectlombok:lombok:1.18.32")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public StreamingChatLanguageModel getStreamingChatLanguageModel(@NotNull ChatMes
* @return the chat model factory
*/
private @NotNull ChatModelFactory getFactory(@NotNull ChatMessageContext chatMessageContext) {
ModelProvider provider = ModelProvider.valueOf(chatMessageContext.getLlmProvider());
ModelProvider provider = ModelProvider.fromString(chatMessageContext.getLlmProvider());
ChatModelFactory factory = factories.get(provider);
if (factory == null) {
throw new IllegalArgumentException("No factory for provider: " + provider);
Expand Down
25 changes: 24 additions & 1 deletion src/main/java/com/devoxx/genie/model/Constant.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.devoxx.genie.model;

public class Constant {

private Constant() {
}

Expand All @@ -10,12 +11,21 @@ private Constant() {
public static final String EXPLAIN_PROMPT = "Break down the code in simple terms to help a junior developer grasp its functionality.";
public static final String CUSTOM_PROMPT = "Write a custom prompt here.";

// The Local LLM Model URLs
// The Local LLM Model URLs, these can be overridden in the settings page
public static final String OLLAMA_MODEL_URL = "http://localhost:11434/";
public static final String LMSTUDIO_MODEL_URL = "http://localhost:1234/v1/";
public static final String GPT4ALL_MODEL_URL = "http://localhost:4891/v1/";
public static final String JAN_MODEL_URL = "http://localhost:1337/v1/";

// ActionCommands
public static final String SUBMIT_ACTION = "submit";
public static final String TAVILY_SEARCH_ACTION = "tavilySearch";
public static final String GOOGLE_SEARCH_ACTION = "googleSearch";
public static final String COMBO_BOX_CHANGED = "comboBoxChanged";

// I18N file name
public static final String MESSAGES = "messages";

// The LLM Settings
public static final Double TEMPERATURE = 0.7d;
public static final Double TOP_P = 0.9d;
Expand All @@ -24,10 +34,23 @@ private Constant() {
public static final Integer TIMEOUT = 60;
public static final Integer MAX_MEMORY = 10;

// Hide Search Button
public static final Boolean HIDE_SEARCH_BUTTONS = false;

// Stream mode settings
public static final Boolean STREAM_MODE = false;

// AST settings
public static final Boolean AST_MODE = false;
public static final Boolean AST_PARENT_CLASS = true;
public static final Boolean AST_CLASS_REFERENCE = true;
public static final Boolean AST_FIELD_REFERENCE = true;

// Button tooltip texts
public static final String ADD_FILE_S_TO_PROMPT_CONTEXT = "Add file(s) to prompt context";
public static final String SUBMIT_THE_PROMPT = "Submit the prompt";
public static final String SEARCH_THE_WEB_WITH_TAVILY_FOR_AN_ANSWER = "Search the web with Tavily for an answer";
public static final String SEARCH_GOOGLE_FOR_AN_ANSWER = "Search Google for an answer";
public static final String PROMPT_IS_RUNNING_PLEASE_BE_PATIENT = "Prompt is running, please be patient...";

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.devoxx.genie.model.enumarations;

import lombok.Getter;
import org.jetbrains.annotations.NotNull;

@Getter
public enum ModelProvider {
Expand All @@ -21,4 +22,12 @@ public enum ModelProvider {
this.name = name;
}

public static @NotNull ModelProvider fromString(String name) {
for (ModelProvider provider : ModelProvider.values()) {
if (provider.getName().equals(name)) {
return provider;
}
}
throw new IllegalArgumentException("No enum found with name: [" + name + "]");
}
}
36 changes: 27 additions & 9 deletions src/main/java/com/devoxx/genie/service/ChatPromptExecutor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.devoxx.genie.service;

import com.devoxx.genie.model.Constant;
import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.panel.PromptOutputPanel;
import com.devoxx.genie.ui.util.NotificationUtil;
Expand All @@ -9,6 +10,7 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import org.jetbrains.annotations.NotNull;

import java.awt.event.ActionEvent;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeoutException;
Expand All @@ -23,7 +25,6 @@ public ChatPromptExecutor() {

/**
* Execute the prompt.
*
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
Expand All @@ -35,19 +36,39 @@ public void executePrompt(@NotNull ChatMessageContext chatMessageContext,
new Task.Backgroundable(chatMessageContext.getProject(), "Working...", true) {
@Override
public void run(@NotNull ProgressIndicator progressIndicator) {
if (SettingsStateService.getInstance().getStreamMode()) {
setupStreaming(chatMessageContext, promptOutputPanel, enableButtons);
if (chatMessageContext.getContext().toLowerCase().contains("search")) {
webSearchPrompt(chatMessageContext, promptOutputPanel, enableButtons);
} else {
runPrompt(chatMessageContext, promptOutputPanel, enableButtons);
progressIndicator.setText("Working...");
if (SettingsStateService.getInstance().getStreamMode()) {
setupStreaming(chatMessageContext, promptOutputPanel, enableButtons);
} else {
runPrompt(chatMessageContext, promptOutputPanel, enableButtons);
}
}
}
}.queue();
}

/**
* Web search prompt.
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
*/
private void webSearchPrompt(@NotNull ChatMessageContext chatMessageContext,
@NotNull PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {
promptOutputPanel.addUserPrompt(chatMessageContext);
WebSearchService.getInstance().searchWeb(chatMessageContext)
.ifPresent(aiMessage -> {
chatMessageContext.setAiMessage(aiMessage);
promptOutputPanel.addChatResponse(chatMessageContext);
enableButtons.run();
});
}

/**
* Process possible command prompt.
*
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
*/
Expand All @@ -59,7 +80,6 @@ public void updatePromptWithCommandIfPresent(@NotNull ChatMessageContext chatMes

/**
* Setup streaming.
*
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
Expand Down Expand Up @@ -121,7 +141,6 @@ private Optional<String> getCommandFromPrompt(@NotNull String prompt,

/**
* Run the prompt.
*
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
Expand All @@ -130,7 +149,6 @@ private void runPrompt(@NotNull ChatMessageContext chatMessageContext,
PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {


promptExecutionService.executeQuery(chatMessageContext)
.thenAccept(aiMessageOptional -> {
enableButtons.run();
Expand Down
44 changes: 41 additions & 3 deletions src/main/java/com/devoxx/genie/service/MessageCreationService.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.model.request.EditorInfo;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.vfs.VirtualFile;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
Expand All @@ -12,6 +16,8 @@
import java.util.List;
import java.util.Optional;

import static com.devoxx.genie.action.AddSnippetAction.SELECTED_TEXT_KEY;

/**
* The message creation service for user and system messages.
* Here's where also the basic prompt "engineering" is happening, including calling the AST magic.
Expand Down Expand Up @@ -73,13 +79,45 @@ public static MessageCreationService getInstance() {
return userMessage;
}

/**
* Create user prompt with context.
* @param project the project
* @param userPrompt the user prompt
* @param files the files
* @return the user prompt with context
*/
public @NotNull String createUserPromptWithContext(Project project,
String userPrompt,
@NotNull List<VirtualFile> files) {
StringBuilder userPromptContext = new StringBuilder();
FileDocumentManager fileDocumentManager = FileDocumentManager.getInstance();
files.forEach(file -> ApplicationManager.getApplication().runReadAction(() -> {
if (file.getFileType().getName().equals("UNKNOWN")) {
userPromptContext.append("Filename: ").append(file.getName()).append("\n");
userPromptContext.append("Code Snippet: ").append(file.getUserData(SELECTED_TEXT_KEY)).append("\n");
} else {
Document document = fileDocumentManager.getDocument(file);
if (document != null) {
userPromptContext.append("Filename: ").append(file.getName()).append("\n");
String content = document.getText();
userPromptContext.append(content).append("\n");
} else {
NotificationUtil.sendNotification(project, "Error reading file: " + file.getName());
}
}
}));

userPromptContext.append(userPrompt);
return userPromptContext.toString();
}

/**
* Construct a user message with context.
* @param chatMessageContext the chat message context
* @param context the context
* @return the user message
*/
private static @NotNull UserMessage constructUserMessage(@NotNull ChatMessageContext chatMessageContext,
private @NotNull UserMessage constructUserMessage(@NotNull ChatMessageContext chatMessageContext,
String context) {
StringBuilder sb = new StringBuilder(QUESTION);

Expand Down Expand Up @@ -108,7 +146,7 @@ public static MessageCreationService getInstance() {
* @param chatMessageContext the chat message context
* @param sb the string builder
*/
private static void addASTContext(@NotNull ChatMessageContext chatMessageContext,
private void addASTContext(@NotNull ChatMessageContext chatMessageContext,
@NotNull StringBuilder sb) {
sb.append("\n\nRelated classes:\n\n");
List<VirtualFile> tempFiles = new ArrayList<>();
Expand All @@ -129,7 +167,7 @@ private static void addASTContext(@NotNull ChatMessageContext chatMessageContext
* @param sb the string builder
* @param text the text
*/
private static void appendIfNotEmpty(StringBuilder sb, String text) {
private void appendIfNotEmpty(StringBuilder sb, String text) {
if (text != null && !text.isEmpty()) {
sb.append(text).append("\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;

import lombok.Getter;
import org.jetbrains.annotations.NotNull;

Expand Down Expand Up @@ -54,8 +55,9 @@ static PromptExecutionService getInstance() {

ChatMemoryService.getInstance().add(userMessage);

queryFuture = CompletableFuture.supplyAsync(() -> processChatMessage(chatMessageContext), queryExecutor)
.orTimeout(chatMessageContext.getTimeout(), TimeUnit.SECONDS);
queryFuture = CompletableFuture.supplyAsync(() ->
processChatMessage(chatMessageContext), queryExecutor)
.orTimeout(chatMessageContext.getTimeout(), TimeUnit.SECONDS);
} finally {
queryLock.unlock();
}
Expand All @@ -79,7 +81,6 @@ private boolean isCanceled() {

/**
* Process the chat message.
*
* @param chatMessageContext the chat message context
* @return the AI message
*/
Expand Down
47 changes: 27 additions & 20 deletions src/main/java/com/devoxx/genie/service/SettingsStateService.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.devoxx.genie.service;

import com.devoxx.genie.model.Constant;
import com.devoxx.genie.ui.util.DoubleConverter;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.PersistentStateComponent;
Expand All @@ -13,6 +12,8 @@
import lombok.Setter;
import org.jetbrains.annotations.NotNull;

import static com.devoxx.genie.model.Constant.*;

@Getter
@Setter
@Service
Expand All @@ -27,10 +28,10 @@ public static SettingsStateService getInstance() {
}

// Local LLM URL fields
private String ollamaModelUrl = Constant.OLLAMA_MODEL_URL;
private String lmstudioModelUrl = Constant.LMSTUDIO_MODEL_URL;
private String gpt4allModelUrl = Constant.GPT4ALL_MODEL_URL;
private String janModelUrl = Constant.JAN_MODEL_URL;
private String ollamaModelUrl = OLLAMA_MODEL_URL;
private String lmstudioModelUrl = LMSTUDIO_MODEL_URL;
private String gpt4allModelUrl = GPT4ALL_MODEL_URL;
private String janModelUrl = JAN_MODEL_URL;

// LLM API Keys
private String openAIKey = "";
Expand All @@ -41,38 +42,44 @@ public static SettingsStateService getInstance() {
private String deepInfraKey = "";
private String geminiKey = "";

// Search API Keys
private Boolean hideSearchButtonsFlag = HIDE_SEARCH_BUTTONS;
private String googleSearchKey = "";
private String googleCSIKey = "";
private String tavilySearchKey = "";

// Prompt fields
private String testPrompt = Constant.TEST_PROMPT;
private String reviewPrompt = Constant.REVIEW_PROMPT;
private String explainPrompt = Constant.EXPLAIN_PROMPT;
private String customPrompt = Constant.CUSTOM_PROMPT;
private String testPrompt = TEST_PROMPT;
private String reviewPrompt = REVIEW_PROMPT;
private String explainPrompt = EXPLAIN_PROMPT;
private String customPrompt = CUSTOM_PROMPT;

// LLM settings
@OptionTag(converter = DoubleConverter.class)
private Double temperature = Constant.TEMPERATURE;
private Double temperature = TEMPERATURE;

@OptionTag(converter = DoubleConverter.class)
private Double topP = Constant.TOP_P;
private Double topP = TOP_P;

private Integer timeout = Constant.TIMEOUT;
private Integer maxRetries = Constant.MAX_RETRIES;
private Integer chatMemorySize = Constant.MAX_MEMORY;
private Integer timeout = TIMEOUT;
private Integer maxRetries = MAX_RETRIES;
private Integer chatMemorySize = MAX_MEMORY;

// Was unable to make it work with Integer for some unknown reason
private String maxOutputTokens = Constant.MAX_OUTPUT_TOKENS.toString();
private String maxOutputTokens = MAX_OUTPUT_TOKENS.toString();

// Last selected LLM provider and model name
private String lastSelectedProvider;
private String lastSelectedModel;

// Enable stream mode
private Boolean streamMode = Constant.STREAM_MODE;
private Boolean streamMode = STREAM_MODE;

// Enable AST mode
private Boolean astMode = Constant.AST_MODE;
private Boolean astParentClass = Constant.AST_PARENT_CLASS;
private Boolean astClassReference = Constant.AST_CLASS_REFERENCE;
private Boolean astFieldReference = Constant.AST_FIELD_REFERENCE;
private Boolean astMode = AST_MODE;
private Boolean astParentClass = AST_PARENT_CLASS;
private Boolean astClassReference = AST_CLASS_REFERENCE;
private Boolean astFieldReference = AST_FIELD_REFERENCE;

@Override
public SettingsStateService getState() {
Expand Down
Loading

0 comments on commit 002d7c9

Please sign in to comment.