Skip to content

Commit

Permalink
Merge pull request #79 from devoxx/issue-71
Browse files Browse the repository at this point in the history
Fix #71: Support for AST/PSI window context reflection
  • Loading branch information
stephanj authored May 27, 2024
2 parents 9010a2b + 91036d6 commit 4febe75
Show file tree
Hide file tree
Showing 37 changed files with 342 additions and 137 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.1.12"
version = "0.1.13"

repositories {
mavenCentral()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import com.devoxx.genie.model.Constant;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Setter;
Expand Down Expand Up @@ -83,7 +83,7 @@ public StreamingChatLanguageModel getStreamingChatLanguageModel(@NotNull ChatMes
*/
public @NotNull ChatModel initChatModel(@NotNull ChatMessageContext chatMessageContext) {
ChatModel chatModel = new ChatModel();
SettingsState settingsState = SettingsState.getInstance();
SettingsStateService settingsState = SettingsStateService.getInstance();
setMaxOutputTokens(settingsState, chatModel);

chatModel.setTemperature(settingsState.getTemperature());
Expand All @@ -101,7 +101,7 @@ public StreamingChatLanguageModel getStreamingChatLanguageModel(@NotNull ChatMes
* @param settingsState the settings state
* @param chatModel the chat model
*/
private static void setMaxOutputTokens(@NotNull SettingsState settingsState, ChatModel chatModel) {
private static void setMaxOutputTokens(@NotNull SettingsStateService settingsState, ChatModel chatModel) {
String maxOutputTokens = settingsState.getMaxOutputTokens();
if (maxOutputTokens.isBlank()) {
chatModel.setMaxTokens(Constant.MAX_OUTPUT_TOKENS);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.devoxx.genie.chatmodel;

import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import org.jetbrains.annotations.NotNull;

import java.util.*;
Expand Down Expand Up @@ -32,7 +32,7 @@ private LLMProviderConstant() {
};

public static @NotNull List<String> getLLMProviders() {
SettingsState settingState = SettingsState.getInstance();
SettingsStateService settingState = SettingsStateService.getInstance();
Map<String, Supplier<String>> providerKeyMap = new HashMap<>();
providerKeyMap.put(OpenAI.getName(), settingState::getOpenAIKey);
providerKeyMap.put(Anthropic.getName(), settingState::getAnthropicKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.anthropic.AnthropicChatModel;
import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand Down Expand Up @@ -40,7 +40,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch

@Override
public String getApiKey() {
return SettingsState.getInstance().getAnthropicKey();
return SettingsStateService.getInstance().getAnthropicKey();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
Expand Down Expand Up @@ -42,7 +42,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch

@Override
public String getApiKey() {
return SettingsState.getInstance().getDeepInfraKey();
return SettingsStateService.getInstance().getDeepInfraKey();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.gemini.GeminiChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.jetbrains.annotations.NotNull;

Expand All @@ -26,7 +26,7 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {

@Override
public String getApiKey() {
return SettingsState.getInstance().getGeminiKey();
return SettingsStateService.getInstance().getGeminiKey();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
Expand All @@ -16,7 +16,7 @@ public class GPT4AllChatModelFactory implements ChatModelFactory {
@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
return LocalAiChatModel.builder()
.baseUrl(SettingsState.getInstance().getGpt4allModelUrl())
.baseUrl(SettingsStateService.getInstance().getGpt4allModelUrl())
.modelName("test-model")
.maxRetries(chatModel.getMaxRetries())
.maxTokens(chatModel.getMaxTokens())
Expand All @@ -28,7 +28,7 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {

public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(SettingsState.getInstance().getGpt4allModelUrl())
.baseUrl(SettingsStateService.getInstance().getGpt4allModelUrl())
.modelName("test-model")
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
Expand All @@ -28,21 +28,21 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
.build();
}


@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return OpenAiStreamingChatModel.builder()
.apiKey(getApiKey())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.build();
}
// Streaming gives error for Groq model provider
// @Override
// public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
// return OpenAiStreamingChatModel.builder()
// .apiKey(getApiKey())
// .modelName(chatModel.getModelName())
// .temperature(chatModel.getTemperature())
// .topP(chatModel.getTopP())
// .timeout(Duration.ofSeconds(chatModel.getTimeout()))
// .build();
// }

@Override
public String getApiKey() {
return SettingsState.getInstance().getGroqKey();
return SettingsStateService.getInstance().getGroqKey();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.service.JanService;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.project.ProjectManager;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -27,7 +27,7 @@ public class JanChatModelFactory implements ChatModelFactory {
@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
return LocalAiChatModel.builder()
.baseUrl(SettingsState.getInstance().getJanModelUrl())
.baseUrl(SettingsStateService.getInstance().getJanModelUrl())
.modelName(chatModel.getModelName())
.maxRetries(chatModel.getMaxRetries())
.temperature(chatModel.getTemperature())
Expand All @@ -41,7 +41,7 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(SettingsState.getInstance().getJanModelUrl())
.baseUrl(SettingsStateService.getInstance().getJanModelUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
Expand All @@ -16,7 +16,7 @@ public class LMStudioChatModelFactory implements ChatModelFactory {
@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
return LocalAiChatModel.builder()
.baseUrl(SettingsState.getInstance().getLmstudioModelUrl())
.baseUrl(SettingsStateService.getInstance().getLmstudioModelUrl())
.modelName("LMStudio")
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
Expand All @@ -29,7 +29,7 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return LocalAiStreamingChatModel.builder()
.baseUrl(SettingsState.getInstance().getLmstudioModelUrl())
.baseUrl(SettingsStateService.getInstance().getLmstudioModelUrl())
.modelName("LMStudio")
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
Expand Down Expand Up @@ -42,7 +42,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch

@Override
public String getApiKey() {
return SettingsState.getInstance().getMistralKey();
return SettingsStateService.getInstance().getMistralKey();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.ollama.OllamaModelEntryDTO;
import com.devoxx.genie.service.OllamaService;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.project.ProjectManager;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -23,7 +23,7 @@ public class OllamaChatModelFactory implements ChatModelFactory {
@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
return OllamaChatModel.builder()
.baseUrl(SettingsState.getInstance().getOllamaModelUrl())
.baseUrl(SettingsStateService.getInstance().getOllamaModelUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
Expand All @@ -35,7 +35,7 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
return OllamaStreamingChatModel.builder()
.baseUrl(SettingsState.getInstance().getOllamaModelUrl())
.baseUrl(SettingsStateService.getInstance().getOllamaModelUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.devoxx.genie.chatmodel.ChatModelFactory;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.service.SettingsStateService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
Expand Down Expand Up @@ -41,7 +41,7 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch

@Override
public String getApiKey() {
return SettingsState.getInstance().getOpenAIKey();
return SettingsStateService.getInstance().getOpenAIKey();
}

@Override
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/com/devoxx/genie/model/Constant.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,9 @@ private Constant() {
public static final Integer MAX_MEMORY = 6;

public static final Boolean STREAM_MODE = false;

public static final Boolean AST_MODE = false;
public static final Boolean AST_PARENT_CLASS = true;
public static final Boolean AST_CLASS_REFERENCE = true;
public static final Boolean AST_FIELD_REFERENCE = true;
}
5 changes: 5 additions & 0 deletions src/main/java/com/devoxx/genie/model/request/EditorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import lombok.Getter;
import lombok.Setter;

import java.util.ArrayList;
import java.util.List;

@Setter
Expand All @@ -20,4 +21,8 @@ public EditorInfo() {
public EditorInfo(List<VirtualFile> selectedFiles) {
this.selectedFiles = selectedFiles;
}

public void setSelectedFiles(List<VirtualFile> selectedFiles) {
this.selectedFiles = new ArrayList<>(selectedFiles);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.devoxx.genie.service;

import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.SettingsState;
import com.devoxx.genie.ui.panel.PromptOutputPanel;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.progress.ProgressIndicator;
Expand All @@ -17,7 +16,7 @@
public class ChatPromptExecutor {

private final PromptExecutionService promptExecutionService = PromptExecutionService.getInstance();
private final SettingsState settingsState = SettingsState.getInstance();
private final SettingsStateService settingsState = SettingsStateService.getInstance();

public ChatPromptExecutor() {
}
Expand All @@ -36,7 +35,7 @@ public void executePrompt(@NotNull ChatMessageContext chatMessageContext,
new Task.Backgroundable(chatMessageContext.getProject(), "Working...", true) {
@Override
public void run(@NotNull ProgressIndicator progressIndicator) {
if (SettingsState.getInstance().getStreamMode()) {
if (SettingsStateService.getInstance().getStreamMode()) {
setupStreaming(chatMessageContext, promptOutputPanel, enableButtons);
} else {
runPrompt(chatMessageContext, promptOutputPanel, enableButtons);
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/com/devoxx/genie/service/JanService.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.devoxx.genie.model.jan.Data;
import com.devoxx.genie.model.jan.ResponseDTO;
import com.devoxx.genie.ui.SettingsState;
import com.google.gson.Gson;
import okhttp3.OkHttpClient;
import okhttp3.Request;
Expand All @@ -21,7 +20,7 @@ public JanService(OkHttpClient client) {
}

public List<Data> getModels() throws IOException {
String baseUrl = ensureEndsWithSlash(SettingsState.getInstance().getJanModelUrl());
String baseUrl = ensureEndsWithSlash(SettingsStateService.getInstance().getJanModelUrl());

Request request = new Request.Builder()
.url(baseUrl + "models")
Expand Down
Loading

0 comments on commit 4febe75

Please sign in to comment.