diff --git a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/dynamicInspection/VariableInit.java b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/dynamicInspection/VariableInit.java index 830d04a..afc91a3 100644 --- a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/dynamicInspection/VariableInit.java +++ b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/dynamicInspection/VariableInit.java @@ -27,10 +27,15 @@ public class VariableInit { - public static int parameterSize; + private static int parameterSize; + private static String tensorShapeDimension; + private static int[] tensorShapeDimensions; public static String variableInitHelper(@NotNull PsiMethod method) { parameterSize = TornadoSettingState.getInstance().parameterSize; + tensorShapeDimension = TornadoSettingState.getInstance().tensorShapeDimensions; + tensorShapeDimensions = convertShapeStringToIntArray(tensorShapeDimension); + ArrayList parametersName = new ArrayList<>(); ArrayList parametersType = new ArrayList<>(); for (PsiParameter parameter : method.getParameterList().getParameters()) { @@ -89,6 +94,7 @@ private static String lookupBoxedTypes(String type, String name, int size){ case "VectorDouble", "VectorDouble2", "VectorDouble3", "VectorDouble4", "VectorDouble8", "VectorDouble16" -> vectorInit(name, type, "Double"); case "VectorHalf", "VectorHalf2", "VectorHalf3", "VectorHalf4", "VectorHalf8", "VectorHalf16" -> vectorHalfInit(name, type); case "KernelContext" -> " = new KernelContext();"; + case "TensorByte", "TensorFP16", "TensorFP32", "TensorFP64", "TensorInt16", "TensorInt32", "TensorInt64" -> tensorInit(type); default -> ""; }; } @@ -149,6 +155,32 @@ private static String vectorHalfInit(String name, String type){ name + ".fill(new HalfFloat(" + generateValueByType("HalfFloat") + "));"; } + private static String tensorInit(String type){ + StringBuilder builder = new StringBuilder(); + builder.append(" = new ").append(type).append("(").append("new Shape("); + + for (int i = 0; i < tensorShapeDimensions.length; i++){ + builder.append(tensorShapeDimensions[i]); + if (i < tensorShapeDimensions.length - 1){ + builder.append(", "); + } + } + builder.append("));").append("\n"); + return builder.toString(); + } + + private static int[] convertShapeStringToIntArray(String shapeString){ + String[] stringArray = shapeString.split(","); + int[] numbers = new int[stringArray.length]; + + for (int i = 0; i < stringArray.length; i++) { + if (!stringArray[i].trim().isEmpty()) { + numbers[i] = Integer.parseInt(stringArray[i].trim()); + } + } + return numbers; + } + private static String generateValueByType(String type){ Random r = new Random(); return switch (type) { diff --git a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingConfiguration.java b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingConfiguration.java index f0b426c..cc2a67d 100644 --- a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingConfiguration.java +++ b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingConfiguration.java @@ -57,6 +57,7 @@ public boolean isModified() { boolean modified = !mySettingsComponent.getJdk().equals(settings.JdkPath); modified |= !mySettingsComponent.getTornadoEnvPath().equals(settings.TornadoRoot); modified |= mySettingsComponent.getMaxArraySize() != settings.parameterSize; + modified |= mySettingsComponent.getTensorShapeDimensions() != settings.tensorShapeDimensions; modified |= mySettingsComponent.isSaveFileEnabled() != settings.saveFileEnabled; modified |= !mySettingsComponent.getFileSaveLocation().equals(settings.fileSaveLocation); return modified; @@ -74,6 +75,7 @@ public void apply() throws ConfigurationException { settings.TornadoRoot = mySettingsComponent.getTornadoEnvPath(); settings.JdkPath = mySettingsComponent.getJdk(); settings.parameterSize = mySettingsComponent.getMaxArraySize(); + settings.tensorShapeDimensions = mySettingsComponent.getTensorShapeDimensions(); settings.saveFileEnabled = mySettingsComponent.isSaveFileEnabled(); settings.fileSaveLocation = mySettingsComponent.getFileSaveLocation(); } @@ -85,6 +87,7 @@ public void reset() { mySettingsComponent.setTornadoEnvPath(settings.TornadoRoot); mySettingsComponent.setMyJdk(settings.JdkPath); mySettingsComponent.setMaxArraySize(settings.parameterSize); + mySettingsComponent.setTensorShapeDimensions(settings.tensorShapeDimensions); mySettingsComponent.setSaveFileEnabled(settings.saveFileEnabled); mySettingsComponent.setFileSaveLocation(settings.fileSaveLocation); } diff --git a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingState.java b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingState.java index 3819b98..33b22f4 100644 --- a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingState.java +++ b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/ui/settings/TornadoSettingState.java @@ -40,6 +40,7 @@ public class TornadoSettingState implements PersistentStateComponent
" + INNER_COMMENT + "
")) + .addLabeledComponent(new JBLabel(" "), new JLabel("
" + INNER_COMMENT + "
")) .addVerticalGap(10) .getPanel(); JPanel dynamicInspectionPanel = FormBuilder.createFormBuilder() .addLabeledComponent(new JBLabel("Max array size:"), myMaxArraySize, 1) - .addLabeledComponent(new JBLabel(" "), new JLabel("
" + MessageBundle.message("ui.settings.comment.size") + "
")) + .addLabeledComponent(new JBLabel(" "), new JLabel("
" + MessageBundle.message("ui.settings.max.array.size") + "
")) + .addLabeledComponent("Tensor shape dimensions:", tensorShapeDimensions) + .addLabeledComponent(new JBLabel(" "), new JLabel("
" + MessageBundle.message("ui.settings.tensor.shape.dimensions.doc") + "
")) .getPanel(); dynamicInspectionPanel.setBorder(IdeBorderFactory.createTitledBorder(MessageBundle.message("ui.settings.group.dynamic"))); JPanel debugPanel = FormBuilder.createFormBuilder() .addComponent(saveFileCheckbox) - .addLabeledComponent(new JBLabel(" "), new JLabel("
" + MessageBundle.message("ui.settings.comment.debug.file") + "
")) + .addLabeledComponent(new JBLabel(" "), new JLabel("
" + MessageBundle.message("ui.settings.comment.debug.file") + "
")) .addLabeledComponent(new JBLabel("Save Location:"), fileSaveLocationField) .getPanel(); @@ -144,6 +147,21 @@ public void setMaxArraySize(int size) { myMaxArraySize.setText(String.valueOf(size)); } + public String getTensorShapeDimensions() { + if (tensorShapeDimensions.getText().isEmpty() || Objects.equals(tensorShapeDimensions.getText(), "0")) { + return ""; + } + return tensorShapeDimensions.getText(); + } + + public void setTensorShapeDimensions(String size) { + tensorShapeDimensions.setText(size); + } + + public boolean isTensorShapeConfigured() { + return !tensorShapeDimensions.getText().isEmpty(); + } + public boolean isSaveFileEnabled() { return saveFileCheckbox.isSelected(); } @@ -160,11 +178,39 @@ public void setFileSaveLocation(String path) { fileSaveLocationField.setText(path); } + private static String evaluateConditionsOfUserDefinedShape(String shapeString){ + String[] stringArray = shapeString.split(","); + int[] numbers = new int[stringArray.length]; + + for (int i = 0; i < stringArray.length; i++) { + String trimmedValue = stringArray[i].trim(); + try { + numbers[i] = Integer.parseInt(trimmedValue); + // in case the input is negative + if (numbers[i] < 0) { + return MessageBundle.message("ui.settings.validation.shape.dimensions.negative"); + } + } catch (NumberFormatException e) { + // in case the input is not even a number + System.out.println("Invalid input: " + trimmedValue + " is not a number."); + return MessageBundle.message("ui.settings.validation.shape.dimensions"); + } catch (IllegalArgumentException e) { + // in case the input is a float or double + System.out.println("Invalid input: " + trimmedValue + " is a float or double number."); + return MessageBundle.message("ui.settings.validation.shape.dimensions.float"); + } + } + return ""; + } + public String isValidPath() { String path = myTornadoEnv.getText() + "/setvars.sh"; String parameterSize = myMaxArraySize.getText(); AtomicReference stringAtomicReference = new AtomicReference<>(); stringAtomicReference.set(""); + if (isTensorShapeConfigured()) { + return evaluateConditionsOfUserDefinedShape(tensorShapeDimensions.getText()); + } if (isSaveFileEnabled()) { if (StringUtil.isEmpty(path)) return MessageBundle.message("ui.settings.validation.emptyTornadovm"); diff --git a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/util/TornadoTWTask.java b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/util/TornadoTWTask.java index caed446..236f7b4 100644 --- a/src/main/java/uk/ac/manchester/beehive/tornado/plugins/util/TornadoTWTask.java +++ b/src/main/java/uk/ac/manchester/beehive/tornado/plugins/util/TornadoTWTask.java @@ -216,7 +216,7 @@ private static String getImportCode(PsiFile file) { PsiImportStatement[] importStatements = importList.getImportStatements(); for (PsiImportStatement importStatement : importStatements) { String importText = importStatement.getText(); - if (isJunit(importText)) { + if (statementImportsJunit(importText)) { continue; } importCodeBlock.append(importStatement.getText()); @@ -244,8 +244,8 @@ public static String getImportCodeBlock() { return importCodeBlock; } - private static boolean isJunit(String importStatement) { - return importStatement.equals("import org.junit.Test;"); + private static boolean statementImportsJunit(String importStatement) { + return importStatement.contains("import org.junit."); } } diff --git a/src/main/resources/messages/plugin_en.properties b/src/main/resources/messages/plugin_en.properties index 102eed5..98aadf7 100644 --- a/src/main/resources/messages/plugin_en.properties +++ b/src/main/resources/messages/plugin_en.properties @@ -31,11 +31,12 @@ inspection.traps.throws=TornadoVM: Incompatible thrown types Exception in functi # ui ui.settings.comment.env=The environment variable file for TornadoVM is usually \"TornadoVM/setvars.sh\". \ - This file allows the plugin to call your host's TornadoVM for further analysis of Tornado methods. -ui.settings.comment.size=Max array size specifies the length of Java variables when automatically initialized \ + This file allows the plugin to call your host's TornadoVM for further analysis of TornadoVM methods. +ui.settings.max.array.size=Max array size specifies the length of Java variables when automatically initialized \ by TornadoInsight. For example, when the parameter size is set to 32, and the type of a \ - parameter in a TornadoVM task is IntArray, it will It initialises an IntArray of length 32 \ + parameter in a TornadoVM task is IntArray, TornadoInsight creates an IntArray of length 32 \ and fills it with random values. +ui.settings.tensor.shape.dimensions.doc=Tensor shape dimensions define the dimensions of a TornadoVM tensor type as a list of integer values separated by commas. For example, to set a three-dimensional shape: 16, 1, 1. ui.settings.comment.debug.file=Saves an internally generated file for debugging purposes. This feature is not intended for regular users. ui.settings.label.tornado=TornadoVM root: ui.settings.label.java=Path to Java 21: @@ -45,6 +46,9 @@ ui.settings.group.debugging=Debug options ui.settings.group.dynamic=Dynamic Inspection ui.settings.validation.emptySize=Empty Parameter size ui.settings.validation.invalidSize=Parameter scale needs to be greater than 0 and less than 65534 +ui.settings.validation.shape.dimensions.negative=Invalid shape for a tensor. You defined at least one dimension as negative. +ui.settings.validation.shape.dimensions=Invalid shape for a tensor. Please define the shape as a list of integer values separated by commas. +ui.settings.validation.shape.dimensions.float=Invalid format of a dimension within the shape. Please define the shape as a list of integer values separated by commas. ui.settings.validation.emptyTornadovm=Empty TornadoVM path ui.settings.validation.emptyJava=Empty Java path ui.settings.validation.emptySave=Empty save location