Skip to content

Commit

Permalink
#Feat 98: Allow a streaming response to be stopped
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed Jun 25, 2024
1 parent 4483fb6 commit cd483c5
Show file tree
Hide file tree
Showing 31 changed files with 187 additions and 45 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.1.18"
version = "0.1.19"

repositories {
mavenCentral()
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added core/build/libs/core-0.1.16.jar
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions core/build/tmp/jar/MANIFEST.MF
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Manifest-Version: 1.0

1 change: 1 addition & 0 deletions src/main/java/com/devoxx/genie/model/Constant.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,6 @@ private Constant() {
public static final String SEARCH_THE_WEB_WITH_TAVILY_FOR_AN_ANSWER = "Search the web with Tavily for an answer";
public static final String SEARCH_GOOGLE_FOR_AN_ANSWER = "Search Google for an answer";
public static final String PROMPT_IS_RUNNING_PLEASE_BE_PATIENT = "Prompt is running, please be patient...";
public static final String STOP_STREAMING = "Stop streaming response";

}
58 changes: 37 additions & 21 deletions src/main/java/com/devoxx/genie/service/ChatPromptExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import org.jetbrains.annotations.NotNull;

import javax.swing.*;
import java.net.ConnectException;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeoutException;

public class ChatPromptExecutor {

private final PromptExecutionService promptExecutionService = PromptExecutionService.getInstance();
private StreamingResponseHandler currentStreamingHandler;

public ChatPromptExecutor() {
}
Expand All @@ -38,26 +37,34 @@ public void executePrompt(@NotNull ChatMessageContext chatMessageContext,
new Task.Backgroundable(chatMessageContext.getProject(), "Working...", true) {
@Override
public void run(@NotNull ProgressIndicator progressIndicator) {
if (chatMessageContext.getContext() != null && chatMessageContext.getContext().toLowerCase().contains("search")) {
webSearchPrompt(chatMessageContext, promptOutputPanel, enableButtons);
if (isWebSearch(chatMessageContext)) {
executeWebSearch(chatMessageContext, promptOutputPanel, enableButtons);
} else if (DevoxxGenieStateService.getInstance().getStreamMode()) {
executeStreamingPrompt(chatMessageContext, promptOutputPanel, enableButtons);
} else {
if (DevoxxGenieStateService.getInstance().getStreamMode()) {
setupStreaming(chatMessageContext, promptOutputPanel, enableButtons);
} else {
runPrompt(chatMessageContext, promptOutputPanel, enableButtons);
}
executeNonStreamingPrompt(chatMessageContext, promptOutputPanel, enableButtons);
}
}
}.queue();
}

/**
* Is web search.
* @param chatMessageContext the chat message context
* @return the boolean
*/
private boolean isWebSearch(@NotNull ChatMessageContext chatMessageContext) {
return chatMessageContext.getContext() != null &&
chatMessageContext.getContext().toLowerCase().contains("search");
}

/**
* Web search prompt.
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
*/
private void webSearchPrompt(@NotNull ChatMessageContext chatMessageContext,
private void executeWebSearch(@NotNull ChatMessageContext chatMessageContext,
@NotNull PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {
promptOutputPanel.addUserPrompt(chatMessageContext);
Expand All @@ -81,16 +88,16 @@ public Optional<String> updatePromptWithCommandIfPresent(@NotNull ChatMessageCon
return commandFromPrompt;
}


/**
* Setup streaming.
* Execute streaming response.
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
*/
private void setupStreaming(@NotNull ChatMessageContext chatMessageContext,
@NotNull PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {

private void executeStreamingPrompt(@NotNull ChatMessageContext chatMessageContext,
@NotNull PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {
StreamingChatLanguageModel streamingChatLanguageModel = chatMessageContext.getStreamingChatLanguageModel();
if (streamingChatLanguageModel == null) {
NotificationUtil.sendNotification(chatMessageContext.getProject(), "Streaming model not available, please select another provider.");
Expand All @@ -110,9 +117,18 @@ private void setupStreaming(@NotNull ChatMessageContext chatMessageContext,

promptOutputPanel.addUserPrompt(chatMessageContext);

streamingChatLanguageModel.generate(
chatMemoryService.messages(),
new StreamingResponseHandler(chatMessageContext, promptOutputPanel, enableButtons));
currentStreamingHandler = new StreamingResponseHandler(chatMessageContext, promptOutputPanel, enableButtons);
streamingChatLanguageModel.generate(chatMemoryService.messages(), currentStreamingHandler);
}

/**
* Stop streaming.
*/
public void stopStreaming() {
if (currentStreamingHandler != null) {
currentStreamingHandler.stop();
currentStreamingHandler = null;
}
}

/**
Expand Down Expand Up @@ -148,9 +164,9 @@ private Optional<String> getCommandFromPrompt(@NotNull String prompt,
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
*/
private void runPrompt(@NotNull ChatMessageContext chatMessageContext,
PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {
private void executeNonStreamingPrompt(@NotNull ChatMessageContext chatMessageContext,
PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {

promptExecutionService.executeQuery(chatMessageContext)
.thenAccept(aiMessageOptional -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.devoxx.genie.service;

import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.service.exception.ProviderUnavailableException;
import com.devoxx.genie.ui.component.ExpandablePanel;
import com.devoxx.genie.ui.panel.ChatStreamingResponsePanel;
import com.devoxx.genie.ui.panel.PromptOutputPanel;
Expand All @@ -15,32 +14,46 @@
import java.util.concurrent.TimeoutException;

public class StreamingResponseHandler implements dev.langchain4j.model.StreamingResponseHandler<AiMessage> {

private final ChatMessageContext chatMessageContext;
private final Runnable enableButtons;
private final ChatStreamingResponsePanel streamingChatResponsePanel;
private final PromptOutputPanel promptOutputPanel;
private volatile boolean isStopped = false;

public StreamingResponseHandler(ChatMessageContext chatMessageContext,
@NotNull PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {
this.chatMessageContext = chatMessageContext;
this.enableButtons = enableButtons;
this.promptOutputPanel = promptOutputPanel;
streamingChatResponsePanel = new ChatStreamingResponsePanel(chatMessageContext);
this.streamingChatResponsePanel = new ChatStreamingResponsePanel(chatMessageContext);
promptOutputPanel.addStreamResponse(streamingChatResponsePanel);
}

@Override
public void onNext(String token) {
streamingChatResponsePanel.insertToken(token);
if (!isStopped) {
streamingChatResponsePanel.insertToken(token);
}
}

@Override
public void onComplete(@NotNull Response<AiMessage> response) {
if (isStopped) {
return;
}

finalizeResponse(response);
addExpandablePanelIfNeeded();
}

private void finalizeResponse(@NotNull Response<AiMessage> response) {
chatMessageContext.setAiMessage(response.content());
ChatMemoryService.getInstance().add(response.content());
enableButtons.run();
}

private void addExpandablePanelIfNeeded() {
if (chatMessageContext.hasFiles()) {
SwingUtilities.invokeLater(() -> {
ExpandablePanel fileListPanel = new ExpandablePanel(chatMessageContext);
Expand All @@ -50,19 +63,36 @@ public void onComplete(@NotNull Response<AiMessage> response) {
}
}

public void stop() {
isStopped = true;
enableButtons.run();
}

@Override
public void onError(Throwable error) {
enableButtons.run();
handleError(error);
}

/**
* Handle the LLM error and notify user with a message.
* @param error the error
*/
private void handleError(@NotNull Throwable error) {
if (error.getCause() instanceof TimeoutException) {
NotificationUtil.sendNotification(chatMessageContext.getProject(),
"Timeout occurred. Please increase the timeout setting.");
notifyUser("Timeout occurred. Please increase the timeout setting.");
} else if (error.getCause() instanceof ConnectException) {
NotificationUtil.sendNotification(chatMessageContext.getProject(),
"LLM provider not available. Please select another provider or make sure it's running.");
notifyUser("LLM provider not available. Please select another provider or make sure it's running.");
} else {
NotificationUtil.sendNotification(chatMessageContext.getProject(),
"An error occurred. Please try again.");
notifyUser("An error occurred. Please try again.");
}
}
}

/**
* Notify the user with a message.
* @param message the message
*/
private void notifyUser(String message) {
NotificationUtil.sendNotification(chatMessageContext.getProject(), message);
}
}
113 changes: 101 additions & 12 deletions src/main/java/com/devoxx/genie/ui/panel/ActionButtonsPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.awt.*;
import java.awt.event.ActionEvent;
import java.util.List;
import java.util.Optional;

import static com.devoxx.genie.model.Constant.*;
import static com.devoxx.genie.model.Constant.ADD_FILE_S_TO_PROMPT_CONTEXT;
Expand All @@ -55,6 +54,9 @@ public class ActionButtonsPanel extends JPanel {
private final DevoxxGenieToolWindowContent devoxxGenieToolWindowContent;
private final ChatModelProvider chatModelProvider = new ChatModelProvider();

private boolean isStreaming = false;
private ChatMessageContext currentChatMessageContext;

public ActionButtonsPanel(Project project,
PromptInputArea promptInputComponent,
PromptOutputPanel promptOutputPanel,
Expand Down Expand Up @@ -139,21 +141,103 @@ private void selectFilesForPromptContext(ActionEvent e) {
* Submit the user prompt.
*/
private void onSubmitPrompt(ActionEvent actionEvent) {
String userPromptText = isUserPromptProvided();
if (userPromptText == null) return;
if (isStreaming) {
stopStreaming();
return;
}

if (isWebSearchTriggeredAndConfigured(actionEvent)) return;
if (!validateAndPreparePrompt(actionEvent)) {
return;
}

disableSubmitBtn();
executePrompt();
}

/**
* Execute the prompt.
*/
private void executePrompt() {
disableUIForPromptExecution();

ChatMessageContext chatMessageContext =
createChatMessageContext(actionEvent, userPromptText, editorFileButtonManager.getSelectedTextEditor());
chatPromptExecutor.updatePromptWithCommandIfPresent(currentChatMessageContext, promptOutputPanel)
.ifPresentOrElse(
command -> startPromptExecution(),
this::enableButtons
);
}

/**
* Start the prompt execution.
*/
private void startPromptExecution() {
if (DevoxxGenieStateService.getInstance().getStreamMode()) {
isStreaming = true;
}
chatPromptExecutor.executePrompt(currentChatMessageContext, promptOutputPanel, this::enableButtons);
}

/**
* Stop the streaming.
*/
private void stopStreaming() {
chatPromptExecutor.stopStreaming();
isStreaming = false;
enableButtons();
}

/**
* get the user prompt text.
*/
private @Nullable String getUserPromptText() {
String userPromptText = promptInputComponent.getText();
if (userPromptText.isEmpty()) {
NotificationUtil.sendNotification(project, "Please enter a prompt.");
return null;
}
return userPromptText;
}

/**
* Check if web search is triggered and not configured, if not show Settings page.
* @param actionEvent the action event
* @return true if the web search is triggered and not configured
*/
private boolean isWebSearchTriggeredAndNotConfigured(@NotNull ActionEvent actionEvent) {
if (actionEvent.getActionCommand().toLowerCase().contains("search") && !isWebSearchEnabled()) {
SwingUtilities.invokeLater(() ->
NotificationUtil.sendNotification(project, "No Search API keys found, please add one in the settings.")
);
showSettingsDialog(project);
return true;
}
return false;
}

/**
* Disable the UI for prompt execution.
*/
private void disableUIForPromptExecution() {
disableSubmitBtn();
disableButtons();
}

/**
* Validate and prepare the prompt.
* @param actionEvent the action event
* @return true if the prompt is valid
*/
private boolean validateAndPreparePrompt(ActionEvent actionEvent) {
String userPromptText = getUserPromptText();
if (userPromptText == null) {
return false;
}

if (isWebSearchTriggeredAndNotConfigured(actionEvent)) {
return false;
}

chatPromptExecutor.updatePromptWithCommandIfPresent(chatMessageContext, promptOutputPanel)
.ifPresentOrElse(command -> chatPromptExecutor.executePrompt(chatMessageContext, promptOutputPanel, this::enableButtons),
this::enableButtons);
currentChatMessageContext = createChatMessageContext(actionEvent, userPromptText, editorFileButtonManager.getSelectedTextEditor());
return true;
}

/**
Expand All @@ -165,6 +249,7 @@ public void enableButtons() {
submitBtn.setEnabled(true);
submitBtn.setToolTipText(SUBMIT_THE_PROMPT);
promptInputComponent.setEnabled(true);
isStreaming = false;
});
}

Expand All @@ -174,10 +259,14 @@ public void enableButtons() {
private void disableSubmitBtn() {
invokeLater(() -> {
if (DevoxxGenieStateService.getInstance().getStreamMode()) {
submitBtn.setEnabled(true);
submitBtn.setIcon(StopIcon);
submitBtn.setToolTipText(STOP_STREAMING);
} else {
submitBtn.setEnabled(false);
submitBtn.setIcon(StopIcon);
submitBtn.setToolTipText(PROMPT_IS_RUNNING_PLEASE_BE_PATIENT);
}
submitBtn.setIcon(StopIcon);
submitBtn.setToolTipText(PROMPT_IS_RUNNING_PLEASE_BE_PATIENT);
});
}

Expand Down
Loading

0 comments on commit cd483c5

Please sign in to comment.