Skip to content

Commit

Permalink
Rolledback fix so we can do this in a dedicated branch
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed Dec 12, 2024
1 parent 16ec839 commit 1937565
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void execute(ChatMessageContext chatMessageContext,
isCancelled = false;

if (FIND_COMMAND.equals(chatMessageContext.getCommandName())) {
semanticSearch(chatMessageContext, promptOutputPanel);
semanticSearch(chatMessageContext, promptOutputPanel, enableButtons);
enableButtons.run();
return;
}
Expand Down Expand Up @@ -73,15 +73,19 @@ private void prompt(ChatMessageContext chatMessageContext,

// Add the conversation to the chat service
ApplicationManager.getApplication().getMessageBus()
.syncPublisher(AppTopics.CONVERSATION_TOPIC)
.onNewConversation(chatMessageContext);
.syncPublisher(AppTopics.CONVERSATION_TOPIC)
.onNewConversation(chatMessageContext);

promptOutputPanel.addChatResponse(chatMessageContext);
} else if (isCancelled) {
LOG.debug(">>>> Prompt execution cancelled");
promptOutputPanel.removeLastUserPrompt(chatMessageContext);
}
})
.exceptionally(throwable -> {
ErrorHandler.handleError(chatMessageContext.getProject(), throwable);
return null;
})
.whenComplete((result, throwable) -> enableButtons.run());
}

Expand All @@ -91,7 +95,8 @@ private void prompt(ChatMessageContext chatMessageContext,
* @param promptOutputPanel the prompt output panel
*/
private static void semanticSearch(ChatMessageContext chatMessageContext,
@NotNull PromptOutputPanel promptOutputPanel) {
@NotNull PromptOutputPanel promptOutputPanel,
Runnable enableButtons) {
try {
SemanticSearchService semanticSearchService = SemanticSearchService.getInstance();
Map<String, SearchResult> searchResults = semanticSearchService.search(
Expand Down Expand Up @@ -119,7 +124,6 @@ private static void semanticSearch(ChatMessageContext chatMessageContext,
*/
public void stopExecution() {
if (currentTask != null && !currentTask.isDone()) {
promptExecutionService.cancelCurrentQuery();
isCancelled = true;
currentTask.cancel(true);
}
Expand Down
76 changes: 28 additions & 48 deletions src/main/java/com/devoxx/genie/service/PromptExecutionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
import lombok.Getter;
import org.jetbrains.annotations.NotNull;

import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;

public class PromptExecutionService {
Expand All @@ -33,7 +31,8 @@ public class PromptExecutionService {
private CompletableFuture<Response<AiMessage>> queryFuture = null;

@Getter
private final AtomicBoolean running = new AtomicBoolean(false);
private boolean running = false;

private final ReentrantLock queryLock = new ReentrantLock();

@NotNull
Expand All @@ -52,12 +51,7 @@ static PromptExecutionService getInstance() {

queryLock.lock();
try {
if (running.get()) {
LOG.info("Another query is already running. Cancelling it.");
cancelCurrentQuery();
}

running.set(true);
if (isCanceled()) return CompletableFuture.completedFuture(null);

MessageCreationService messageCreationService = MessageCreationService.getInstance();

Expand All @@ -68,7 +62,7 @@ static PromptExecutionService getInstance() {
ChatMemoryService
.getInstance()
.add(chatMessageContext.getProject(),
new SystemMessage(DevoxxGenieStateService.getInstance().getSystemPrompt() + Constant.MARKDOWN)
new SystemMessage(DevoxxGenieStateService.getInstance().getSystemPrompt() + Constant.MARKDOWN)
);
}
}
Expand All @@ -81,45 +75,37 @@ static PromptExecutionService getInstance() {
long startTime = System.currentTimeMillis();

queryFuture = CompletableFuture
.supplyAsync(() -> {
if (Thread.currentThread().isInterrupted()) {
throw new CancellationException("Query was cancelled before execution.");
}
return processChatMessage(chatMessageContext);
}, queryExecutor)
.orTimeout(
chatMessageContext.getTimeout() == null ? 60 : chatMessageContext.getTimeout(), TimeUnit.SECONDS)
.thenApply(result -> {
chatMessageContext.setExecutionTimeMs(System.currentTimeMillis() - startTime);
return result;
})
.whenComplete((r, t) -> {
queryLock.lock();
try {
running.set(false);
} finally {
queryLock.unlock();
}
});
.supplyAsync(() -> processChatMessage(chatMessageContext), queryExecutor)
.orTimeout(
chatMessageContext.getTimeout() == null ? 60 : chatMessageContext.getTimeout(), TimeUnit.SECONDS)
.thenApply(result -> {
chatMessageContext.setExecutionTimeMs(System.currentTimeMillis() - startTime);
return result;
})
.exceptionally(throwable -> {
LOG.error("Error occurred while processing chat message", throwable);
ErrorHandler.handleError(chatMessageContext.getProject(), throwable);
return null;
});
} finally {
queryLock.unlock();
}
return queryFuture;
}

/**
* Cancels the current query if it's running.
* If the future task is not null this means we need to cancel it
*
* @return true if the task is canceled
*/
public void cancelCurrentQuery() {
queryLock.lock();
try {
if (queryFuture != null) {
queryFuture.cancel(true);
queryFuture = null;
}
} finally {
queryLock.unlock();
private boolean isCanceled() {
if (queryFuture != null && !queryFuture.isDone()) {
queryFuture.cancel(true);
running = false;
return true;
}
running = true;
return false;
}

/**
Expand All @@ -130,18 +116,12 @@ public void cancelCurrentQuery() {
*/
private @NotNull Response<AiMessage> processChatMessage(ChatMessageContext chatMessageContext) {
try {
if (Thread.currentThread().isInterrupted()) {
throw new CancellationException("Query was cancelled during execution.");
}

ChatLanguageModel chatLanguageModel = chatMessageContext.getChatLanguageModel();
Response<AiMessage> response =
chatLanguageModel
.generate(ChatMemoryService.getInstance().messages(chatMessageContext.getProject()));
chatLanguageModel
.generate(ChatMemoryService.getInstance().messages(chatMessageContext.getProject()));
ChatMemoryService.getInstance().add(chatMessageContext.getProject(), response.content());
return response;
} catch (CancellationException e) {
throw e; // Re-throw cancellation exceptions
} catch (Exception e) {
if (chatMessageContext.getLanguageModel().getProvider().equals(ModelProvider.Jan)) {
throw new ModelNotActiveException("Selected Jan model is not active. Download and make it active or add API Key in Jan settings.");
Expand Down

0 comments on commit 1937565

Please sign in to comment.