Skip to content

Commit

Permalink
Automated initialization of tabular data sets #13
Browse files Browse the repository at this point in the history
  • Loading branch information
B96 committed Jun 21, 2019
1 parent a1549dd commit 413a59c
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 12 deletions.
10 changes: 10 additions & 0 deletions AutoTuning/src/main/java/RandomSearch/ContinuousParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@ public class ContinuousParameter implements ContinuousParameterInterface {
private double currentValue;

public ContinuousParameter(String name, double defaultValue, double minValue, double maxValue) {
this(name, defaultValue, minValue, maxValue, 0);
}

public ContinuousParameter(String name, double currentValue) {
this(name, -1, -1, -1 , currentValue);
}

public ContinuousParameter(String name, double defaultValue, double minValue, double maxValue, double currentValue) {
this.name = name;
this.defaultValue = defaultValue;
this.minValue = minValue;
this.maxValue = maxValue;
this.currentValue = currentValue;
}


public String getType() {
return type;
}
Expand Down
33 changes: 25 additions & 8 deletions AutoTuning/src/main/java/RandomSearch/HyperparameterSpace.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package RandomSearch;

import de.viadee.xai.anchor.adapter.tabular.TabularInstance;
import scala.reflect.internal.Trees;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -14,16 +13,34 @@ public class HyperparameterSpace {
private double coverage = 0;
private long runtime = 0;

private List<Parameter> hyperParameters = new ArrayList<Parameter>();
private List<Parameter> hyperParameters;


public HyperparameterSpace() {
this.hyperParameters.add(new IntegerParameter("beamsize", 2, 1, 30));
this.hyperParameters.add(new ContinuousParameter("tau", 1, 0.1, 1.0));
this.hyperParameters.add(new ContinuousParameter("delta", 0.1, 0.1, 0.5));
this.hyperParameters.add(new ContinuousParameter("epsilon", 0.1, 0.1, 0.5));
this.hyperParameters.add(new ContinuousParameter("tauDiscrepancy", 0.05, 0.01, 0.1));
this.hyperParameters.add(new IntegerParameter("initSampleCount", 1, 1, 10));
this(null);
}

public HyperparameterSpace(List<Parameter> hyperParameters) {

if (hyperParameters != null) {
this.hyperParameters = hyperParameters;
} else {
this.hyperParameters = fillWithDefaults();
}
}

private List<Parameter> fillWithDefaults() {

List<Parameter> parameters = new ArrayList<Parameter>();

parameters.add(new IntegerParameter("beamsize", 2, 1, 30));
parameters.add(new ContinuousParameter("tau", 1, 0.1, 1.0));
parameters.add(new ContinuousParameter("delta", 0.1, 0.1, 0.5));
parameters.add(new ContinuousParameter("epsilon", 0.1, 0.1, 0.5));
parameters.add(new ContinuousParameter("tauDiscrepancy", 0.05, 0.01, 0.1));
parameters.add(new IntegerParameter("initSampleCount", 1, 1, 10));

return parameters;
}

public void setRandomHyperparameterSpace() {
Expand Down
14 changes: 13 additions & 1 deletion AutoTuning/src/main/java/RandomSearch/IntegerParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,26 @@ public class IntegerParameter implements IntegerParameterInterface {
private int maxValue;
private int currentValue;


public IntegerParameter(String name, int defaultValue, int minValue, int maxValue) {
this(name, defaultValue, minValue, maxValue, -1);
}

public IntegerParameter(String name, int currentValue) {
this(name, -1, -1, -1, currentValue);
}

public IntegerParameter(String name, int defaultValue, int minValue, int maxValue, int currentValue) {
this.name = name;
this.defaultValue = defaultValue;
this.minValue = minValue;
this.maxValue = maxValue;
this.currentValue = currentValue;
}

public String getType() { return type; }
public String getType() {
return type;
}

public String getName() {
return name;
Expand Down
8 changes: 5 additions & 3 deletions AutoTuning/src/main/java/TitanicDataset.java
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package data;

import de.viadee.xai.anchor.adapter.tabular.AnchorTabular;
import de.viadee.xai.anchor.adapter.tabular.column.DoubleColumn;
import de.viadee.xai.anchor.adapter.tabular.column.IntegerColumn;
Expand All @@ -23,7 +25,7 @@ public class TitanicDataset {
/**
* @return the {@link AnchorTabular} object that contains the training data and its definitions
*/
static AnchorTabular createTabularTrainingDefinition() {
public static AnchorTabular createTabularTrainingDefinition() {
InputStream trainingDataStream = ClassLoader.getSystemResourceAsStream("Titanic/train.csv");
if (trainingDataStream == null)
throw new RuntimeException("Could not load data");
Expand Down Expand Up @@ -56,7 +58,7 @@ static AnchorTabular createTabularTrainingDefinition() {
/**
* @return the {@link AnchorTabular} object that contains the test data and its definitions
*/
static AnchorTabular createTabularTestDefinition() {
public static AnchorTabular createTabularTestDefinition() {
// The following implementation is very much similar to the above method.
// It is contained in an own block to increase the tutorial's readability
// Main difference: no target label is included in test set data
Expand Down Expand Up @@ -103,7 +105,7 @@ public Serializable apply(Serializable serializable) {
/**
* @return the labels of the test set as specified in gender_submission.
*/
static int[] readTestLabels() {
public static int[] readTestLabels() {
InputStream trainingDataStream = ClassLoader.getSystemResourceAsStream("Titanic/gender_submission.csv");
if (trainingDataStream == null)
throw new RuntimeException("Could not load data");
Expand Down
49 changes: 49 additions & 0 deletions AutoTuning/src/test/java/Initializer/FileInitializerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package Initializer;

import org.junit.Assert;
import org.junit.Test;

public class FileInitializerTest {

@Test(expected = RuntimeException.class)
public void testEmptyPath() {

// Given
String path = "";

// When
FileInitializer fi = new FileInitializer(path);
}

@Test
public void testNotExistingFile() {

// Given
String path = "test";

// When
FileInitializer fi = new FileInitializer(path);
fi.setExtension();
String extension = fi.getExtension();

// Then
Assert.assertFalse(extension == null);
}

@Test
public void testExcelFile() {

// Given
String path = "src/test/resources/train.csv";

// When
FileInitializer fi = new FileInitializer(path);
fi.setExtension();
String extension = fi.getExtension();

// Then
Assert.assertTrue(extension.equals("csv"));

}

}
87 changes: 87 additions & 0 deletions AutoTuning/src/test/java/Initializer/TabularInitializerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package Initializer;

import de.viadee.xai.anchor.adapter.tabular.util.CSVReader;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

public class TabularInitializerTest {

@Test
public void createTabularTrainingDefinition() {

// Given
String path = "train.csv";

// Then
TabularInitializer.createTabularDefinition(path, 0, true);

}

@Test
public void testGetColumnDatatype() {

// Given
String path = "train.csv";
InputStream trainingDataStream = ClassLoader.getSystemResourceAsStream(path);
Collection<String[]> strings = null;
try {
strings = CSVReader.readCSV(trainingDataStream, false);
} catch (IOException e) {

}

// When
int[] result = TabularInitializer.getColumnDataTypes(strings);

}

@Test
public void testEmptyRowDatatypes() {
// Given
String[] row = {};
Collection<String[]> dataframe = new ArrayList<>();
dataframe.add(row);

// When
int[] result = TabularInitializer.getColumnDataTypes(dataframe);

// Then
Assert.assertTrue(result.length == 0);
}

@Test
public void testIntRowDatatypes() {
// Given
String[] row = {"1", "2", "3", "4"};
Collection<String[]> dataframe = new ArrayList<>();
dataframe.add(row);

// When
int[] result = TabularInitializer.getColumnDataTypes(dataframe);
int[] test ={0, 0, 0, 0};

// Then
Assert.assertTrue(Arrays.equals(result,test));
}

@Test
public void testAllDatatypes() {
// Given
String[] row = new String[]{"1", "Hallo", "Test123.3", "3.12"};
Collection<String[]> dataframe = new ArrayList<>();
dataframe.add(row);

// When
int[] result = TabularInitializer.getColumnDataTypes(dataframe);
int[] test ={0, 2, 2, 1};

// Then
Assert.assertTrue(Arrays.equals(result,test));
}
}

0 comments on commit 413a59c

Please sign in to comment.