Skip to content

Commit

Permalink
Feat #327 : POC of TDG
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanj committed Nov 8, 2024
1 parent 156742c commit d9c5c5c
Show file tree
Hide file tree
Showing 26 changed files with 260 additions and 47 deletions.
1 change: 1 addition & 0 deletions src/main/java/com/devoxx/genie/model/Constant.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ private Constant() {
public static final String TEST_PROMPT = "Write a unit test for this code using JUnit.";
public static final String REVIEW_PROMPT = "Review the selected code, can it be improved or are there any bugs?";
public static final String EXPLAIN_PROMPT = "Break down the code in simple terms to help a junior developer grasp its functionality.";
public static final String TDG_PROMPT = "You are a professional Java developer. Give me a SINGLE FILE COMPLETE java implementation that will pass this test. Do not respond with a test. Give me only complete code and no snippets. Include imports and use the right package.";

// The Local LLM Model URLs, these can be overridden in the settings page
public static final String OLLAMA_MODEL_URL = "http://localhost:11434/";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@Builder
public class ChatMessageContext {
private final LocalDateTime createdOn = LocalDateTime.now();
private String name;
private String id;
private Project project;
private Integer timeout;
private String userPrompt;
Expand All @@ -30,6 +30,7 @@ public class ChatMessageContext {
private int totalFileCount;
private long executionTimeMs;
private TokenUsage tokenUsage;
private String commandName; // Custom command name for the prompt, for example /test, /review etc.
private double cost;

@Builder.Default
Expand Down
21 changes: 11 additions & 10 deletions src/main/java/com/devoxx/genie/service/ChatPromptExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.intellij.openapi.vfs.VirtualFile;
import org.jetbrains.annotations.NotNull;

import javax.swing.*;
import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -38,7 +37,6 @@ public ChatPromptExecutor(PromptInputArea promptInputArea) {

/**
* Execute the prompt.
*
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @param enableButtons the Enable buttons
Expand Down Expand Up @@ -92,13 +90,12 @@ public void run(@NotNull ProgressIndicator progressIndicator) {

/**
* Process possible command prompt.
*
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
*/
public Optional<String> updatePromptWithCommandIfPresent(@NotNull ChatMessageContext chatMessageContext,
PromptOutputPanel promptOutputPanel) {
Optional<String> commandFromPrompt = getCommandFromPrompt(chatMessageContext.getUserPrompt().trim(), promptOutputPanel);
Optional<String> commandFromPrompt = getCommandFromPrompt(chatMessageContext, promptOutputPanel);
chatMessageContext.setUserPrompt(commandFromPrompt.orElse(chatMessageContext.getUserPrompt()));

// Ensure that EditorInfo is set in the ChatMessageContext
Expand All @@ -109,6 +106,11 @@ public Optional<String> updatePromptWithCommandIfPresent(@NotNull ChatMessageCon
return commandFromPrompt;
}

/**
* Get the editor info.
* @param project the project
* @return the editor info
*/
private @NotNull EditorInfo getEditorInfo(Project project) {
EditorInfo editorInfo = new EditorInfo();
FileEditorManager fileEditorManager = FileEditorManager.getInstance(project);
Expand All @@ -132,7 +134,6 @@ public Optional<String> updatePromptWithCommandIfPresent(@NotNull ChatMessageCon

/**
* Stop streaming or the non-streaming prompt execution
*
* @param project the project
*/
public void stopPromptExecution(Project project) {
Expand All @@ -145,21 +146,21 @@ public void stopPromptExecution(Project project) {

/**
* Get the command from the prompt.
*
* @param prompt the prompt
* @param chatMessageContext the chat message context
* @param promptOutputPanel the prompt output panel
* @return the command
*/
private Optional<String> getCommandFromPrompt(@NotNull String prompt,
private Optional<String> getCommandFromPrompt(@NotNull ChatMessageContext chatMessageContext,
PromptOutputPanel promptOutputPanel) {
String prompt = chatMessageContext.getUserPrompt().trim();
if (prompt.startsWith("/")) {
DevoxxGenieSettingsService settings = DevoxxGenieStateService.getInstance();

// Check for custom prompts
for (CustomPrompt customPrompt : settings.getCustomPrompts()) {
if (prompt.equalsIgnoreCase("/" + customPrompt.getName())) {
prompt = customPrompt.getPrompt();
return Optional.of(prompt);
chatMessageContext.setCommandName(customPrompt.getName());
return Optional.of(customPrompt.getPrompt());
}
}
promptOutputPanel.showHelpText();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import dev.langchain4j.model.output.Response;
import org.jetbrains.annotations.NotNull;

import javax.swing.*;

public class StreamingResponseHandler implements dev.langchain4j.model.StreamingResponseHandler<AiMessage> {
private final ChatMessageContext chatMessageContext;
private final Runnable enableButtons;
Expand Down Expand Up @@ -70,7 +68,7 @@ private void addExpandablePanelIfNeeded() {
ApplicationManager.getApplication().invokeLater(() -> {
ExpandablePanel fileListPanel =
new ExpandablePanel(chatMessageContext, FileListManager.getInstance().getFiles());
fileListPanel.setName(chatMessageContext.getName());
fileListPanel.setName(chatMessageContext.getId());
promptOutputPanel.addStreamFileReferencesResponse(fileListPanel);
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.devoxx.genie.service.tdg;

public class ClassNameNotFoundException extends Exception {
public ClassNameNotFoundException(String message) {
super(message);
}
}
54 changes: 54 additions & 0 deletions src/main/java/com/devoxx/genie/service/tdg/CodeContainer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.devoxx.genie.service.tdg;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

@ToString
@EqualsAndHashCode
@Setter
@Getter
public final class CodeContainer {

private final String content;
private final String fileName;
private final String packageName;
private final int attempts;

public CodeContainer(String content) throws ClassNameNotFoundException {
this(content, 1);
}

public CodeContainer(String content, int attempts) throws ClassNameNotFoundException {
this.content = content;
this.fileName = extractClassName() + ".java";
this.packageName = extractPackageName();
this.attempts = attempts;
}

// TODO: first look for 'public class' and then for 'class
private String extractClassName() throws ClassNameNotFoundException {
// matches "public" (optional) followed by "class" and then the class name
String regex = "\\b(?:public\\s+)?class\\s+(\\w+)\\b";
Matcher matcher = Pattern.compile(regex).matcher(content);
if (matcher.find()) {
return matcher.group(1);
} else {
throw new ClassNameNotFoundException("Class name not found in: " + content);
}
}

private String extractPackageName() {
String regex = "package\\s+(\\w+(\\.\\w+)*)";
Matcher matcher = Pattern.compile(regex).matcher(content);
if (matcher.find()) {
return matcher.group(1);
} else {
return "";
}
}
}
128 changes: 128 additions & 0 deletions src/main/java/com/devoxx/genie/service/tdg/CodeGeneratorService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package com.devoxx.genie.service.tdg;

import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.application.ModalityState;
import com.intellij.openapi.fileEditor.FileEditorManager;
import com.intellij.openapi.module.Module;
import com.intellij.openapi.module.ModuleManager;
import com.intellij.openapi.progress.ProgressIndicator;
import com.intellij.openapi.progress.Task;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.roots.ModuleRootManager;
import com.intellij.openapi.vfs.VirtualFile;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.IOException;
import java.nio.charset.StandardCharsets;

public class CodeGeneratorService {

public static void createClassFromCodeSnippet(@NotNull ChatMessageContext chatMessageContext,
String selectedText) {
Project project = chatMessageContext.getProject();

new Task.Backgroundable(project, "Creating java class", false) {
@Override
public void run(@NotNull ProgressIndicator indicator) {
try {
indicator.setIndeterminate(true);
indicator.setText("Parsing code...");

CodeContainer codeContainer = new CodeContainer(selectedText);
String packageName = codeContainer.getPackageName();
String fileName = codeContainer.getFileName();

indicator.setText("Creating class file...");

ApplicationManager.getApplication().invokeAndWait(() ->
ApplicationManager.getApplication().runWriteAction(() ->
createFile(packageName, fileName, project, selectedText)),
ModalityState.defaultModalityState());

} catch (Exception e) {
NotificationUtil.sendNotification(project,
"Error creating class: " + e.getMessage());
}
}
}.queue();
}

private static void createFile(String packageName,
String fileName,
Project project,
String selectedText) {
try {
// Find the proper source root for Java files
VirtualFile sourceRoot = findSourceRoot(project);
if (sourceRoot == null) {
NotificationUtil.sendNotification(project,
"Error: Could not find source root directory");
return;
}

VirtualFile packageDir = createPackageDirectories(sourceRoot, packageName);
VirtualFile existingFile = packageDir.findChild(fileName);
VirtualFile javaFile;

if (existingFile != null) {
existingFile.setBinaryContent(
selectedText.getBytes(StandardCharsets.UTF_8));
javaFile = existingFile;
NotificationUtil.sendNotification(project,
"Class updated successfully");
} else {
javaFile = packageDir.createChildData(null, fileName);
javaFile.setBinaryContent(
selectedText.getBytes(StandardCharsets.UTF_8));
NotificationUtil.sendNotification(project,
"Class created successfully");
}

FileEditorManager.getInstance(project).openFile(javaFile, true);

} catch (IOException e) {
NotificationUtil.sendNotification(project,
"Error creating class: " + e.getMessage());
}
}

private static @Nullable VirtualFile findSourceRoot(Project project) {
ModuleManager moduleManager = ModuleManager.getInstance(project);
for (Module module : moduleManager.getModules()) {
ModuleRootManager rootManager = ModuleRootManager.getInstance(module);
// Get source roots for the module
for (VirtualFile root : rootManager.getSourceRoots(false)) {
// Look for the main source root, typically ending with "src/main/java"
if (root.getPath().endsWith("src/main/java")) {
return root;
}
}
// Fallback to first source root if we can't find main/java
VirtualFile[] sourceRoots = rootManager.getSourceRoots(false);
if (sourceRoots.length > 0) {
return sourceRoots[0];
}
}
return null;
}

private static VirtualFile createPackageDirectories(@NotNull VirtualFile sourceRoot,
@NotNull String packageName) throws IOException {
if (packageName.isEmpty()) {
return sourceRoot;
}

VirtualFile currentDir = sourceRoot;
for (String part : packageName.split("\\.")) {
VirtualFile subDir = currentDir.findChild(part);
if (subDir == null) {
subDir = currentDir.createChildDirectory(null, part);
}
currentDir = subDir;
}
return currentDir;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ private void initializeCommands() {
commands.add("/test");
commands.add("/explain");
commands.add("/review");
commands.add("/tdg");
commands.add("/help");

DevoxxGenieSettingsService stateService = DevoxxGenieStateService.getInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class ChatResponsePanel extends BackgroundPanel {
* @param chatMessageContext the chat message context
*/
public ChatResponsePanel(@NotNull ChatMessageContext chatMessageContext) {
super(chatMessageContext.getName());
super(chatMessageContext.getId());
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));

this.chatMessageContext = chatMessageContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class ChatStreamingResponsePanel extends BackgroundPanel {
* @param chatMessageContext the chat message context
*/
public ChatStreamingResponsePanel(@NotNull ChatMessageContext chatMessageContext) {
super(chatMessageContext.getName());
super(chatMessageContext.getId());

setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
add(new ResponseHeaderPanel(chatMessageContext));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ public void addUserPrompt(ChatMessageContext chatMessageContext) {
userPromptPanel.add(waitingPanel, BorderLayout.SOUTH);
}

addFiller(chatMessageContext.getName());
addFiller(chatMessageContext.getId());
container.add(userPromptPanel);
scrollToBottom();
}

public void addChatResponse(@NotNull ChatMessageContext chatMessageContext) {
waitingPanel.hideMsg();
addFiller(chatMessageContext.getName());
addFiller(chatMessageContext.getId());
container.add(new ChatResponsePanel(chatMessageContext));
scrollToBottom();
}
Expand All @@ -108,7 +108,7 @@ public void addStreamFileReferencesResponse(ExpandablePanel fileListPanel) {

public void removeLastUserPrompt(ChatMessageContext chatMessageContext) {
for (Component component : container.getComponents()) {
if (component instanceof UserPromptPanel && component.getName().equals(chatMessageContext.getName())) {
if (component instanceof UserPromptPanel && component.getName().equals(chatMessageContext.getId())) {
container.remove(component);
break;
}
Expand Down Expand Up @@ -153,7 +153,7 @@ private ChatMessageContext createChatMessageContext(Project project,
@NotNull Conversation conversation,
@NotNull ChatMessage message) {
return ChatMessageContext.builder()
.name(conversation.getId())
.id(conversation.getId())
.project(project)
.userPrompt(message.isUser() ? message.getContent() : "")
.aiMessage(message.isUser() ? null : AiMessage.aiMessage(message.getContent()))
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/devoxx/genie/ui/panel/UserPromptPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class UserPromptPanel extends BackgroundPanel {
*/
public UserPromptPanel(JPanel container,
@NotNull ChatMessageContext chatMessageContext) {
super(chatMessageContext.getName());
super(chatMessageContext.getId());
this.container = container;
setLayout(new BorderLayout());

Expand Down Expand Up @@ -77,7 +77,7 @@ public UserPromptPanel(JPanel container,
* @param chatMessageContext the chat message context
*/
private void removeChat(@NotNull ChatMessageContext chatMessageContext) {
String nameToRemove = chatMessageContext.getName();
String nameToRemove = chatMessageContext.getId();
java.util.List<Component> componentsToRemove = new ArrayList<>();

for (Component c : container.getComponents()) {
Expand Down
Loading

0 comments on commit d9c5c5c

Please sign in to comment.