Skip to content

Commit

Permalink
feat: add input field for llama server build parameters and improve e…
Browse files Browse the repository at this point in the history
…rror handling (#481)
  • Loading branch information
PhilKes authored Apr 20, 2024
1 parent 67dc425 commit c8181a6
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 49 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ tasks {
runIde {
enabled = true
environment("ENVIRONMENT", "LOCAL")
autoReloadPlugins.set(false) // is triggered when building llama server
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.ui.MessageType;
import com.intellij.openapi.util.Key;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.ServerProgressPanel;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand All @@ -32,65 +35,94 @@ public final class LlamaServerAgent implements Disposable {

private @Nullable OSProcessHandler makeProcessHandler;
private @Nullable OSProcessHandler startServerProcessHandler;
private ServerProgressPanel activeServerProgressPanel;
private boolean stoppedByUser;

public void startAgent(
LlamaServerStartupParams params,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
this.activeServerProgressPanel = serverProgressPanel;
ApplicationManager.getApplication().invokeLater(() -> {
try {
serverProgressPanel.updateText(
stoppedByUser = false;
serverProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.buildingProject.description"));
makeProcessHandler = new OSProcessHandler(getMakeCommandLinde());
makeProcessHandler = new OSProcessHandler(
getMakeCommandLine(params.additionalBuildParameters()));
makeProcessHandler.addProcessListener(
getMakeProcessListener(params, serverProgressPanel, onSuccess, onServerTerminated));
getMakeProcessListener(params, onSuccess, onServerTerminated));
makeProcessHandler.startNotify();
} catch (ExecutionException e) {
throw new RuntimeException(e);
showServerError(e.getMessage(), onServerTerminated);
}
});
}

public void stopAgent() {
stoppedByUser = true;
if (makeProcessHandler != null) {
makeProcessHandler.destroyProcess();
}
if (startServerProcessHandler != null) {
startServerProcessHandler.destroyProcess();
}
}

public boolean isServerRunning() {
return startServerProcessHandler != null
return (makeProcessHandler != null
&& makeProcessHandler.isStartNotified()
&& !makeProcessHandler.isProcessTerminated())
|| (startServerProcessHandler != null
&& startServerProcessHandler.isStartNotified()
&& !startServerProcessHandler.isProcessTerminated();
&& !startServerProcessHandler.isProcessTerminated());
}

private ProcessListener getMakeProcessListener(
LlamaServerStartupParams params,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
LOG.info("Building llama project");

return new ProcessAdapter() {

private final List<String> errorLines = new CopyOnWriteArrayList<>();

@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}
LOG.info(event.getText());
}

@Override
public void processTerminated(@NotNull ProcessEvent event) {
int exitCode = event.getExitCode();
LOG.info(format("Server build exited with code %d", exitCode));
if (stoppedByUser) {
onServerTerminated.accept(activeServerProgressPanel);
return;
}
if (exitCode != 0) {
showServerError(String.join(",", errorLines), onServerTerminated);
return;
}

try {
LOG.info("Booting up llama server");

serverProgressPanel.updateText(
activeServerProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.serverBootup.description"));
startServerProcessHandler = new OSProcessHandler.Silent(getServerCommandLine(params));
startServerProcessHandler.addProcessListener(
getProcessListener(params.port(), onSuccess, onServerTerminated));
getProcessListener(params.port(), onSuccess,
onServerTerminated));
startServerProcessHandler.startNotify();
} catch (ExecutionException ex) {
LOG.error("Unable to start llama server", ex);
throw new RuntimeException(ex);
showServerError(ex.getMessage(), onServerTerminated);
}
}
};
Expand All @@ -99,27 +131,25 @@ public void processTerminated(@NotNull ProcessEvent event) {
private ProcessListener getProcessListener(
int port,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
return new ProcessAdapter() {
private final ObjectMapper objectMapper = new ObjectMapper();
private final List<String> errorLines = new CopyOnWriteArrayList<>();

@Override
public void processTerminated(@NotNull ProcessEvent event) {
if (errorLines.isEmpty()) {
LOG.info(format("Server terminated with code %d", event.getExitCode()));
LOG.info(format("Server terminated with code %d", event.getExitCode()));
if (stoppedByUser) {
onServerTerminated.accept(activeServerProgressPanel);
} else {
LOG.info(String.join("", errorLines));
showServerError(String.join(",", errorLines), onServerTerminated);
}

onServerTerminated.run();
}

@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}

if (ProcessOutputType.isStdout(outputType)) {
Expand All @@ -141,11 +171,18 @@ public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType
};
}

private static GeneralCommandLine getMakeCommandLinde() {
private void showServerError(String errorText, Consumer<ServerProgressPanel> onServerTerminated) {
onServerTerminated.accept(activeServerProgressPanel);
LOG.info("Unable to start llama server:\n" + errorText);
OverlayUtil.showClosableBalloon(errorText, MessageType.ERROR, activeServerProgressPanel);
}

private static GeneralCommandLine getMakeCommandLine(List<String> additionalCompileParameters) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("make");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.addParameters("-j");
commandLine.addParameters(additionalCompileParameters);
commandLine.setRedirectErrorStream(false);
return commandLine;
}
Expand All @@ -159,11 +196,16 @@ private GeneralCommandLine getServerCommandLine(LlamaServerStartupParams params)
"-c", String.valueOf(params.contextLength()),
"--port", String.valueOf(params.port()),
"-t", String.valueOf(params.threads()));
commandLine.addParameters(params.additionalParameters());
commandLine.addParameters(params.additionalRunParameters());
commandLine.setRedirectErrorStream(false);
return commandLine;
}

public void setActiveServerProgressPanel(
ServerProgressPanel activeServerProgressPanel) {
this.activeServerProgressPanel = activeServerProgressPanel;
}

@Override
public void dispose() {
if (makeProcessHandler != null && !makeProcessHandler.isProcessTerminated()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
import java.util.List;

public record LlamaServerStartupParams(String modelPath, int contextLength, int threads, int port,
List<String> additionalParameters) {
List<String> additionalRunParameters,
List<String> additionalBuildParameters) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class LlamaSettingsState {
private int contextSize = 2048;
private int threads = 8;
private String additionalParameters = "";
private String additionalBuildParameters = "";
private int topK = 40;
private double topP = 0.9;
private double minP = 0.05;
Expand Down Expand Up @@ -138,6 +139,14 @@ public void setAdditionalParameters(String additionalParameters) {
this.additionalParameters = additionalParameters;
}

public String getAdditionalBuildParameters() {
return additionalBuildParameters;
}

public void setAdditionalBuildParameters(String additionalBuildParameters) {
this.additionalBuildParameters = additionalBuildParameters;
}

public int getTopK() {
return topK;
}
Expand Down Expand Up @@ -220,6 +229,7 @@ public boolean equals(Object o) {
&& Objects.equals(baseHost, that.baseHost)
&& Objects.equals(serverPort, that.serverPort)
&& Objects.equals(additionalParameters, that.additionalParameters)
&& Objects.equals(additionalBuildParameters, that.additionalBuildParameters)
&& codeCompletionsEnabled == that.codeCompletionsEnabled
&& codeCompletionMaxTokens == that.codeCompletionMaxTokens;
}
Expand All @@ -229,7 +239,7 @@ public int hashCode() {
return Objects.hash(runLocalServer, useCustomModel, customLlamaModelPath, huggingFaceModel,
localModelPromptTemplate, remoteModelPromptTemplate, localModelInfillPromptTemplate,
remoteModelInfillPromptTemplate, baseHost, serverPort, contextSize, threads,
additionalParameters, topK, topP, minP, repeatPenalty, codeCompletionsEnabled,
codeCompletionMaxTokens);
additionalParameters, additionalBuildParameters, topK, topP, minP, repeatPenalty,
codeCompletionsEnabled, codeCompletionMaxTokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class LlamaServerPreferencesForm {
private final IntegerField maxTokensField;
private final IntegerField threadsField;
private final JBTextField additionalParametersField;
private final JBTextField additionalBuildParametersField;
private final ChatPromptTemplatePanel remotePromptTemplatePanel;
private final InfillPromptTemplatePanel infillPromptTemplatePanel;

Expand All @@ -79,6 +80,9 @@ public LlamaServerPreferencesForm(LlamaSettingsState settings) {
additionalParametersField = new JBTextField(settings.getAdditionalParameters(), 30);
additionalParametersField.setEnabled(!serverRunning);

additionalBuildParametersField = new JBTextField(settings.getAdditionalBuildParameters(), 30);
additionalBuildParametersField.setEnabled(!serverRunning);

baseHostField = new JBTextField(settings.getBaseHost(), 30);
apiKeyField = new JBPasswordField();
apiKeyField.setColumns(30);
Expand Down Expand Up @@ -124,6 +128,7 @@ public void resetForm(LlamaSettingsState state) {
maxTokensField.setValue(state.getContextSize());
threadsField.setValue(state.getThreads());
additionalParametersField.setText(state.getAdditionalParameters());
additionalBuildParametersField.setText(state.getAdditionalBuildParameters());
remotePromptTemplatePanel.setPromptTemplate(state.getRemoteModelPromptTemplate()); // ?
infillPromptTemplatePanel.setPromptTemplate(state.getRemoteModelInfillPromptTemplate());
apiKeyField.setText(CredentialsStore.INSTANCE.getCredential(LLAMA_API_KEY));
Expand Down Expand Up @@ -184,9 +189,17 @@ public JComponent createRunLocalServerForm(LlamaServerAgent llamaServerAgent) {
createComment("settingsConfigurable.service.llama.threads.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.additionalParameters.label"),
additionalParametersField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.additionalParameters.comment"))
additionalParametersField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalParameters.comment"))
.addLabeledComponent(
CodeGPTBundle.get(
"settingsConfigurable.service.llama.additionalBuildParameters.label"),
additionalBuildParametersField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalBuildParameters.comment"))
.addVerticalGap(4)
.addComponentFillVertically(new JPanel(), 0)
.getPanel()))
Expand All @@ -196,6 +209,7 @@ public JComponent createRunLocalServerForm(LlamaServerAgent llamaServerAgent) {
private JButton getServerButton(
LlamaServerAgent llamaServerAgent,
ServerProgressPanel serverProgressPanel) {
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel);
var serverRunning = llamaServerAgent.isServerRunning();
var serverButton = new JButton();
serverButton.setText(serverRunning
Expand All @@ -218,7 +232,9 @@ private JButton getServerButton(
getContextSize(),
getThreads(),
getServerPort(),
getListOfAdditionalParameters()),
getListOfAdditionalParameters(),
getListOfAdditionalBuildParameters()
),
serverProgressPanel,
() -> {
setFormEnabled(false);
Expand All @@ -227,12 +243,12 @@ private JButton getServerButton(
Actions.Checked,
SwingConstants.LEADING));
},
() -> {
(activeServerProgressPanel) -> {
setFormEnabled(true);
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label"));
serverButton.setIcon(Actions.Execute);
serverProgressPanel.displayComponent(new JBLabel(
activeServerProgressPanel.displayComponent(new JBLabel(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverTerminated"),
Actions.Cancel,
SwingConstants.LEADING));
Expand Down Expand Up @@ -282,7 +298,7 @@ private void enableForm(JButton serverButton, ServerProgressPanel progressPanel)
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label"));
serverButton.setIcon(Actions.Execute);
progressPanel.updateText(
progressPanel.displayText(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer"));
}

Expand All @@ -291,7 +307,7 @@ private void disableForm(JButton serverButton, ServerProgressPanel progressPanel
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.stopServer.label"));
serverButton.setIcon(Actions.Suspend);
progressPanel.startProgress(
progressPanel.displayText(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer"));
}

Expand All @@ -301,6 +317,7 @@ private void setFormEnabled(boolean enabled) {
maxTokensField.setEnabled(enabled);
threadsField.setEnabled(enabled);
additionalParametersField.setEnabled(enabled);
additionalBuildParametersField.setEnabled(enabled);
}

public boolean isRunLocalServer() {
Expand Down Expand Up @@ -337,9 +354,20 @@ public String getAdditionalParameters() {

public List<String> getListOfAdditionalParameters() {
return Arrays.stream(additionalParametersField.getText().split(","))
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
}

public String getAdditionalBuildParameters() {
return additionalBuildParametersField.getText();
}

public List<String> getListOfAdditionalBuildParameters() {
return Arrays.stream(additionalBuildParametersField.getText().split(","))
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
}

public PromptTemplate getPromptTemplate() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public LlamaSettingsState getCurrentState() {
state.setContextSize(llamaServerPreferencesForm.getContextSize());
state.setThreads(llamaServerPreferencesForm.getThreads());
state.setAdditionalParameters(llamaServerPreferencesForm.getAdditionalParameters());
state.setAdditionalBuildParameters(llamaServerPreferencesForm.getAdditionalBuildParameters());

var modelPreferencesForm = llamaServerPreferencesForm.getLlamaModelPreferencesForm();
state.setCustomLlamaModelPath(modelPreferencesForm.getCustomLlamaModelPath());
Expand Down
Loading

0 comments on commit c8181a6

Please sign in to comment.