Skip to content

Commit

Permalink
Merge pull request #28 from stratika/feat/tensors/support
Browse files Browse the repository at this point in the history
[feat] Support for TornadoVM Tensor types
stratika authored Sep 12, 2024
2 parents 0ef4630 + 254e554 commit ed1db50
Showing 6 changed files with 96 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -27,10 +27,15 @@

public class VariableInit {

public static int parameterSize;
private static int parameterSize;
private static String tensorShapeDimension;
private static int[] tensorShapeDimensions;

public static String variableInitHelper(@NotNull PsiMethod method) {
parameterSize = TornadoSettingState.getInstance().parameterSize;
tensorShapeDimension = TornadoSettingState.getInstance().tensorShapeDimensions;
tensorShapeDimensions = convertShapeStringToIntArray(tensorShapeDimension);

ArrayList<String> parametersName = new ArrayList<>();
ArrayList<String> parametersType = new ArrayList<>();
for (PsiParameter parameter : method.getParameterList().getParameters()) {
@@ -89,6 +94,7 @@ private static String lookupBoxedTypes(String type, String name, int size){
case "VectorDouble", "VectorDouble2", "VectorDouble3", "VectorDouble4", "VectorDouble8", "VectorDouble16" -> vectorInit(name, type, "Double");
case "VectorHalf", "VectorHalf2", "VectorHalf3", "VectorHalf4", "VectorHalf8", "VectorHalf16" -> vectorHalfInit(name, type);
case "KernelContext" -> " = new KernelContext();";
case "TensorByte", "TensorFP16", "TensorFP32", "TensorFP64", "TensorInt16", "TensorInt32", "TensorInt64" -> tensorInit(type);
default -> "";
};
}
@@ -149,6 +155,32 @@ private static String vectorHalfInit(String name, String type){
name + ".fill(new HalfFloat(" + generateValueByType("HalfFloat") + "));";
}

private static String tensorInit(String type){
StringBuilder builder = new StringBuilder();
builder.append(" = new ").append(type).append("(").append("new Shape(");

for (int i = 0; i < tensorShapeDimensions.length; i++){
builder.append(tensorShapeDimensions[i]);
if (i < tensorShapeDimensions.length - 1){
builder.append(", ");
}
}
builder.append("));").append("\n");
return builder.toString();
}

private static int[] convertShapeStringToIntArray(String shapeString){
String[] stringArray = shapeString.split(",");
int[] numbers = new int[stringArray.length];

for (int i = 0; i < stringArray.length; i++) {
if (!stringArray[i].trim().isEmpty()) {
numbers[i] = Integer.parseInt(stringArray[i].trim());
}
}
return numbers;
}

private static String generateValueByType(String type){
Random r = new Random();
return switch (type) {
Original file line number Diff line number Diff line change
@@ -57,6 +57,7 @@ public boolean isModified() {
boolean modified = !mySettingsComponent.getJdk().equals(settings.JdkPath);
modified |= !mySettingsComponent.getTornadoEnvPath().equals(settings.TornadoRoot);
modified |= mySettingsComponent.getMaxArraySize() != settings.parameterSize;
modified |= mySettingsComponent.getTensorShapeDimensions() != settings.tensorShapeDimensions;
modified |= mySettingsComponent.isSaveFileEnabled() != settings.saveFileEnabled;
modified |= !mySettingsComponent.getFileSaveLocation().equals(settings.fileSaveLocation);
return modified;
@@ -74,6 +75,7 @@ public void apply() throws ConfigurationException {
settings.TornadoRoot = mySettingsComponent.getTornadoEnvPath();
settings.JdkPath = mySettingsComponent.getJdk();
settings.parameterSize = mySettingsComponent.getMaxArraySize();
settings.tensorShapeDimensions = mySettingsComponent.getTensorShapeDimensions();
settings.saveFileEnabled = mySettingsComponent.isSaveFileEnabled();
settings.fileSaveLocation = mySettingsComponent.getFileSaveLocation();
}
@@ -85,6 +87,7 @@ public void reset() {
mySettingsComponent.setTornadoEnvPath(settings.TornadoRoot);
mySettingsComponent.setMyJdk(settings.JdkPath);
mySettingsComponent.setMaxArraySize(settings.parameterSize);
mySettingsComponent.setTensorShapeDimensions(settings.tensorShapeDimensions);
mySettingsComponent.setSaveFileEnabled(settings.saveFileEnabled);
mySettingsComponent.setFileSaveLocation(settings.fileSaveLocation);
}
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ public class TornadoSettingState implements PersistentStateComponent<TornadoSett
@OptionTag(converter = JdkConverter.class)
public Sdk JdkPath;
public int parameterSize;
public String tensorShapeDimensions;
public boolean isValid;
public boolean saveFileEnabled;
public String fileSaveLocation;
Original file line number Diff line number Diff line change
@@ -61,6 +61,7 @@ public class TornadoSettingsComponent {
private JdkComboBox myJdk;

private final JBTextField myMaxArraySize = new JBTextField(4);
private final JBTextField tensorShapeDimensions = new JBTextField();

public TornadoSettingsComponent() {
jdkModel = ProjectStructureConfigurable.getInstance(ProjectManager.getInstance().getDefaultProject()).getProjectJdksModel();
@@ -86,20 +87,22 @@ public TornadoSettingsComponent() {
JPanel innerGrid = FormBuilder.createFormBuilder()
.addLabeledComponent(new JBLabel("TornadoVM Root:"), myTornadoEnv)
.addLabeledComponent(new JBLabel("Java SDK:"), myJdk)
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray; font-size:15px;'>" + INNER_COMMENT + "</div></html>"))
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray;'>" + INNER_COMMENT + "</div></html>"))
.addVerticalGap(10)
.getPanel();

JPanel dynamicInspectionPanel = FormBuilder.createFormBuilder()
.addLabeledComponent(new JBLabel("Max array size:"), myMaxArraySize, 1)
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray; font-size:15px;'>" + MessageBundle.message("ui.settings.comment.size") + "</div></html>"))
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray;'>" + MessageBundle.message("ui.settings.max.array.size") + "</div></html>"))
.addLabeledComponent("Tensor shape dimensions:", tensorShapeDimensions)
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray;'>" + MessageBundle.message("ui.settings.tensor.shape.dimensions.doc") + "</div></html>"))
.getPanel();

dynamicInspectionPanel.setBorder(IdeBorderFactory.createTitledBorder(MessageBundle.message("ui.settings.group.dynamic")));

JPanel debugPanel = FormBuilder.createFormBuilder()
.addComponent(saveFileCheckbox)
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray; font-size:15px;'>" + MessageBundle.message("ui.settings.comment.debug.file") + "</div></html>"))
.addLabeledComponent(new JBLabel(" "), new JLabel("<html><div style='width:400px; color:gray;'>" + MessageBundle.message("ui.settings.comment.debug.file") + "</div></html>"))
.addLabeledComponent(new JBLabel("Save Location:"), fileSaveLocationField)
.getPanel();

@@ -144,6 +147,21 @@ public void setMaxArraySize(int size) {
myMaxArraySize.setText(String.valueOf(size));
}

public String getTensorShapeDimensions() {
if (tensorShapeDimensions.getText().isEmpty() || Objects.equals(tensorShapeDimensions.getText(), "0")) {
return "";
}
return tensorShapeDimensions.getText();
}

public void setTensorShapeDimensions(String size) {
tensorShapeDimensions.setText(size);
}

public boolean isTensorShapeConfigured() {
return !tensorShapeDimensions.getText().isEmpty();
}

public boolean isSaveFileEnabled() {
return saveFileCheckbox.isSelected();
}
@@ -160,11 +178,39 @@ public void setFileSaveLocation(String path) {
fileSaveLocationField.setText(path);
}

private static String evaluateConditionsOfUserDefinedShape(String shapeString){
String[] stringArray = shapeString.split(",");
int[] numbers = new int[stringArray.length];

for (int i = 0; i < stringArray.length; i++) {
String trimmedValue = stringArray[i].trim();
try {
numbers[i] = Integer.parseInt(trimmedValue);
// in case the input is negative
if (numbers[i] < 0) {
return MessageBundle.message("ui.settings.validation.shape.dimensions.negative");
}
} catch (NumberFormatException e) {
// in case the input is not even a number
System.out.println("Invalid input: " + trimmedValue + " is not a number.");
return MessageBundle.message("ui.settings.validation.shape.dimensions");
} catch (IllegalArgumentException e) {
// in case the input is a float or double
System.out.println("Invalid input: " + trimmedValue + " is a float or double number.");
return MessageBundle.message("ui.settings.validation.shape.dimensions.float");
}
}
return "";
}

public String isValidPath() {
String path = myTornadoEnv.getText() + "/setvars.sh";
String parameterSize = myMaxArraySize.getText();
AtomicReference<String> stringAtomicReference = new AtomicReference<>();
stringAtomicReference.set("");
if (isTensorShapeConfigured()) {
return evaluateConditionsOfUserDefinedShape(tensorShapeDimensions.getText());
}
if (isSaveFileEnabled()) {
if (StringUtil.isEmpty(path))
return MessageBundle.message("ui.settings.validation.emptyTornadovm");
Original file line number Diff line number Diff line change
@@ -216,7 +216,7 @@ private static String getImportCode(PsiFile file) {
PsiImportStatement[] importStatements = importList.getImportStatements();
for (PsiImportStatement importStatement : importStatements) {
String importText = importStatement.getText();
if (isJunit(importText)) {
if (statementImportsJunit(importText)) {
continue;
}
importCodeBlock.append(importStatement.getText());
@@ -244,8 +244,8 @@ public static String getImportCodeBlock() {
return importCodeBlock;
}

private static boolean isJunit(String importStatement) {
return importStatement.equals("import org.junit.Test;");
private static boolean statementImportsJunit(String importStatement) {
return importStatement.contains("import org.junit.");
}

}
10 changes: 7 additions & 3 deletions src/main/resources/messages/plugin_en.properties
Original file line number Diff line number Diff line change
@@ -31,11 +31,12 @@ inspection.traps.throws=TornadoVM: Incompatible thrown types Exception in functi

# ui
ui.settings.comment.env=The environment variable file for TornadoVM is usually \"TornadoVM/setvars.sh\". \
This file allows the plugin to call your host's TornadoVM for further analysis of Tornado methods.
ui.settings.comment.size=Max array size specifies the length of Java variables when automatically initialized \
This file allows the plugin to call your host's TornadoVM for further analysis of TornadoVM methods.
ui.settings.max.array.size=Max array size specifies the length of Java variables when automatically initialized \
by TornadoInsight. For example, when the parameter size is set to 32, and the type of a \
parameter in a TornadoVM task is IntArray, it will It initialises an IntArray of length 32 \
parameter in a TornadoVM task is IntArray, TornadoInsight creates an IntArray of length 32 \
and fills it with random values.
ui.settings.tensor.shape.dimensions.doc=Tensor shape dimensions define the dimensions of a TornadoVM tensor type as a list of integer values separated by commas. For example, to set a three-dimensional shape: 16, 1, 1.
ui.settings.comment.debug.file=Saves an internally generated file for debugging purposes. This feature is not intended for regular users.
ui.settings.label.tornado=TornadoVM root:
ui.settings.label.java=Path to Java 21:
@@ -45,6 +46,9 @@ ui.settings.group.debugging=Debug options
ui.settings.group.dynamic=Dynamic Inspection
ui.settings.validation.emptySize=Empty Parameter size
ui.settings.validation.invalidSize=Parameter scale needs to be greater than 0 and less than 65534
ui.settings.validation.shape.dimensions.negative=Invalid shape for a tensor. You defined at least one dimension as negative.
ui.settings.validation.shape.dimensions=Invalid shape for a tensor. Please define the shape as a list of integer values separated by commas.
ui.settings.validation.shape.dimensions.float=Invalid format of a dimension within the shape. Please define the shape as a list of integer values separated by commas.
ui.settings.validation.emptyTornadovm=Empty TornadoVM path
ui.settings.validation.emptyJava=Empty Java path
ui.settings.validation.emptySave=Empty save location

0 comments on commit ed1db50

Please sign in to comment.