Skip to content

Commit

Permalink
feat: support DeepSeek R1 and V3
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh committed Jan 27, 2025
1 parent 89a3b66 commit 0549957
Show file tree
Hide file tree
Showing 13 changed files with 305 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/main/cpp/llama.cpp
Submodule llama.cpp updated 1079 files
14 changes: 9 additions & 5 deletions src/main/java/ee/carlrobert/codegpt/CodeGPTPlugin.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package ee.carlrobert.codegpt;

import static java.io.File.separator;
import static java.util.Objects.requireNonNull;

import com.intellij.ide.plugins.PluginManagerCore;
import com.intellij.openapi.application.PathManager;
import com.intellij.openapi.extensions.PluginId;
import com.intellij.openapi.project.Project;
import java.io.File;
import java.nio.file.Path;
import org.jetbrains.annotations.NotNull;

Expand All @@ -26,18 +26,22 @@ private CodeGPTPlugin() {
}

public static @NotNull String getPluginOptionsPath() {
return PathManager.getOptionsPath() + File.separator + "CodeGPT";
return PathManager.getOptionsPath() + separator + "CodeGPT";
}

public static @NotNull String getIndexStorePath() {
return getPluginOptionsPath() + File.separator + "indexes";
return getPluginOptionsPath() + separator + "indexes";
}

public static @NotNull String getLlamaSourcePath() {
return getPluginBasePath() + File.separator + "llama.cpp";
return getPluginBasePath() + separator + "llama.cpp";
}

public static @NotNull String getLlamaServerSourcePath() {
return getPluginBasePath() + separator + "llama.cpp" + separator + "build" + separator + "bin";
}

public static @NotNull String getProjectIndexStorePath(@NotNull Project project) {
return getIndexStorePath() + File.separator + project.getName();
return getIndexStorePath() + separator + project.getName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ public enum HuggingFaceModel {
DEEPSEEK_CODER_33B_Q5(33, 5, "deepseek-coder-33b-instruct-GGUF",
"deepseek-coder-33b-instruct.Q5_K_M.gguf", 23.5),

DEEPSEEK_R1_1_5B_Q6(1, 6, "DeepSeek-R1-Distill-Qwen-1.5B-GGUF",
"DeepSeek-R1-Distill-Qwen-1.5B-Q6_K.gguf", "bartowski", 1.89),
DEEPSEEK_R1_7B_Q4(7, 4, "DeepSeek-R1-Distill-Qwen-7B-GGUF",
"DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", "bartowski", 4.68),
DEEPSEEK_R1_7B_Q6(7, 6, "DeepSeek-R1-Distill-Qwen-7B-GGUF",
"DeepSeek-R1-Distill-Qwen-7B-Q6_K.gguf", "bartowski", 6.25),
DEEPSEEK_R1_14B_Q4(14, 4, "DeepSeek-R1-Distill-Qwen-14B-GGUF",
"DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf", "bartowski", 8.99),
DEEPSEEK_R1_14B_Q6(14, 6, "DeepSeek-R1-Distill-Qwen-14B-GGUF",
"DeepSeek-R1-Distill-Qwen-14B-Q6_K.gguf", "bartowski", 12.12),

PHIND_CODE_LLAMA_34B_Q3(34, 3, "Phind-CodeLlama-34B-v2-GGUF",
"phind-codellama-34b-v2.Q3_K_M.gguf"),
PHIND_CODE_LLAMA_34B_Q4(34, 4, "Phind-CodeLlama-34B-v2-GGUF",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ public enum LlamaModel {
HuggingFaceModel.DEEPSEEK_CODER_33B_Q3,
HuggingFaceModel.DEEPSEEK_CODER_33B_Q4,
HuggingFaceModel.DEEPSEEK_CODER_33B_Q5)),
DEEPSEEK_R1(
"Deepseek R1",
"DeepSeek-R1-Zero, a model trained via large-scale reinforcement learning (RL) "
+ "without supervised fine-tuning (SFT) as a preliminary step, demonstrated remarkable "
+ "performance on reasoning. DeepSeek-R1 achieves performance comparable to OpenAI-o1 "
+ "across math, code, and reasoning tasks.",
PromptTemplate.DEEPSEEK_R1,
InfillPromptTemplate.DEEPSEEK_CODER,
List.of(
HuggingFaceModel.DEEPSEEK_R1_1_5B_Q6,
HuggingFaceModel.DEEPSEEK_R1_7B_Q4,
HuggingFaceModel.DEEPSEEK_R1_7B_Q6,
HuggingFaceModel.DEEPSEEK_R1_14B_Q4,
HuggingFaceModel.DEEPSEEK_R1_14B_Q6)),
PHIND_CODE_LLAMA(
"Phind Code Llama",
"This model is fine-tuned from Phind-CodeLlama-34B-v1 on an additional 1.5B tokens "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public final class LlamaServerAgent implements Disposable {

private static final Logger LOG = Logger.getInstance(LlamaServerAgent.class);

private @Nullable OSProcessHandler makeProcessHandler;
private @Nullable OSProcessHandler makeSetupProcessHandler;
private @Nullable OSProcessHandler makeBuildProcessHandler;
private @Nullable OSProcessHandler startServerProcessHandler;
private ServerProgressPanel activeServerProgressPanel;
private boolean stoppedByUser;
Expand All @@ -49,11 +50,44 @@ public void startAgent(
stoppedByUser = false;
serverProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.buildingProject.description"));
makeProcessHandler = new OSProcessHandler(
getMakeCommandLine(params));
makeProcessHandler.addProcessListener(
getMakeProcessListener(params, onSuccess, onServerStopped));
makeProcessHandler.startNotify();

makeSetupProcessHandler = new OSProcessHandler(getCMakeSetupCommandLine(params));
makeSetupProcessHandler.addProcessListener(new ProcessAdapter() {
private final List<String> errorLines = new CopyOnWriteArrayList<>();

@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}
LOG.info(event.getText());
}

@Override
public void processTerminated(@NotNull ProcessEvent event) {
int exitCode = event.getExitCode();
LOG.info(format("CMake setup exited with code %d", exitCode));
if (stoppedByUser) {
onServerStopped.accept(activeServerProgressPanel);
return;
}
if (exitCode != 0) {
showServerError(String.join(",", errorLines), onServerStopped);
return;
}

try {
makeBuildProcessHandler = new OSProcessHandler(getCMakeBuildCommandLine(params));
makeBuildProcessHandler.addProcessListener(
getMakeProcessListener(params, onSuccess, onServerStopped));
makeBuildProcessHandler.startNotify();
} catch (ExecutionException e) {
showServerError(e.getMessage(), onServerStopped);
}
}
});
makeSetupProcessHandler.startNotify();
} catch (ExecutionException e) {
showServerError(e.getMessage(), onServerStopped);
}
Expand All @@ -62,18 +96,18 @@ public void startAgent(

public void stopAgent() {
stoppedByUser = true;
if (makeProcessHandler != null) {
makeProcessHandler.destroyProcess();
if (makeSetupProcessHandler != null) {
makeSetupProcessHandler.destroyProcess();
}
if (startServerProcessHandler != null) {
startServerProcessHandler.destroyProcess();
}
}

public boolean isServerRunning() {
return (makeProcessHandler != null
&& makeProcessHandler.isStartNotified()
&& !makeProcessHandler.isProcessTerminated())
return (makeSetupProcessHandler != null
&& makeSetupProcessHandler.isStartNotified()
&& !makeSetupProcessHandler.isProcessTerminated())
|| (startServerProcessHandler != null
&& startServerProcessHandler.isStartNotified()
&& !startServerProcessHandler.isProcessTerminated());
Expand Down Expand Up @@ -147,25 +181,14 @@ public void processTerminated(@NotNull ProcessEvent event) {

@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
}

if (ProcessOutputType.isStdout(outputType)) {
LOG.info(event.getText());
LOG.info(event.getText());

try {
var serverMessage = objectMapper.readValue(event.getText(), LlamaServerMessage.class);
// hack
if ("HTTP server listening".equals(serverMessage.msg())) {
LOG.info("Server up and running!");
// TODO: Use proper successful boot up validation
if (event.getText().contains("server is listening")) {
LOG.info("Server up and running!");

LlamaSettings.getCurrentState().setServerPort(port);
onSuccess.run();
}
} catch (Exception ignore) {
// ignore
}
LlamaSettings.getCurrentState().setServerPort(port);
onSuccess.run();
}
}
};
Expand All @@ -177,21 +200,33 @@ private void showServerError(String errorText, Consumer<ServerProgressPanel> onS
OverlayUtil.showClosableBalloon(errorText, MessageType.ERROR, activeServerProgressPanel);
}

private static GeneralCommandLine getMakeCommandLine(LlamaServerStartupParams params) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("make");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.addParameters("-j");
commandLine.addParameters(params.additionalBuildParameters());
commandLine.withEnvironment(params.additionalEnvironmentVariables());
commandLine.setRedirectErrorStream(false);
return commandLine;
private static GeneralCommandLine getCMakeSetupCommandLine(LlamaServerStartupParams params) {
GeneralCommandLine cmakeSetupCommand = new GeneralCommandLine().withCharset(
StandardCharsets.UTF_8);
cmakeSetupCommand.setExePath("cmake");
cmakeSetupCommand.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
cmakeSetupCommand.addParameters("-B", "build");
cmakeSetupCommand.withEnvironment(params.additionalEnvironmentVariables());
cmakeSetupCommand.setRedirectErrorStream(false);
return cmakeSetupCommand;
}

private static GeneralCommandLine getCMakeBuildCommandLine(LlamaServerStartupParams params) {
GeneralCommandLine cmakeBuildCommand = new GeneralCommandLine().withCharset(
StandardCharsets.UTF_8);
cmakeBuildCommand.setExePath("cmake");
cmakeBuildCommand.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
cmakeBuildCommand.addParameters("--build", "build", "--config", "Release", "-t", "llama-server",
"-j", "4");
cmakeBuildCommand.withEnvironment(params.additionalEnvironmentVariables());
cmakeBuildCommand.setRedirectErrorStream(false);
return cmakeBuildCommand;
}

private GeneralCommandLine getServerCommandLine(LlamaServerStartupParams params) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("./server");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.setExePath("./llama-server");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaServerSourcePath());
commandLine.addParameters(
"-m", params.modelPath(),
"-c", String.valueOf(params.contextLength()),
Expand All @@ -210,8 +245,8 @@ public void setActiveServerProgressPanel(

@Override
public void dispose() {
if (makeProcessHandler != null && !makeProcessHandler.isProcessTerminated()) {
makeProcessHandler.destroyProcess();
if (makeSetupProcessHandler != null && !makeSetupProcessHandler.isProcessTerminated()) {
makeSetupProcessHandler.destroyProcess();
}
if (startServerProcessHandler != null && !startServerProcessHandler.isProcessTerminated()) {
startServerProcessHandler.destroyProcess();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ee.carlrobert.codegpt.conversations.message.Message;
import java.util.List;
import java.util.stream.Collectors;

public enum PromptTemplate {

Expand Down Expand Up @@ -237,6 +238,23 @@ public String buildPrompt(String systemPrompt, String userPrompt, List<Message>
.toString();
}
},
DEEPSEEK_R1("DeepSeek R1") {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
var historyString = history.stream()
.map(it -> {
String response = it.getResponse();
if (response.startsWith("<think>")) {
response = response.replaceAll("(?s)<think>.*?</think>", "").trim();
}
return String.format("User:\n%s\n\nAssistant:\n%s", it.getPrompt(), response);
})
.collect(Collectors.joining("\n\n"));

return "<|begin▁of▁sentence|>%s<|User|>History:\n%s\n\nUser:\n%s<|Assistant|>"
.formatted(systemPrompt, historyString, userPrompt);
}
},
DEEPSEEK_CODER("DeepSeek Coder") {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ private ResponseMessagePanel createResponseMessagePanel(ChatCompletionParameters
panel.addCopyAction(() -> CopyAction.copyToClipboard(message.getResponse()));
panel.addContent(new ChatMessageResponseBody(
project,
true,
false,
message.isWebSearchIncluded(),
fileContextIncluded || message.getDocumentationDetails() != null,
Expand Down
Loading

0 comments on commit 0549957

Please sign in to comment.