Skip to content

Commit

Permalink
Merge pull request #178 from devoxx/issue-177
Browse files Browse the repository at this point in the history
Feat #177: Show Ollama language models window context
  • Loading branch information
stephanj authored Jul 23, 2024
2 parents 82e8451 + 340aa2b commit 587f702
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 92 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {
}

group = "com.devoxx.genie"
version = "0.2.6"
version = "0.2.7"

repositories {
mavenCentral()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private void addDirectoryToContext(Project project, @NotNull VirtualFile directo
ModelProvider selectedProvider = ModelProvider.valueOf(settings.getSelectedProvider());

ProjectContentService.getInstance()
.getDirectoryContentAndTokens(project, directory, false, selectedProvider)
.getDirectoryContentAndTokens(directory, false, selectedProvider)
.thenAccept(result -> {
int fileCount = filesToAdd.size();
int tokenCount = result.getTokenCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.devoxx.genie.ui.util.WindowContextFormatterUtil;
import com.intellij.openapi.actionSystem.ActionUpdateThread;
import com.intellij.openapi.actionSystem.AnAction;
import com.intellij.openapi.actionSystem.AnActionEvent;
import com.intellij.openapi.actionSystem.CommonDataKeys;
import com.intellij.openapi.progress.ProgressIndicator;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.progress.Task;
import com.intellij.openapi.project.DumbAwareAction;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.vfs.VirtualFile;
import org.jetbrains.annotations.NotNull;

public class CalcTokensForDirectoryAction extends AnAction {
public class CalcTokensForDirectoryAction extends DumbAwareAction {

@Override
public void actionPerformed(@NotNull AnActionEvent e) {
Expand All @@ -33,7 +35,7 @@ public void actionPerformed(@NotNull AnActionEvent e) {
@Override
public void run(@NotNull ProgressIndicator indicator) {
ProjectContentService.getInstance()
.getDirectoryContentAndTokens(project, selectedDir, true, selectedProvider)
.getDirectoryContentAndTokens(selectedDir, true, selectedProvider)
.thenAccept(result -> {
String message = String.format("Directory '%s' contains approximately %s tokens (using %s tokenizer)",
selectedDir.getName(),
Expand All @@ -50,4 +52,14 @@ public void update(@NotNull AnActionEvent e) {
VirtualFile file = e.getData(CommonDataKeys.VIRTUAL_FILE);
e.getPresentation().setEnabledAndVisible(file != null && file.isDirectory());
}

@Override
public @NotNull ActionUpdateThread getActionUpdateThread() {
return ActionUpdateThread.BGT;
}

@Override
public boolean isDumbAware() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.ollama.OllamaModelEntryDTO;
import com.devoxx.genie.service.OllamaApiService;
import com.devoxx.genie.service.OllamaService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.devoxx.genie.ui.util.NotificationUtil;
Expand All @@ -19,9 +20,12 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class OllamaChatModelFactory implements ChatModelFactory {

private static final ExecutorService executorService = Executors.newFixedThreadPool(5);
private static boolean warningShown = false;
private List<LanguageModel> cachedModels = null;

Expand Down Expand Up @@ -61,21 +65,35 @@ public List<LanguageModel> getModels() {
}

List<LanguageModel> modelNames = new ArrayList<>();
List<CompletableFuture<Void>> futures = new ArrayList<>();

try {
OllamaModelEntryDTO[] ollamaModels = OllamaService.getInstance().getModels();
for (OllamaModelEntryDTO model : ollamaModels) {
modelNames.add(
LanguageModel.builder()
.provider(ModelProvider.Ollama)
.modelName(model.getName())
.displayName(model.getName())
.inputCost(0)
.outputCost(0)
.contextWindow(8_000)
.apiKeyUsed(false)
.build()
);
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
try {
int contextWindow = OllamaApiService.getModelContext(model.getName());
LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.Ollama)
.modelName(model.getName())
.displayName(model.getName())
.inputCost(0)
.outputCost(0)
.contextWindow(contextWindow)
.apiKeyUsed(false)
.build();
synchronized (modelNames) {
modelNames.add(languageModel);
}
} catch (IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
"Error fetching context window for model: " + model.getName());
}
}, executorService);
futures.add(future);
}

CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
cachedModels = modelNames;
} catch (IOException e) {
if (!warningShown) {
Expand All @@ -87,9 +105,4 @@ public List<LanguageModel> getModels() {
}
return cachedModels;
}

public void resetCache() {
cachedModels = null;
warningShown = false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,22 @@ private void addMistralModels() {
.provider(ModelProvider.Mistral)
.modelName(OPEN_MIXTRAL_8x7B.toString())
.displayName("Mistral 8x7B")
.inputCost(7)
.inputCost(0.7)
.outputCost(0.7)
.contextWindow(32_000)
.apiKeyUsed(true)
.build());

models.add(LanguageModel.builder()
.provider(ModelProvider.Mistral)
.modelName(OPEN_MIXTRAL_8X22B.toString())
.displayName("Mistral 8x22b")
.inputCost(2)
.outputCost(6)
.contextWindow(64_000)
.apiKeyUsed(true)
.build());

models.add(LanguageModel.builder()
.provider(ModelProvider.Mistral)
.modelName(MISTRAL_SMALL_LATEST.toString())
Expand Down Expand Up @@ -395,6 +405,16 @@ private void addMistralModels() {
.contextWindow(32_000)
.apiKeyUsed(true)
.build());

models.add(LanguageModel.builder()
.provider(ModelProvider.Mistral)
.modelName("codestral-2405")
.displayName("Codestral")
.inputCost(1)
.outputCost(3)
.contextWindow(32_000)
.apiKeyUsed(true)
.build());
}

@Contract(value = " -> new", pure = true)
Expand Down
54 changes: 54 additions & 0 deletions src/main/java/com/devoxx/genie/service/OllamaApiService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.devoxx.genie.service;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import okhttp3.*;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;

public class OllamaApiService {
private static final String OLLAMA_API_URL = "http://localhost:11434/api/show";
private static final OkHttpClient client = new OkHttpClient();
private static final Gson gson = new Gson();

public static int getModelContext(@NotNull String modelName) throws IOException {
RequestBody body = RequestBody.create(
MediaType.parse("application/json"),
"{\"name\":\"" + modelName + "\"}"
);

Request request = new Request.Builder()
.url(OLLAMA_API_URL)
.post(body)
.build();

try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);

JsonObject jsonObject = gson.fromJson(response.body().string(), JsonObject.class);
return findContextLength(jsonObject);
}
}

private static int findContextLength(@NotNull JsonObject jsonObject) {
JsonElement modelInfo = jsonObject.get("model_info");
if (modelInfo != null && modelInfo.isJsonObject()) {
JsonObject modelInfoObject = modelInfo.getAsJsonObject();
for (String key : modelInfoObject.keySet()) {
if (key.endsWith(".context_length")) {
return modelInfoObject.get(key).getAsInt();
}
}
}

// Fallback: check if context_length exists directly in the root
JsonElement contextLength = jsonObject.get("context_length");
if (contextLength != null) {
return contextLength.getAsInt();
}

return -1; // Return -1 if not found
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,15 @@ public CompletableFuture<String> getDirectoryContent(Project project,
});
}

public CompletableFuture<ContentResult> getDirectoryContentAndTokens(Project project,
VirtualFile directory,
public CompletableFuture<ContentResult> getDirectoryContentAndTokens(VirtualFile directory,
boolean isTokenCalculation,
ModelProvider modelProvider) {
return CompletableFuture.supplyAsync(() -> {
AtomicLong totalTokens = new AtomicLong(0);
StringBuilder content = new StringBuilder();

Encoding encoding = getEncodingForProvider(modelProvider);
processDirectoryRecursively(project, directory, content, totalTokens, isTokenCalculation, encoding);
processDirectoryRecursively(directory, content, totalTokens, isTokenCalculation, encoding);

return new ContentResult(content.toString(), totalTokens.intValue());
});
Expand All @@ -108,7 +107,9 @@ public void calculateTokensAndCost(Project project,
getProjectContent(project, windowContext, true)
.thenAccept(projectContent -> {
int tokenCount = ENCODING.countTokens(projectContent);
String message = String.format("Project contains %s. Cost calculation is not applicable for local providers.",
String message = String.format("Project contains %s. " +
"Cost calculation is not applicable for local providers. " +
"Make sure you select a model with a big enough window context.",
WindowContextFormatterUtil.format(tokenCount, "tokens"));
NotificationUtil.sendNotification(project, message);
});
Expand Down Expand Up @@ -153,14 +154,12 @@ private Encoding getEncodingForProvider(@NotNull ModelProvider provider) {

/**
* Processes a directory recursively, calculating the number of tokens and building a content string.
* @param project The Project containing the directory to scan
* @param directory VirtualFile representing the directory to scan
* @param content StringBuilder object to hold the content of the scanned files
* @param totalTokens AtomicLong object to hold the total token count
* @param isTokenCalculation Boolean flag indicating whether to calculate tokens or not
*/
private void processDirectoryRecursively(Project project,
@NotNull VirtualFile directory,
private void processDirectoryRecursively(@NotNull VirtualFile directory,
StringBuilder content,
AtomicLong totalTokens,
boolean isTokenCalculation,
Expand All @@ -170,7 +169,7 @@ private void processDirectoryRecursively(Project project,
for (VirtualFile child : directory.getChildren()) {
if (child.isDirectory()) {
if (!settings.getExcludedDirectories().contains(child.getName())) {
processDirectoryRecursively(project, child, content, totalTokens, isTokenCalculation, encoding);
processDirectoryRecursively(child, content, totalTokens, isTokenCalculation, encoding);
}
} else if (shouldIncludeFile(child, settings)) {
String fileContent = readFileContent(child);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ public CompletableFuture<String> scanProject(Project project,
}

// Only truncate if it's not a token calculation
if (!isTokenCalculation) {
return truncateToTokens(project, fullContent.toString(), windowContext, isTokenCalculation);
} else {
if (isTokenCalculation) {
return fullContent.toString();
} else {
return truncateToTokens(project, fullContent.toString(), windowContext, isTokenCalculation);
}
}).inSmartMode(project)
.finishOnUiThread(ModalityState.defaultModalityState(), future::complete)
Expand Down
Loading

0 comments on commit 587f702

Please sign in to comment.