Skip to content

Commit

Permalink
Merge pull request #160 from devoxx/issue-159
Browse files Browse the repository at this point in the history
Feat #159: Use correct tokenizer encoding calc based on selected LLM
  • Loading branch information
stephanj authored Jul 5, 2024
2 parents ab83df6 + b15afae commit 9517522
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 23 deletions.
11 changes: 7 additions & 4 deletions src/main/java/com/devoxx/genie/action/AddDirectoryAction.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.devoxx.genie.action;

import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.service.FileListManager;
import com.devoxx.genie.service.ProjectContentService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
Expand Down Expand Up @@ -53,14 +54,16 @@ private void addDirectoryToContext(Project project, @NotNull VirtualFile directo
if (!filesToAdd.isEmpty()) {
fileListManager.addFiles(filesToAdd);

// Get the content and token count
ProjectContentService.getInstance().getDirectoryContentAndTokens(project, directory, Integer.MAX_VALUE, false)
ModelProvider selectedProvider = ModelProvider.valueOf(settings.getSelectedProvider());

ProjectContentService.getInstance()
.getDirectoryContentAndTokens(project, directory, false, selectedProvider)
.thenAccept(result -> {
int fileCount = filesToAdd.size();
int tokenCount = result.getTokenCount();
NotificationUtil.sendNotification(project,
String.format("Added %d files from directory: %s (Approximately %s tokens)",
fileCount, directory.getName(), WindowContextFormatterUtil.format(tokenCount)));
String.format("Added %d files from directory: %s (Approximately %s tokens using %s tokenizer)",
fileCount, directory.getName(), WindowContextFormatterUtil.format(tokenCount), selectedProvider.getName()));
});
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.devoxx.genie.action;

import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.service.ProjectContentService;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.devoxx.genie.ui.util.WindowContextFormatterUtil;
import com.intellij.openapi.actionSystem.AnAction;
Expand All @@ -24,15 +26,19 @@ public void actionPerformed(@NotNull AnActionEvent e) {
return;
}

DevoxxGenieStateService stateService = DevoxxGenieStateService.getInstance();
ModelProvider selectedProvider = ModelProvider.valueOf(stateService.getSelectedProvider());

ProgressManager.getInstance().run(new Task.Backgroundable(project, "Calculating Tokens", false) {
@Override
public void run(@NotNull ProgressIndicator indicator) {
ProjectContentService.getInstance()
.getDirectoryContentAndTokens(project, selectedDir, Integer.MAX_VALUE, true)
.getDirectoryContentAndTokens(project, selectedDir, true, selectedProvider)
.thenAccept(result -> {
String message = String.format("Directory '%s' contains approximately %s tokens",
String message = String.format("Directory '%s' contains approximately %s tokens (using %s tokenizer)",
selectedDir.getName(),
WindowContextFormatterUtil.format(result.getTokenCount()));
WindowContextFormatterUtil.format(result.getTokenCount()),
selectedProvider.getName());
NotificationUtil.sendNotification(project, message);
});
}
Expand Down
36 changes: 22 additions & 14 deletions src/main/java/com/devoxx/genie/service/ProjectContentService.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,16 @@ public CompletableFuture<String> getDirectoryContent(Project project,
});
}

/**
* Retrieves and processes the content of a specified directory within a Project.
* Also calculates number of tokens in this content if required by user settings or provider configurations.
* @param project The Project containing the directory to be scanned
* @param directory VirtualFile representing the directory to scan for content
* @return ContentResult object holding both content and token count, optionally copied to clipboard based on flag
*/
public CompletableFuture<ContentResult> getDirectoryContentAndTokens(Project project,
VirtualFile directory,
int tokenLimit,
boolean isTokenCalculation) {
boolean isTokenCalculation,
ModelProvider modelProvider) {
return CompletableFuture.supplyAsync(() -> {
AtomicLong totalTokens = new AtomicLong(0);
StringBuilder content = new StringBuilder();

processDirectoryRecursively(project, directory, content, totalTokens, isTokenCalculation);
Encoding encoding = getEncodingForProvider(modelProvider);
processDirectoryRecursively(project, directory, content, totalTokens, isTokenCalculation, encoding);

return new ContentResult(content.toString(), totalTokens.intValue());
});
Expand Down Expand Up @@ -136,6 +130,19 @@ public void calculateTokensAndCost(Project project,
});
}

private Encoding getEncodingForProvider(@NotNull ModelProvider provider) {
return switch (provider) {
case OpenAI, Anthropic, Gemini ->
Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
case Mistral, DeepInfra, Groq ->
// These often use the Llama tokenizer or similar
Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.R50K_BASE);
default ->
// Default to cl100k_base for unknown providers
Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
};
}

/**
* Processes a directory recursively, calculating the number of tokens and building a content string.
* @param project The Project containing the directory to scan
Expand All @@ -145,24 +152,25 @@ public void calculateTokensAndCost(Project project,
* @param isTokenCalculation Boolean flag indicating whether to calculate tokens or not
*/
private void processDirectoryRecursively(Project project,
VirtualFile directory,
@NotNull VirtualFile directory,
StringBuilder content,
AtomicLong totalTokens,
boolean isTokenCalculation) {
boolean isTokenCalculation,
Encoding encoding) {
DevoxxGenieStateService settings = DevoxxGenieStateService.getInstance();

for (VirtualFile child : directory.getChildren()) {
if (child.isDirectory()) {
if (!settings.getExcludedDirectories().contains(child.getName())) {
processDirectoryRecursively(project, child, content, totalTokens, isTokenCalculation);
processDirectoryRecursively(project, child, content, totalTokens, isTokenCalculation, encoding);
}
} else if (shouldIncludeFile(child, settings)) {
String fileContent = readFileContent(child);
if (!isTokenCalculation) {
content.append("File: ").append(child.getPath()).append("\n");
content.append(fileContent).append("\n\n");
}
totalTokens.addAndGet(ENCODING.countTokens(fileContent));
totalTokens.addAndGet(encoding.countTokens(fileContent));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,17 @@ private void processModelNameSelection(@NotNull ActionEvent e) {
* Set the model provider and update the model names.
*/
private void handleModelProviderSelectionChange(@NotNull ActionEvent e) {
if (!e.getActionCommand()
.equals(Constant.COMBO_BOX_CHANGED) || !isInitializationComplete || isUpdatingModelNames) return;
if (!e.getActionCommand().equals(Constant.COMBO_BOX_CHANGED) || !isInitializationComplete || isUpdatingModelNames) return;

isUpdatingModelNames = true;

try {
JComboBox<?> comboBox = (JComboBox<?>) e.getSource();
ModelProvider modelProvider = (ModelProvider) comboBox.getSelectedItem();
if (modelProvider != null) {
// Update the selectedProvider in DevoxxGenieStateService
DevoxxGenieStateService.getInstance().setSelectedProvider(modelProvider.getName());

updateModelNamesComboBox(modelProvider.getName());
modelNameComboBox.setRenderer(new ModelInfoRenderer()); // Re-apply the renderer
modelNameComboBox.revalidate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public static DevoxxGenieStateService getInstance() {
private Integer maxSearchResults = MAX_SEARCH_RESULTS;

// Last selected language model
@Getter
@Setter
private String selectedProvider;
private String selectedLanguageModel;

Expand Down Expand Up @@ -160,4 +162,5 @@ public void setModelWindowContext(ModelProvider provider, String modelName, int
public void setLanguageModels(List<LanguageModel> models) {
this.languageModels = new ArrayList<>(models);
}

}
1 change: 1 addition & 0 deletions src/main/resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
<LI>Feat #148: Create custom commands</LI>
<LI>Feat #157: Calc tokens for directory</LI>
<LI>Fix #153: Use the "Copy Project" settings when using "Add Directory to Context Window"</LI>
<LI>Feat #159: Introduce variable TokenCalculator based on selected LLM Provider</LI>
</UL>
<h2>v0.2.2</h2>
<UL>
Expand Down

0 comments on commit 9517522

Please sign in to comment.