diff --git a/AutoTuning/src/main/java/RandomSearch/ContinuousParameter.java b/AutoTuning/src/main/java/RandomSearch/ContinuousParameter.java index a34a809..0179a4c 100644 --- a/AutoTuning/src/main/java/RandomSearch/ContinuousParameter.java +++ b/AutoTuning/src/main/java/RandomSearch/ContinuousParameter.java @@ -13,12 +13,22 @@ public class ContinuousParameter implements ContinuousParameterInterface { private double currentValue; public ContinuousParameter(String name, double defaultValue, double minValue, double maxValue) { + this(name, defaultValue, minValue, maxValue, 0); + } + + public ContinuousParameter(String name, double currentValue) { + this(name, -1, -1, -1 , currentValue); + } + + public ContinuousParameter(String name, double defaultValue, double minValue, double maxValue, double currentValue) { this.name = name; this.defaultValue = defaultValue; this.minValue = minValue; this.maxValue = maxValue; + this.currentValue = currentValue; } + public String getType() { return type; } diff --git a/AutoTuning/src/main/java/RandomSearch/HyperparameterSpace.java b/AutoTuning/src/main/java/RandomSearch/HyperparameterSpace.java index b298938..66b0473 100644 --- a/AutoTuning/src/main/java/RandomSearch/HyperparameterSpace.java +++ b/AutoTuning/src/main/java/RandomSearch/HyperparameterSpace.java @@ -1,7 +1,6 @@ package RandomSearch; import de.viadee.xai.anchor.adapter.tabular.TabularInstance; -import scala.reflect.internal.Trees; import java.util.ArrayList; import java.util.List; @@ -14,16 +13,34 @@ public class HyperparameterSpace { private double coverage = 0; private long runtime = 0; - private List hyperParameters = new ArrayList(); + private List hyperParameters; public HyperparameterSpace() { - this.hyperParameters.add(new IntegerParameter("beamsize", 2, 1, 30)); - this.hyperParameters.add(new ContinuousParameter("tau", 1, 0.1, 1.0)); - this.hyperParameters.add(new ContinuousParameter("delta", 0.1, 0.1, 0.5)); - this.hyperParameters.add(new ContinuousParameter("epsilon", 0.1, 0.1, 0.5)); - this.hyperParameters.add(new ContinuousParameter("tauDiscrepancy", 0.05, 0.01, 0.1)); - this.hyperParameters.add(new IntegerParameter("initSampleCount", 1, 1, 10)); + this(null); + } + + public HyperparameterSpace(List hyperParameters) { + + if (hyperParameters != null) { + this.hyperParameters = hyperParameters; + } else { + this.hyperParameters = fillWithDefaults(); + } + } + + private List fillWithDefaults() { + + List parameters = new ArrayList(); + + parameters.add(new IntegerParameter("beamsize", 2, 1, 30)); + parameters.add(new ContinuousParameter("tau", 1, 0.1, 1.0)); + parameters.add(new ContinuousParameter("delta", 0.1, 0.1, 0.5)); + parameters.add(new ContinuousParameter("epsilon", 0.1, 0.1, 0.5)); + parameters.add(new ContinuousParameter("tauDiscrepancy", 0.05, 0.01, 0.1)); + parameters.add(new IntegerParameter("initSampleCount", 1, 1, 10)); + + return parameters; } public void setRandomHyperparameterSpace() { diff --git a/AutoTuning/src/main/java/RandomSearch/IntegerParameter.java b/AutoTuning/src/main/java/RandomSearch/IntegerParameter.java index 81839ab..ab7ad25 100644 --- a/AutoTuning/src/main/java/RandomSearch/IntegerParameter.java +++ b/AutoTuning/src/main/java/RandomSearch/IntegerParameter.java @@ -10,14 +10,26 @@ public class IntegerParameter implements IntegerParameterInterface { private int maxValue; private int currentValue; + public IntegerParameter(String name, int defaultValue, int minValue, int maxValue) { + this(name, defaultValue, minValue, maxValue, -1); + } + + public IntegerParameter(String name, int currentValue) { + this(name, -1, -1, -1, currentValue); + } + + public IntegerParameter(String name, int defaultValue, int minValue, int maxValue, int currentValue) { this.name = name; this.defaultValue = defaultValue; this.minValue = minValue; this.maxValue = maxValue; + this.currentValue = currentValue; } - public String getType() { return type; } + public String getType() { + return type; + } public String getName() { return name; diff --git a/AutoTuning/src/main/java/TitanicDataset.java b/AutoTuning/src/main/java/TitanicDataset.java index bd00201..2c77fb4 100644 --- a/AutoTuning/src/main/java/TitanicDataset.java +++ b/AutoTuning/src/main/java/TitanicDataset.java @@ -1,3 +1,5 @@ +package data; + import de.viadee.xai.anchor.adapter.tabular.AnchorTabular; import de.viadee.xai.anchor.adapter.tabular.column.DoubleColumn; import de.viadee.xai.anchor.adapter.tabular.column.IntegerColumn; @@ -23,7 +25,7 @@ public class TitanicDataset { /** * @return the {@link AnchorTabular} object that contains the training data and its definitions */ - static AnchorTabular createTabularTrainingDefinition() { + public static AnchorTabular createTabularTrainingDefinition() { InputStream trainingDataStream = ClassLoader.getSystemResourceAsStream("Titanic/train.csv"); if (trainingDataStream == null) throw new RuntimeException("Could not load data"); @@ -56,7 +58,7 @@ static AnchorTabular createTabularTrainingDefinition() { /** * @return the {@link AnchorTabular} object that contains the test data and its definitions */ - static AnchorTabular createTabularTestDefinition() { + public static AnchorTabular createTabularTestDefinition() { // The following implementation is very much similar to the above method. // It is contained in an own block to increase the tutorial's readability // Main difference: no target label is included in test set data @@ -103,7 +105,7 @@ public Serializable apply(Serializable serializable) { /** * @return the labels of the test set as specified in gender_submission. */ - static int[] readTestLabels() { + public static int[] readTestLabels() { InputStream trainingDataStream = ClassLoader.getSystemResourceAsStream("Titanic/gender_submission.csv"); if (trainingDataStream == null) throw new RuntimeException("Could not load data"); diff --git a/AutoTuning/src/test/java/Initializer/FileInitializerTest.java b/AutoTuning/src/test/java/Initializer/FileInitializerTest.java new file mode 100644 index 0000000..c632efa --- /dev/null +++ b/AutoTuning/src/test/java/Initializer/FileInitializerTest.java @@ -0,0 +1,49 @@ +package Initializer; + +import org.junit.Assert; +import org.junit.Test; + +public class FileInitializerTest { + + @Test(expected = RuntimeException.class) + public void testEmptyPath() { + + // Given + String path = ""; + + // When + FileInitializer fi = new FileInitializer(path); + } + + @Test + public void testNotExistingFile() { + + // Given + String path = "test"; + + // When + FileInitializer fi = new FileInitializer(path); + fi.setExtension(); + String extension = fi.getExtension(); + + // Then + Assert.assertFalse(extension == null); + } + + @Test + public void testExcelFile() { + + // Given + String path = "src/test/resources/train.csv"; + + // When + FileInitializer fi = new FileInitializer(path); + fi.setExtension(); + String extension = fi.getExtension(); + + // Then + Assert.assertTrue(extension.equals("csv")); + + } + +} \ No newline at end of file diff --git a/AutoTuning/src/test/java/Initializer/TabularInitializerTest.java b/AutoTuning/src/test/java/Initializer/TabularInitializerTest.java new file mode 100644 index 0000000..6d3e1a2 --- /dev/null +++ b/AutoTuning/src/test/java/Initializer/TabularInitializerTest.java @@ -0,0 +1,87 @@ +package Initializer; + +import de.viadee.xai.anchor.adapter.tabular.util.CSVReader; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; + +public class TabularInitializerTest { + + @Test + public void createTabularTrainingDefinition() { + + // Given + String path = "train.csv"; + + // Then + TabularInitializer.createTabularDefinition(path, 0, true); + + } + + @Test + public void testGetColumnDatatype() { + + // Given + String path = "train.csv"; + InputStream trainingDataStream = ClassLoader.getSystemResourceAsStream(path); + Collection strings = null; + try { + strings = CSVReader.readCSV(trainingDataStream, false); + } catch (IOException e) { + + } + + // When + int[] result = TabularInitializer.getColumnDataTypes(strings); + + } + + @Test + public void testEmptyRowDatatypes() { + // Given + String[] row = {}; + Collection dataframe = new ArrayList<>(); + dataframe.add(row); + + // When + int[] result = TabularInitializer.getColumnDataTypes(dataframe); + + // Then + Assert.assertTrue(result.length == 0); + } + + @Test + public void testIntRowDatatypes() { + // Given + String[] row = {"1", "2", "3", "4"}; + Collection dataframe = new ArrayList<>(); + dataframe.add(row); + + // When + int[] result = TabularInitializer.getColumnDataTypes(dataframe); + int[] test ={0, 0, 0, 0}; + + // Then + Assert.assertTrue(Arrays.equals(result,test)); + } + + @Test + public void testAllDatatypes() { + // Given + String[] row = new String[]{"1", "Hallo", "Test123.3", "3.12"}; + Collection dataframe = new ArrayList<>(); + dataframe.add(row); + + // When + int[] result = TabularInitializer.getColumnDataTypes(dataframe); + int[] test ={0, 2, 2, 1}; + + // Then + Assert.assertTrue(Arrays.equals(result,test)); + } +} \ No newline at end of file