Skip to content

Commit

Permalink
Merge pull request #331 from devoxx/issue-329
Browse files Browse the repository at this point in the history
Fix for OpenRouter calc pricing + Added support for cost table sorting
  • Loading branch information
stephanj authored Nov 8, 2024
2 parents b6c1b77 + d64621c commit 156742c
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -27,6 +29,7 @@ public class OpenRouterChatModelFactory implements ChatModelFactory {

private static final ExecutorService executorService = Executors.newFixedThreadPool(5);
private List<LanguageModel> cachedModels = null;
private static final int PRICE_SCALING_FACTOR = 1_000_000; // To convert to per million tokens

@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
Expand Down Expand Up @@ -78,12 +81,16 @@ public List<LanguageModel> getModels() {
List<Data> models = OpenRouterService.getInstance().getModels();
for (Data model : models) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
// Convert scientific notation prices to regular decimals and scale to per million tokens
double inputCost = convertAndScalePrice(model.getPricing().getPrompt());
double outputCost = convertAndScalePrice(model.getPricing().getCompletion());

LanguageModel languageModel = LanguageModel.builder()
.provider(ModelProvider.OpenRouter)
.modelName(model.getId())
.displayName(model.getName())
.inputCost(model.getPricing().getPrompt())
.outputCost(model.getPricing().getCompletion())
.inputCost(inputCost)
.outputCost(outputCost)
.contextWindow(model.getContextLength() == null ? model.getTopProvider().getContextLength() : model.getContextLength())
.apiKeyUsed(true)
.build();
Expand All @@ -103,4 +110,14 @@ public List<LanguageModel> getModels() {
}
return cachedModels;
}

private double convertAndScalePrice(double price) {
// Convert the price to BigDecimal for precise calculation
BigDecimal bd = BigDecimal.valueOf(price);
// Multiply by 1,000,000 to get price per million tokens
bd = bd.multiply(BigDecimal.valueOf(PRICE_SCALING_FACTOR));
// Round to 6 decimal places
bd = bd.setScale(6, RoundingMode.HALF_UP);
return bd.doubleValue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -552,18 +552,20 @@ private void addDeepSeekModels() {
.build());
}

public @NotNull List<LanguageModel> getModels() {
Map<String, LanguageModel> languageModels = new HashMap<>(models);
addOpenRouterModels(languageModels);
return new ArrayList<>(languageModels.values());
}
@NotNull
public List<LanguageModel> getModels() {
// Create a copy of the current models
Map<String, LanguageModel> modelsCopy = new HashMap<>(models);

private static void addOpenRouterModels(Map<String, LanguageModel> languageModels) {
// Add OpenRouter models if API key exists
OpenRouterChatModelFactory openRouterChatModelFactory = new OpenRouterChatModelFactory();
String apiKey = openRouterChatModelFactory.getApiKey();
if (apiKey != null && !apiKey.isEmpty()) {
new OpenRouterChatModelFactory().getModels().forEach(model -> languageModels.put(ModelProvider.OpenRouter.getName() + ":" + model.getModelName(), model));
openRouterChatModelFactory.getModels().forEach(model ->
modelsCopy.put(ModelProvider.OpenRouter.getName() + ":" + model.getModelName(), model));
}

return new ArrayList<>(modelsCopy.values());
}

public void setModels(Map<String, LanguageModel> models) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,51 +151,15 @@ public void addLoadListener(Runnable listener) {
loadListeners.add(listener);
}

public void setModelCost(ModelProvider provider,
String modelName,
double inputCost,
double outputCost) {
if (DefaultLLMSettingsUtil.isApiKeyBasedProvider(provider)) {
String key = provider.getName() + ":" + modelName;
modelInputCosts.put(key, inputCost);
modelOutputCosts.put(key, outputCost);
private void initializeDefaultCostsIfEmpty() {
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.entrySet()) {
String key = entry.getKey().provider().getName() + ":" + entry.getKey().modelName();
modelInputCosts.put(key, entry.getValue());
}
}

public double getModelInputCost(@NotNull ModelProvider provider, String modelName) {
String key = provider.getName() + ":" + modelName;
double cost = modelInputCosts.getOrDefault(key, 0.0);
if (cost == 0.0) {
DefaultLLMSettingsUtil.CostKey costKey = new DefaultLLMSettingsUtil.CostKey(provider, modelName);
cost = DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.getOrDefault(costKey, 0.0);
if (cost == 0.0) {
// Fallback to similar model names
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.entrySet()) {
if (entry.getKey().provider == provider && entry.getKey().modelName.startsWith(modelName.split("-")[0])) {
cost = entry.getValue();
break;
}
}
}
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_OUTPUT_COSTS.entrySet()) {
String key = entry.getKey().provider().getName() + ":" + entry.getKey().modelName();
modelOutputCosts.put(key, entry.getValue());
}
return cost;
}

private void initializeDefaultCostsIfEmpty() {
// if (modelInputCosts.isEmpty()) {
// DefaultLLMSettingsUtil.initializeDefaultCosts();
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_INPUT_COSTS.entrySet()) {
String key = entry.getKey().provider.getName() + ":" + entry.getKey().modelName;
modelInputCosts.put(key, entry.getValue());
}
// }
// if (modelOutputCosts.isEmpty()) {
// DefaultLLMSettingsUtil.initializeDefaultCosts();
for (Map.Entry<DefaultLLMSettingsUtil.CostKey, Double> entry : DefaultLLMSettingsUtil.DEFAULT_OUTPUT_COSTS.entrySet()) {
String key = entry.getKey().provider.getName() + ":" + entry.getKey().modelName;
modelOutputCosts.put(key, entry.getValue());
}
// }
}

public void setModelWindowContext(ModelProvider provider, String modelName, int windowContext) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
package com.devoxx.genie.ui.settings.costsettings;

import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.service.LLMModelRegistryService;
import com.devoxx.genie.ui.settings.AbstractSettingsComponent;
import com.devoxx.genie.util.LLMProviderUtil;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.ui.ComboBox;
import com.intellij.ui.components.JBScrollPane;
import com.intellij.ui.table.JBTable;
import lombok.Getter;

import javax.swing.*;
import javax.swing.table.DefaultTableCellRenderer;
import javax.swing.table.DefaultTableModel;
import javax.swing.table.TableColumn;
import javax.swing.table.*;
import java.awt.*;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Vector;

import static com.devoxx.genie.ui.settings.costsettings.LanguageModelCostSettingsComponent.ColumnName.CONTEXT_WINDOW;
import java.util.*;
import java.util.List;

public class LanguageModelCostSettingsComponent extends AbstractSettingsComponent {

private final JTable costTable;
private final DefaultTableModel tableModel;
private final SortableTableModel tableModel;

@Getter
public enum ColumnName {
Expand All @@ -48,155 +37,56 @@ public LanguageModelCostSettingsComponent() {
.map(ColumnName::getDisplayName)
.toArray(String[]::new);

tableModel = new DefaultTableModel(columnNames, 0) {
@Override
public boolean isCellEditable(int row, int column) {
return column == ColumnName.INPUT_COST.ordinal() ||
column == ColumnName.OUTPUT_COST.ordinal() ||
column == ColumnName.CONTEXT_WINDOW.ordinal();
}

@Override
public Class<?> getColumnClass(int columnIndex) {
if (columnIndex == ColumnName.PROVIDER.ordinal()) {
return ModelProvider.class;
} else if (columnIndex == ColumnName.INPUT_COST.ordinal() || columnIndex == ColumnName.OUTPUT_COST.ordinal()) {
return Double.class;
} else if (columnIndex == ColumnName.CONTEXT_WINDOW.ordinal()) {
return Integer.class;
}
return String.class;
}
};

tableModel = new SortableTableModel(columnNames);
costTable = new JBTable(tableModel);
costTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION);

// Set custom editors for editable columns
costTable.setDefaultEditor(Double.class, new DefaultCellEditor(new JTextField()) {
@Override
public Object getCellEditorValue() {
try {
return Double.parseDouble((String) super.getCellEditorValue());
} catch (NumberFormatException e) {
return 0.0;
}
}
});

costTable.setDefaultEditor(Integer.class, new DefaultCellEditor(new JTextField()) {
@Override
public Object getCellEditorValue() {
try {
return Integer.parseInt((String) super.getCellEditorValue());
} catch (NumberFormatException e) {
return 0;
}
}
});
// Enable sorting
costTable.setAutoCreateRowSorter(true);
TableRowSorter<TableModel> sorter = new TableRowSorter<>(tableModel);
costTable.setRowSorter(sorter);

setColumnWidths();
setCustomRenderers();
// Set custom comparators for different column types
sorter.setComparator(ColumnName.PROVIDER.ordinal(), String.CASE_INSENSITIVE_ORDER);
sorter.setComparator(ColumnName.MODEL.ordinal(), String.CASE_INSENSITIVE_ORDER);
sorter.setComparator(ColumnName.INPUT_COST.ordinal(), Comparator.comparingDouble(value -> (Double) value));
sorter.setComparator(ColumnName.OUTPUT_COST.ordinal(), Comparator.comparingDouble(value -> (Double) value));
sorter.setComparator(ColumnName.CONTEXT_WINDOW.ordinal(), Comparator.comparingInt(value -> (Integer) value));

ComboBox<ModelProvider> providerComboBox = new ComboBox<>(LLMProviderUtil.getApiKeyEnabledProviders().toArray(new ModelProvider[0]));
costTable.getColumnModel().getColumn(0).setCellEditor(new DefaultCellEditor(providerComboBox));
// Sort by provider by default
List<RowSorter.SortKey> sortKeys = new ArrayList<>();
sortKeys.add(new RowSorter.SortKey(ColumnName.PROVIDER.ordinal(), SortOrder.ASCENDING));
sorter.setSortKeys(sortKeys);

setupColumns();
setCustomRenderers();
loadCurrentCosts();

JScrollPane scrollPane = new JBScrollPane(costTable);
scrollPane.setVerticalScrollBarPolicy(JScrollPane.VERTICAL_SCROLLBAR_AS_NEEDED);
panel.add(scrollPane, BorderLayout.CENTER);

JButton addButton = new JButton("Add Model");
addButton.setEnabled(false);
addButton.setToolTipText("Add a new model not fully implemented yet, we accept PR's :)");
addButton.addActionListener(e -> addNewRow());
panel.add(addButton, BorderLayout.SOUTH);

// Add window context spinner
JSpinner windowContextSpinner = new JSpinner(new SpinnerNumberModel(8000, 1000, 1000000, 1000));

JPanel contextPanel = new JPanel(new FlowLayout(FlowLayout.LEFT));
contextPanel.add(new JLabel("Default Window Context:"));
contextPanel.add(windowContextSpinner);
panel.add(contextPanel, BorderLayout.NORTH);

loadCurrentCosts();
}

public java.util.List<LanguageModel> getModifiedModels() {
java.util.List<LanguageModel> modifiedModels = new ArrayList<>();
for (int i = 0; i < tableModel.getRowCount(); i++) {
String provider = (String) tableModel.getValueAt(i, ColumnName.PROVIDER.ordinal());
String modelName = (String) tableModel.getValueAt(i, ColumnName.MODEL.ordinal());
double inputCost = (Double) tableModel.getValueAt(i, ColumnName.INPUT_COST.ordinal());
double outputCost = (Double) tableModel.getValueAt(i, ColumnName.OUTPUT_COST.ordinal());
Object contextWindowObj = tableModel.getValueAt(i, ColumnName.CONTEXT_WINDOW.ordinal());

int contextWindow = getContextWindow(contextWindowObj);

LanguageModel model = LanguageModel.builder()
.provider(ModelProvider.fromString(provider))
.modelName(modelName)
.inputCost(inputCost)
.outputCost(outputCost)
.contextWindow(contextWindow)
.apiKeyUsed(true)
.build();
modifiedModels.add(model);
}
return modifiedModels;
}

public static int getContextWindow(Object contextWindowObj) {
int contextWindow;
if (contextWindowObj instanceof Integer) {
contextWindow = (Integer) contextWindowObj;
} else if (contextWindowObj instanceof String) {
String contextWindowStr = ((String) contextWindowObj).replaceAll("[^\\d.]", "").split("\\.")[0];
contextWindow = Integer.parseInt(contextWindowStr);
} else {
// Handle unexpected type or throw an exception
throw new IllegalArgumentException("Unexpected type for context window: " + contextWindowObj.getClass());
}
return contextWindow;
private void setupColumns() {
costTable.getColumnModel().getColumn(ColumnName.PROVIDER.ordinal()).setPreferredWidth(60);
costTable.getColumnModel().getColumn(ColumnName.MODEL.ordinal()).setPreferredWidth(220);
costTable.getColumnModel().getColumn(ColumnName.INPUT_COST.ordinal()).setPreferredWidth(60);
costTable.getColumnModel().getColumn(ColumnName.OUTPUT_COST.ordinal()).setPreferredWidth(60);
costTable.getColumnModel().getColumn(ColumnName.CONTEXT_WINDOW.ordinal()).setPreferredWidth(100);
}

private void setCustomRenderers() {
DefaultTableCellRenderer rightRenderer = new DefaultTableCellRenderer();
rightRenderer.setHorizontalAlignment(JLabel.RIGHT);
costTable.getColumnModel().getColumn(CONTEXT_WINDOW.ordinal()).setCellRenderer(rightRenderer);
}

private void setColumnWidths() {
for (ColumnName columnName : ColumnName.values()) {
TableColumn column = costTable.getColumnModel().getColumn(columnName.ordinal());
switch (columnName) {
case PROVIDER, INPUT_COST, OUTPUT_COST -> column.setPreferredWidth(60);
case MODEL -> column.setPreferredWidth(220);
case CONTEXT_WINDOW -> column.setPreferredWidth(100);
}
}
}

private void addNewRow() {
Vector<Object> newRow = new Vector<>();
newRow.add(LLMProviderUtil.getApiKeyEnabledProviders().get(0)); // Default to first provider
newRow.add("");
newRow.add(0.0);
newRow.add(0.0);
newRow.add(8000);
tableModel.addRow(newRow);
ApplicationManager.getApplication().invokeLater(this::scrollToBottom);
}

private void scrollToBottom() {
int lastRowIndex = costTable.getRowCount() - 1;
if (lastRowIndex >= 0) {
Rectangle cellRect = costTable.getCellRect(lastRowIndex, 0, true);
costTable.scrollRectToVisible(cellRect);
}
// Set right alignment for numeric columns
costTable.getColumnModel().getColumn(ColumnName.CONTEXT_WINDOW.ordinal()).setCellRenderer(rightRenderer);
costTable.getColumnModel().getColumn(ColumnName.INPUT_COST.ordinal()).setCellRenderer(rightRenderer);
costTable.getColumnModel().getColumn(ColumnName.OUTPUT_COST.ordinal()).setCellRenderer(rightRenderer);
}

private void loadCurrentCosts() {
tableModel.setRowCount(0);
LLMModelRegistryService.getInstance()
.getModels()
.forEach(model -> tableModel.addRow(new Object[]{
Expand All @@ -208,8 +98,28 @@ private void loadCurrentCosts() {
}));
}

public void reset() {
tableModel.setRowCount(0);
loadCurrentCosts();
private static class SortableTableModel extends DefaultTableModel {
public SortableTableModel(String[] columnNames) {
super(columnNames, 0);
}

@Override
public Class<?> getColumnClass(int column) {
if (column == ColumnName.PROVIDER.ordinal() || column == ColumnName.MODEL.ordinal()) {
return String.class;
} else if (column == ColumnName.INPUT_COST.ordinal() || column == ColumnName.OUTPUT_COST.ordinal()) {
return Double.class;
} else if (column == ColumnName.CONTEXT_WINDOW.ordinal()) {
return Integer.class;
}
return Object.class;
}

@Override
public boolean isCellEditable(int row, int column) {
return column == ColumnName.INPUT_COST.ordinal() ||
column == ColumnName.OUTPUT_COST.ordinal() ||
column == ColumnName.CONTEXT_WINDOW.ordinal();
}
}
}
Loading

0 comments on commit 156742c

Please sign in to comment.