From 9ad50a8c8253a2ea1fbb3fd01f85a1f02ebcd888 Mon Sep 17 00:00:00 2001 From: idhamari Date: Tue, 3 May 2022 09:19:19 +0200 Subject: [PATCH 1/7] add method for evaluation using test subset --- .../arx/examples/Example39_subset.java | 323 ++++++++++++++++++ .../arx/ARXClassificationConfiguration.java | 16 +- .../aggregates/StatisticsClassification.java | 177 ++++++++-- 3 files changed, 490 insertions(+), 26 deletions(-) create mode 100644 src/example/org/deidentifier/arx/examples/Example39_subset.java diff --git a/src/example/org/deidentifier/arx/examples/Example39_subset.java b/src/example/org/deidentifier/arx/examples/Example39_subset.java new file mode 100644 index 000000000..ceaa8c9a6 --- /dev/null +++ b/src/example/org/deidentifier/arx/examples/Example39_subset.java @@ -0,0 +1,323 @@ +/* + * ARX: Powerful Data Anonymization + * Copyright 2012 - 2021 Fabian Prasser and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.deidentifier.arx.examples; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOError; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.text.ParseException; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.deidentifier.arx.ARXAnonymizer; +import org.deidentifier.arx.ARXClassificationConfiguration; +import org.deidentifier.arx.ARXConfiguration; +import org.deidentifier.arx.ARXResult; +import org.deidentifier.arx.AttributeType; +import org.deidentifier.arx.AttributeType.Hierarchy; +import org.deidentifier.arx.Data; +import org.deidentifier.arx.DataSubset; +import org.deidentifier.arx.DataType; +import org.deidentifier.arx.aggregates.ClassificationConfigurationLogisticRegression; +import org.deidentifier.arx.aggregates.ClassificationConfigurationNaiveBayes; +import org.deidentifier.arx.aggregates.ClassificationConfigurationRandomForest; +import org.deidentifier.arx.criteria.Inclusion; +import org.deidentifier.arx.criteria.KAnonymity; +import org.deidentifier.arx.io.CSVHierarchyInput; +import org.deidentifier.arx.metric.Metric; + +/** + * This class implements an example on how to compare data mining performance + * It shows how to use subset for training and different subset for testing + * @author Fabian Prasser + * @author Florian Kohlmayer + */ +public class Example39_subset extends Example { + + /** + * Loads a dataset from disk + * @param dataset + * @return + * @throws IOException + */ + public static Data createData(final String dataset) throws IOException { + + // Load data + Data data = Data.create("data/" + dataset + ".csv", StandardCharsets.UTF_8, ';'); + + // Read generalization hierarchies + FilenameFilter hierarchyFilter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + if (name.matches(dataset + "_hierarchy_(.)+.csv")) { + return true; + } else { + return false; + } + } + }; + + // Create definition + File testDir = new File("data/"); + File[] genHierFiles = testDir.listFiles(hierarchyFilter); + Pattern pattern = Pattern.compile("_hierarchy_(.*?).csv"); + for (File file : genHierFiles) { + Matcher matcher = pattern.matcher(file.getName()); + if (matcher.find()) { + CSVHierarchyInput hier = new CSVHierarchyInput(file, StandardCharsets.UTF_8, ';'); + String attributeName = matcher.group(1); + data.getDefinition().setAttributeType(attributeName, Hierarchy.create(hier.getHierarchy())); + } + } + + return data; + } + + public static Set getRandomDataSubsetIndices(double dataSize, Data inputData, int numRecords) { + + if (dataSize < 0d || dataSize > 1d) { + System.out.println(" data size ratio is out of range"); + throw new IOError(new Exception()); + } + + // Create a data subset via sampling based on beta + Set subsetIndices = new HashSet(); + Random random = new SecureRandom(); + for (int i = 0; i < numRecords; ++i) { + if (random.nextDouble() < dataSize) { + subsetIndices.add(i); + } + } + return subsetIndices; + } + + /** + * Entry point. + * + * @param args the arguments + * @throws ParseException + * @throws IOException + */ + public static void main(String[] args) throws ParseException, IOException { + + String[] features = new String[] { + "sex", + "age", + "race", + "education", + "native-country", + "workclass", + "occupation", + "salary-class" + }; + + String clazz = "marital-status"; + + Data data = createData("adult"); + data.getDefinition().setAttributeType("marital-status", AttributeType.INSENSITIVE_ATTRIBUTE); + data.getDefinition().setDataType("age", DataType.INTEGER); + data.getDefinition().setResponseVariable("marital-status", true); + + // Create a training subset data with a specific percentage of the original data e.g 80% + + double dataSize = 0.80; + + // Createing a view from the original dataset + Set subsetIndicesTrain = getRandomDataSubsetIndices(dataSize, data, data.getHandle().getNumRows()) ; + + System.out.println("Creating a training data subset ...."); + DataSubset datasubTrain = DataSubset.create(data.getHandle().getNumRows(), subsetIndicesTrain); + + // To create a testing subset data from the remaining data we can use this commented code + + // Set subsetIndicesTest = new HashSet(); + // for (int i = 0; i < data.getHandle().getNumRows(); ++i) { + // subsetIndicesTest.add(i); + // } + // subsetIndicesTest.removeAll(subsetIndicesTrain); + + // System.out.println("Creating a testing data subset ...."); + // DataSubset datasubTest = DataSubset.create(data.getHandle().getNumRows(), subsetIndicesTest); + + + ARXAnonymizer anonymizer = new ARXAnonymizer(); + ARXConfiguration config = ARXConfiguration.create(); + config.addPrivacyModel(new KAnonymity(5)); + config.setSuppressionLimit(1d); + config.setQualityModel(Metric.createClassificationMetric()); + + + + // Adding the data subset to the current configuration, + // this subset will be used for the anonymization, + // other records will be transformed but only suppressed, + // In the training, only the subset will be used + config.addPrivacyModel(new Inclusion (datasubTrain) ); + config.setSuppressionLimit(1d); + config.setQualityModel(Metric.createClassificationMetric()); + + // Start anonymization process + ARXResult result = anonymizer.anonymize(data, config); + boolean evaluateWithKfold = true; + + System.out.println("5-anonymous dataset (logistic regression)"); + ClassificationConfigurationLogisticRegression logisticClassifier = ARXClassificationConfiguration.createLogisticRegression(); + logisticClassifier.setEvaluateWithKfold(evaluateWithKfold); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, logisticClassifier)); + + System.out.println("5-anonymous dataset (naive bayes)"); + ClassificationConfigurationNaiveBayes naiveBayesClassifier = ARXClassificationConfiguration.createNaiveBayes(); + naiveBayesClassifier.setEvaluateWithKfold(evaluateWithKfold); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, naiveBayesClassifier)); + + System.out.println("5-anonymous dataset (random forest)"); + ClassificationConfigurationRandomForest randomForestClassifier = ARXClassificationConfiguration.createRandomForest(); + randomForestClassifier.setEvaluateWithKfold(evaluateWithKfold); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, randomForestClassifier)); + } +} + + +/** + * =========================================================== + * Example output with evaluateWithKfold = true; + * =========================================================== + Creating a training data subset .... +5-anonymous dataset (logistic regression) +StatisticsClassification{ + - Accuracy: + * Original: 0.6953119819640607 + * ZeroR: 0.4663152310854718 + * Output: 0.6940189642596645 + - Average error: + * Original: 0.4301467184165105 + * ZeroR: 0.5336847689145282 + * Output: 0.43041888382409704 + - Brier score: + * Original: 0.4333317703310125 + * ZeroR: 0.6572252917603948 + * Output: 0.4372894125209676 + - Number of classes: 7 + - Number of measurements: 30162 +} +5-anonymous dataset (naive bayes) +StatisticsClassification{ + - Accuracy: + * Original: 0.6447516742921557 + * ZeroR: 0.4663152310854718 + * Output: 0.6722697433857171 + - Average error: + * Original: 0.38050937350272185 + * ZeroR: 0.5336847689145282 + * Output: 0.35648745375532154 + - Brier score: + * Original: 0.5499427575714274 + * ZeroR: 0.6572252917603948 + * Output: 0.512156556610383 + - Number of classes: 7 + - Number of measurements: 30162 +} +5-anonymous dataset (random forest) +SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". +SLF4J: Defaulting to no-operation (NOP) logger implementation +SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details. +StatisticsClassification{ + - Accuracy: + * Original: 0.603772959352828 + * ZeroR: 0.4663152310854718 + * Output: 0.6156421987931835 + - Average error: + * Original: 0.5699613869024831 + * ZeroR: 0.5336847689145282 + * Output: 0.5552447567631911 + - Brier score: + * Original: 0.5317809865104269 + * ZeroR: 0.6572252917603948 + * Output: 0.5042967130971948 + - Number of classes: 7 + - Number of measurements: 30162 +} + + + * =========================================================== + * Example output with evaluateWithKfold = false; + * =========================================================== + +Creating a training data subset .... +5-anonymous dataset (logistic regression) +StatisticsClassification{ + - Accuracy: + Original: 0.69 + ZeroR: 0.46 + Output: 0.6845 + - Average error: + Original: 0.4357585527644324 + ZeroR: 0.54 + Output: 0.4377426743661062 + - Brier score: + Original: 0.888404431332961 + ZeroR: 0.932658447872528 + Output: 0.8896271153039844 + - Number of classes: 7 + - Number of measurements: 6000 +} +5-anonymous dataset (naive bayes) +StatisticsClassification{ + - Accuracy: + Original: 0.6406666666666667 + ZeroR: 0.46 + Output: 0.6766666666666666 + - Average error: + Original: 0.3822702437028365 + ZeroR: 0.54 + Output: 0.3566992097114707 + - Brier score: + Original: 0.9110100351549149 + ZeroR: 0.932658447872528 + Output: 0.9023854777043593 + - Number of classes: 7 + - Number of measurements: 6000 +} +5-anonymous dataset (random forest) +SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". +SLF4J: Defaulting to no-operation (NOP) logger implementation +SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details. +StatisticsClassification{ + - Accuracy: + Original: 0.5766666666666667 + ZeroR: 0.46 + Output: 0.5466666666666666 + - Average error: + Original: 0.5760646545342571 + ZeroR: 0.54 + Output: 0.5770816662108619 + - Brier score: + Original: 0.908826414086422 + ZeroR: 0.932658447872528 + Output: 0.9066104920932022 + - Number of classes: 7 + - Number of measurements: 6000 +} + + */ diff --git a/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java b/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java index ced510241..59b2e4f74 100644 --- a/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java +++ b/src/main/org/deidentifier/arx/ARXClassificationConfiguration.java @@ -74,7 +74,8 @@ public static ClassificationConfigurationRandomForest createRandomForest() { private int vectorLength = DEFAULT_VECTOR_LENGTH; /** Modified */ private boolean modified = false; - + /** EvaluateWithKfold */ + private boolean EvaluateWithKfold = true; /** * Creates a new instance with default settings */ @@ -242,4 +243,17 @@ public T setVectorLength(int vectorLength) { } return (T)this; } + /** + * Get EvaluateWithKfold + */ + public boolean getEvaluateWithKfold() { + return this.EvaluateWithKfold; + } + + /** + * Set EvaluateWithKfold + */ + public void setEvaluateWithKfold(boolean value) { + this.EvaluateWithKfold = value; + } } diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index 37cbcb795..6b92d91fa 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; @@ -28,6 +29,7 @@ import org.deidentifier.arx.ARXClassificationConfiguration; import org.deidentifier.arx.ARXFeatureScaling; import org.deidentifier.arx.DataHandleInternal; +import org.deidentifier.arx.DataSubset; import org.deidentifier.arx.aggregates.classification.ClassificationDataSpecification; import org.deidentifier.arx.aggregates.classification.ClassificationMethod; import org.deidentifier.arx.aggregates.classification.ClassificationResult; @@ -361,9 +363,105 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, double[] zerorConfidences = new double[numSamples * ( 1 + numClasses)]; int confidencesIndex = 0; - // For each fold as a validation set - for (int evaluationFold = 0; evaluationFold < folds.size(); evaluationFold++) { - + if (config.getEvaluateWithKfold()) { + // For each fold as a validation set + for (int evaluationFold = 0; evaluationFold < folds.size(); evaluationFold++) { + + // Create classifiers + ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); + ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); + ClassificationMethod outputClassifier = null; + if (inputHandle != outputHandle) { + outputClassifier = getClassifier(interrupt, specification, config, inputHandle); + } + + // Try + try { + + // Train with all training sets + boolean trained = false; + for (int trainingFold = 0; trainingFold < folds.size(); trainingFold++) { + if (trainingFold != evaluationFold) { + for (int index : folds.get(trainingFold)) { + checkInterrupt(); + inputClassifier.train(inputHandle, outputHandle, index); + inputZeroR.train(inputHandle, outputHandle, index); + if (outputClassifier != null && !outputHandle.isOutlier(index)) { + outputClassifier.train(outputHandle, outputHandle, index); + trained = true; + } + this.progress.value = (int)((++done) * total); + } + } + } + + // Close + inputClassifier.close(); + inputZeroR.close(); + if (outputClassifier != null && trained) { + outputClassifier.close(); + } + + // Now validate + for (int index : folds.get(evaluationFold)) { + + // Check + checkInterrupt(); + + // Classify + ClassificationResult resultInput = inputClassifier.classify(inputHandle, index); + ClassificationResult resultInputZR = inputZeroR.classify(inputHandle, index); + ClassificationResult resultOutput = outputClassifier == null || !trained ? null : outputClassifier.classify(outputHandle, index); + classifications++; + + // Correct result + String actualValue = outputHandle.getValue(index, specification.classIndex, true); + + // Maintain data about ZeroR + this.zeroRAverageError += resultInputZR.error(actualValue); + this.zeroRAccuracy += resultInputZR.correct(actualValue) ? 1d : 0d; + double[] confidences = resultInputZR.confidences(); + zerorConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, zerorConfidences, confidencesIndex + 1, confidences.length); + + // Maintain data about input-based classifier + boolean correct = resultInput.correct(actualValue); + this.originalAverageError += resultInput.error(actualValue); + this.originalAccuracy += correct ? 1d : 0d; + confidences = resultInput.confidences(); + inputConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, inputConfidences, confidencesIndex + 1, confidences.length); + + // Maintain data about output-based + if (resultOutput != null) { + correct = resultOutput.correct(actualValue); + this.averageError += resultOutput.error(actualValue); + this.accuracy += correct ? 1d : 0d; + confidences = resultOutput.confidences(); + outputConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, outputConfidences, confidencesIndex + 1, confidences.length); + } + + // Next + confidencesIndex += numClasses + 1; + + this.progress.value = (int)((++done) * total); + } + } catch (Exception e) { + if (e instanceof ComputationInterruptedException) { + throw e; + } else { + throw new UnexpectedErrorException(e); + } + } + } + } else { + + // get the training data + DataSubset subsetTrain = inputHandle.getSubset(); + + // do training + // Create classifiers ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); @@ -373,42 +471,41 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, } // Try - try { - - // Train with all training sets - boolean trained = false; - for (int trainingFold = 0; trainingFold < folds.size(); trainingFold++) { - if (trainingFold != evaluationFold) { - for (int index : folds.get(trainingFold)) { - checkInterrupt(); - inputClassifier.train(inputHandle, outputHandle, index); - inputZeroR.train(inputHandle, outputHandle, index); - if (outputClassifier != null && !outputHandle.isOutlier(index)) { - outputClassifier.train(outputHandle, outputHandle, index); - trained = true; - } - this.progress.value = (int)((++done) * total); - } + try { + // Train with the training subset + for (int index : subsetTrain.getArray()) { + checkInterrupt(); + inputClassifier.train(inputHandle, outputHandle, index); + inputZeroR.train(inputHandle, outputHandle, index); + if (outputClassifier != null && !outputHandle.isOutlier(index)) { + outputClassifier.train(outputHandle, outputHandle, index); } + this.progress.value = (int)((++done) * total); } - + // Close inputClassifier.close(); inputZeroR.close(); - if (outputClassifier != null && trained) { + if (outputClassifier != null ) { outputClassifier.close(); } - // Now validate - for (int index : folds.get(evaluationFold)) { - + // create the testing subset indices + Set subsetIndicesTest = new HashSet(); + for (int i = 0; i < inputHandle.getNumRows(); ++i) { + if (! subsetTrain.getSet().contains(i)) { + subsetIndicesTest.add(i); + } + } + + for (int index : subsetIndicesTest ) { // Check checkInterrupt(); // Classify ClassificationResult resultInput = inputClassifier.classify(inputHandle, index); ClassificationResult resultInputZR = inputZeroR.classify(inputHandle, index); - ClassificationResult resultOutput = outputClassifier == null || !trained ? null : outputClassifier.classify(outputHandle, index); + ClassificationResult resultOutput = outputClassifier == null ? null : outputClassifier.classify(outputHandle, index); classifications++; // Correct result @@ -451,6 +548,36 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, throw new UnexpectedErrorException(e); } } + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + } // Maintain data about inputZR From c67a26a4c92c90dfb66e1ce08b60f8c49a98b4bc Mon Sep 17 00:00:00 2001 From: idhamari Date: Tue, 3 May 2022 09:23:51 +0200 Subject: [PATCH 2/7] formatting: remove extra space --- .../aggregates/StatisticsClassification.java | 32 +------------------ 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index 6b92d91fa..013b88e08 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -547,37 +547,7 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, } else { throw new UnexpectedErrorException(e); } - } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + } } // Maintain data about inputZR From 4a421c69c618bcdf26f68b9b682633764077ed49 Mon Sep 17 00:00:00 2001 From: idhamari Date: Mon, 9 May 2022 17:14:15 +0200 Subject: [PATCH 3/7] use methods instead of if else --- .../deidentifier/arx/examples/Example39.java | 94 ++- .../aggregates/StatisticsClassification.java | 616 ++++++++++-------- 2 files changed, 445 insertions(+), 265 deletions(-) diff --git a/src/example/org/deidentifier/arx/examples/Example39.java b/src/example/org/deidentifier/arx/examples/Example39.java index 131cbb741..102b02b51 100644 --- a/src/example/org/deidentifier/arx/examples/Example39.java +++ b/src/example/org/deidentifier/arx/examples/Example39.java @@ -19,9 +19,14 @@ import java.io.File; import java.io.FilenameFilter; +import java.io.IOError; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; import java.text.ParseException; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -32,16 +37,24 @@ import org.deidentifier.arx.AttributeType; import org.deidentifier.arx.AttributeType.Hierarchy; import org.deidentifier.arx.Data; +import org.deidentifier.arx.DataSubset; import org.deidentifier.arx.DataType; +import org.deidentifier.arx.aggregates.ClassificationConfigurationLogisticRegression; +import org.deidentifier.arx.aggregates.ClassificationConfigurationNaiveBayes; +import org.deidentifier.arx.aggregates.ClassificationConfigurationRandomForest; +import org.deidentifier.arx.criteria.Inclusion; import org.deidentifier.arx.criteria.KAnonymity; import org.deidentifier.arx.io.CSVHierarchyInput; import org.deidentifier.arx.metric.Metric; /** * This class implements an example on how to compare data mining performance - * + * The evaluation can be used with either K-fold cross validation (default) or with + * subset for training and different subset for testing + * * @author Fabian Prasser * @author Florian Kohlmayer + * @author Ibraheem Al-Dhamari */ public class Example39 extends Example { @@ -83,7 +96,25 @@ public boolean accept(File dir, String name) { return data; } + + public static Set getRandomDataSubsetIndices(double dataSize, Data inputData, int numRecords) { + + if (dataSize < 0d || dataSize > 1d) { + System.out.println(" data size ratio is out of range"); + throw new IOError(new Exception()); + } + // Create a data subset via sampling based on beta + Set subsetIndices = new HashSet(); + Random random = new SecureRandom(); + for (int i = 0; i < numRecords; ++i) { + if (random.nextDouble() < dataSize) { + subsetIndices.add(i); + } + } + return subsetIndices; + } + /** * Entry point. * @@ -111,18 +142,65 @@ public static void main(String[] args) throws ParseException, IOException { data.getDefinition().setDataType("age", DataType.INTEGER); data.getDefinition().setResponseVariable("marital-status", true); + ARXAnonymizer anonymizer = new ARXAnonymizer(); + ARXConfiguration config = ARXConfiguration.create(); config.addPrivacyModel(new KAnonymity(5)); config.setSuppressionLimit(1d); config.setQualityModel(Metric.createClassificationMetric()); + // Create a training subset data with a specific percentage of the original data e.g 80% + double dataSize = 0.80; + + // Creating a view from the original dataset + Set subsetIndicesTrain = getRandomDataSubsetIndices(dataSize, data, data.getHandle().getNumRows()) ; + DataSubset datasubTrain = DataSubset.create(data.getHandle().getNumRows(), subsetIndicesTrain); + + // Adding the data subset to the current configuration, + // this subset will be used for the anonymization, + // other records will be transformed but only suppressed, + // In the training, only the subset will be used + config.addPrivacyModel(new Inclusion (datasubTrain) ); + + config.setSuppressionLimit(1d); + config.setQualityModel(Metric.createClassificationMetric()); + + // Start anonymization process ARXResult result = anonymizer.anonymize(data, config); - System.out.println("5-anonymous dataset (logistic regression)"); - System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, ARXClassificationConfiguration.createLogisticRegression())); - System.out.println("5-anonymous dataset (naive bayes)"); - System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, ARXClassificationConfiguration.createNaiveBayes())); - System.out.println("5-anonymous dataset (random forest)"); - System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, ARXClassificationConfiguration.createRandomForest())); + + System.out.println("==============================================="); + System.out.println(" 5-anonymous dataset (logistic regression)"); + System.out.println("==============================================="); + ClassificationConfigurationLogisticRegression logisticClassifier = ARXClassificationConfiguration.createLogisticRegression(); + System.out.println("Evaluation using K-fold cross validation: ..............."); + logisticClassifier.setEvaluateWithKfold(true); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, logisticClassifier)); + System.out.println("Evaluation using testing subset: ........................"); + logisticClassifier.setEvaluateWithKfold(false); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, logisticClassifier)); + + System.out.println("==============================================="); + System.out.println(" 5-anonymous dataset (naive bayes)"); + System.out.println("==============================================="); + System.out.println("Evaluation using K-fold cross validation: ..............."); + logisticClassifier.setEvaluateWithKfold(true); + ClassificationConfigurationNaiveBayes naiveBayesClassifier = ARXClassificationConfiguration.createNaiveBayes(); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, naiveBayesClassifier)); + System.out.println("Evaluation using testing subset: ........................"); + logisticClassifier.setEvaluateWithKfold(false); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, naiveBayesClassifier)); + + System.out.println("==============================================="); + System.out.println(" 5-anonymous dataset (random forest)"); + System.out.println("==============================================="); + System.out.println("Evaluation using K-fold cross validation: ..............."); + logisticClassifier.setEvaluateWithKfold(true); + ClassificationConfigurationRandomForest randomForestClassifier = ARXClassificationConfiguration.createRandomForest(); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, randomForestClassifier)); + System.out.println("Evaluation using testing subset: ........................"); + logisticClassifier.setEvaluateWithKfold(false); + System.out.println(result.getOutput().getStatistics().getClassificationPerformance(features, clazz, randomForestClassifier)); + } -} +} \ No newline at end of file diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index 013b88e08..94d04ab98 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -348,264 +348,14 @@ private static ClassificationMethod getClassifier(WrappedBoolean interrupt, // Number of class values this.numClasses = specification.classMap.size(); - // Train and evaluate - int k = numSamples > config.getNumFolds() ? config.getNumFolds() : numSamples; - List> folds = getFolds(inputHandle.getNumRows(), numSamples, k); - - // Track - int classifications = 0; - double total = 100d / ((double)numSamples * (double)folds.size()); - double done = 0d; - - // ROC - double[] inputConfidences = new double[numSamples * ( 1 + numClasses)]; - double[] outputConfidences = (inputHandle == outputHandle) ? null : new double[numSamples * ( 1 + numClasses)]; - double[] zerorConfidences = new double[numSamples * ( 1 + numClasses)]; - int confidencesIndex = 0; - - if (config.getEvaluateWithKfold()) { - // For each fold as a validation set - for (int evaluationFold = 0; evaluationFold < folds.size(); evaluationFold++) { - - // Create classifiers - ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); - ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); - ClassificationMethod outputClassifier = null; - if (inputHandle != outputHandle) { - outputClassifier = getClassifier(interrupt, specification, config, inputHandle); - } - - // Try - try { - - // Train with all training sets - boolean trained = false; - for (int trainingFold = 0; trainingFold < folds.size(); trainingFold++) { - if (trainingFold != evaluationFold) { - for (int index : folds.get(trainingFold)) { - checkInterrupt(); - inputClassifier.train(inputHandle, outputHandle, index); - inputZeroR.train(inputHandle, outputHandle, index); - if (outputClassifier != null && !outputHandle.isOutlier(index)) { - outputClassifier.train(outputHandle, outputHandle, index); - trained = true; - } - this.progress.value = (int)((++done) * total); - } - } - } - - // Close - inputClassifier.close(); - inputZeroR.close(); - if (outputClassifier != null && trained) { - outputClassifier.close(); - } - - // Now validate - for (int index : folds.get(evaluationFold)) { - - // Check - checkInterrupt(); - - // Classify - ClassificationResult resultInput = inputClassifier.classify(inputHandle, index); - ClassificationResult resultInputZR = inputZeroR.classify(inputHandle, index); - ClassificationResult resultOutput = outputClassifier == null || !trained ? null : outputClassifier.classify(outputHandle, index); - classifications++; - - // Correct result - String actualValue = outputHandle.getValue(index, specification.classIndex, true); - - // Maintain data about ZeroR - this.zeroRAverageError += resultInputZR.error(actualValue); - this.zeroRAccuracy += resultInputZR.correct(actualValue) ? 1d : 0d; - double[] confidences = resultInputZR.confidences(); - zerorConfidences[confidencesIndex] = index; - System.arraycopy(confidences, 0, zerorConfidences, confidencesIndex + 1, confidences.length); - - // Maintain data about input-based classifier - boolean correct = resultInput.correct(actualValue); - this.originalAverageError += resultInput.error(actualValue); - this.originalAccuracy += correct ? 1d : 0d; - confidences = resultInput.confidences(); - inputConfidences[confidencesIndex] = index; - System.arraycopy(confidences, 0, inputConfidences, confidencesIndex + 1, confidences.length); - - // Maintain data about output-based - if (resultOutput != null) { - correct = resultOutput.correct(actualValue); - this.averageError += resultOutput.error(actualValue); - this.accuracy += correct ? 1d : 0d; - confidences = resultOutput.confidences(); - outputConfidences[confidencesIndex] = index; - System.arraycopy(confidences, 0, outputConfidences, confidencesIndex + 1, confidences.length); - } - - // Next - confidencesIndex += numClasses + 1; - - this.progress.value = (int)((++done) * total); - } - } catch (Exception e) { - if (e instanceof ComputationInterruptedException) { - throw e; - } else { - throw new UnexpectedErrorException(e); - } - } - } - } else { - - // get the training data - DataSubset subsetTrain = inputHandle.getSubset(); - - // do training - - // Create classifiers - ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); - ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); - ClassificationMethod outputClassifier = null; - if (inputHandle != outputHandle) { - outputClassifier = getClassifier(interrupt, specification, config, inputHandle); - } - - // Try - try { - // Train with the training subset - for (int index : subsetTrain.getArray()) { - checkInterrupt(); - inputClassifier.train(inputHandle, outputHandle, index); - inputZeroR.train(inputHandle, outputHandle, index); - if (outputClassifier != null && !outputHandle.isOutlier(index)) { - outputClassifier.train(outputHandle, outputHandle, index); - } - this.progress.value = (int)((++done) * total); - } - - // Close - inputClassifier.close(); - inputZeroR.close(); - if (outputClassifier != null ) { - outputClassifier.close(); - } - - // create the testing subset indices - Set subsetIndicesTest = new HashSet(); - for (int i = 0; i < inputHandle.getNumRows(); ++i) { - if (! subsetTrain.getSet().contains(i)) { - subsetIndicesTest.add(i); - } - } - - for (int index : subsetIndicesTest ) { - // Check - checkInterrupt(); - - // Classify - ClassificationResult resultInput = inputClassifier.classify(inputHandle, index); - ClassificationResult resultInputZR = inputZeroR.classify(inputHandle, index); - ClassificationResult resultOutput = outputClassifier == null ? null : outputClassifier.classify(outputHandle, index); - classifications++; - - // Correct result - String actualValue = outputHandle.getValue(index, specification.classIndex, true); - - // Maintain data about ZeroR - this.zeroRAverageError += resultInputZR.error(actualValue); - this.zeroRAccuracy += resultInputZR.correct(actualValue) ? 1d : 0d; - double[] confidences = resultInputZR.confidences(); - zerorConfidences[confidencesIndex] = index; - System.arraycopy(confidences, 0, zerorConfidences, confidencesIndex + 1, confidences.length); - - // Maintain data about input-based classifier - boolean correct = resultInput.correct(actualValue); - this.originalAverageError += resultInput.error(actualValue); - this.originalAccuracy += correct ? 1d : 0d; - confidences = resultInput.confidences(); - inputConfidences[confidencesIndex] = index; - System.arraycopy(confidences, 0, inputConfidences, confidencesIndex + 1, confidences.length); - - // Maintain data about output-based - if (resultOutput != null) { - correct = resultOutput.correct(actualValue); - this.averageError += resultOutput.error(actualValue); - this.accuracy += correct ? 1d : 0d; - confidences = resultOutput.confidences(); - outputConfidences[confidencesIndex] = index; - System.arraycopy(confidences, 0, outputConfidences, confidencesIndex + 1, confidences.length); - } - - // Next - confidencesIndex += numClasses + 1; - - this.progress.value = (int)((++done) * total); - } - } catch (Exception e) { - if (e instanceof ComputationInterruptedException) { - throw e; - } else { - throw new UnexpectedErrorException(e); - } - } - } - - // Maintain data about inputZR - this.zeroRAverageError /= (double)classifications; - this.zeroRAccuracy/= (double)classifications; - - // Maintain data about inputLR - this.originalAverageError /= (double)classifications; - this.originalAccuracy /= (double)classifications; - - // Brier score - this.zerorBrierScore = calculateBrierScore(zerorConfidences, outputHandle, specification); - this.originalBrierScore = calculateBrierScore(inputConfidences, outputHandle, specification); - if (inputHandle != outputHandle) { - this.brierScore = calculateBrierScore(outputConfidences, outputHandle, specification); - } - - // Initialize ROC curves for zeroR - for (String attr : specification.classMap.keySet()) { - zerorROC.put(attr, new ROCCurve(attr, - zerorConfidences, - numClasses, - specification.classMap.get(attr), - outputHandle, - specification.classIndex)); - } - - // Initialize ROC curves on original data - for (String attr : specification.classMap.keySet()) { - originalROC.put(attr, new ROCCurve(attr, - inputConfidences, - numClasses, - specification.classMap.get(attr), - outputHandle, - specification.classIndex)); - } - // Initialize ROC curves on anonymized data - if (inputHandle != outputHandle) { - for (String attr : specification.classMap.keySet()) { - ROC.put(attr, new ROCCurve(attr, - outputConfidences, - numClasses, - specification.classMap.get(attr), - outputHandle, - specification.classIndex)); - } - } - - // Maintain data about outputLR - if (inputHandle != outputHandle) { - this.averageError /= (double)classifications; - this.accuracy /= (double)classifications; - } else { - this.averageError = this.originalAverageError; - this.accuracy = this.originalAccuracy; + // Training and evaluation + if (config.getEvaluateWithKfold()) { + // Evaluating using K-fold cross validation + evaluateWithKFoldCrossValidation(inputHandle, outputHandle, config, specification); + } else { + // Evaluating using test subset + evaluateWithTestingSet(inputHandle, outputHandle, config, specification); } - - this.numMeasurements = classifications; } /** @@ -878,4 +628,356 @@ private int getNumSamples(int numRows, ARXClassificationConfiguration config) } return numSamples; } + + /** + * Evaluation of the Classification using 10 K-fold cross validation + * + * @param + * @param + */ + private void evaluateWithKFoldCrossValidation(DataHandleInternal inputHandle, + DataHandleInternal outputHandle, + ARXClassificationConfiguration config, + ClassificationDataSpecification specification) throws ParseException { + + // Track + int classifications = 0; + double total = numSamples; + double done = 0d; + + // ROC + double[] inputConfidences = new double[numSamples * ( 1 + numClasses)]; + double[] outputConfidences = (inputHandle == outputHandle) ? null : new double[numSamples * ( 1 + numClasses)]; + double[] zerorConfidences = new double[numSamples * ( 1 + numClasses)]; + int confidencesIndex = 0; + + // Train and evaluate + int k = numSamples > config.getNumFolds() ? config.getNumFolds() : numSamples; + List> folds = getFolds(inputHandle.getNumRows(), numSamples, k); + + // Track + total = 100d / ((double)numSamples * (double)folds.size()); + + // For each fold as a validation set + for (int evaluationFold = 0; evaluationFold < folds.size(); evaluationFold++) { + + // Create classifiers + ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); + ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); + ClassificationMethod outputClassifier = null; + if (inputHandle != outputHandle) { + outputClassifier = getClassifier(interrupt, specification, config, inputHandle); + } + + // Try + try { + + // Train with all training sets + boolean trained = false; + for (int trainingFold = 0; trainingFold < folds.size(); trainingFold++) { + if (trainingFold != evaluationFold) { + for (int index : folds.get(trainingFold)) { + checkInterrupt(); + inputClassifier.train(inputHandle, outputHandle, index); + inputZeroR.train(inputHandle, outputHandle, index); + if (outputClassifier != null && !outputHandle.isOutlier(index)) { + outputClassifier.train(outputHandle, outputHandle, index); + trained = true; + } + this.progress.value = (int)(++done * total); + } + } + } + + // Close + inputClassifier.close(); + inputZeroR.close(); + if (outputClassifier != null && trained) { + outputClassifier.close(); + } + + // Now validate + for (int index : folds.get(evaluationFold)) { + + // Check + checkInterrupt(); + + // Classify + ClassificationResult resultInput = inputClassifier.classify(inputHandle, index); + ClassificationResult resultInputZR = inputZeroR.classify(inputHandle, index); + ClassificationResult resultOutput = outputClassifier == null || !trained ? null : outputClassifier.classify(outputHandle, index); + classifications++; + + // Correct result + String actualValue = outputHandle.getValue(index, specification.classIndex, true); + + // Maintain data about ZeroR + this.zeroRAverageError += resultInputZR.error(actualValue); + this.zeroRAccuracy += resultInputZR.correct(actualValue) ? 1d : 0d; + double[] confidences = resultInputZR.confidences(); + zerorConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, zerorConfidences, confidencesIndex + 1, confidences.length); + + // Maintain data about input-based classifier + boolean correct = resultInput.correct(actualValue); + this.originalAverageError += resultInput.error(actualValue); + this.originalAccuracy += correct ? 1d : 0d; + confidences = resultInput.confidences(); + inputConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, inputConfidences, confidencesIndex + 1, confidences.length); + + // Maintain data about output-based + if (resultOutput != null) { + correct = resultOutput.correct(actualValue); + this.averageError += resultOutput.error(actualValue); + this.accuracy += correct ? 1d : 0d; + confidences = resultOutput.confidences(); + outputConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, outputConfidences, confidencesIndex + 1, confidences.length); + } + + // Next + confidencesIndex += numClasses + 1; + + this.progress.value = (int)((++done) * total); + } + } catch (Exception e) { + if (e instanceof ComputationInterruptedException) { + throw e; + } else { + throw new UnexpectedErrorException(e); + } + } + + } + // Maintain data about inputZR + this.zeroRAverageError /= (double)classifications; + this.zeroRAccuracy/= (double)classifications; + + // Maintain data about inputLR + this.originalAverageError /= (double)classifications; + this.originalAccuracy /= (double)classifications; + + // Brier score + this.zerorBrierScore = calculateBrierScore(zerorConfidences, outputHandle, specification); + this.originalBrierScore = calculateBrierScore(inputConfidences, outputHandle, specification); + if (inputHandle != outputHandle) { + this.brierScore = calculateBrierScore(outputConfidences, outputHandle, specification); + } + + // Initialize ROC curves for zeroR + for (String attr : specification.classMap.keySet()) { + zerorROC.put(attr, new ROCCurve(attr, + zerorConfidences, + numClasses, + specification.classMap.get(attr), + outputHandle, + specification.classIndex)); + } + + // Initialize ROC curves on original data + for (String attr : specification.classMap.keySet()) { + originalROC.put(attr, new ROCCurve(attr, + inputConfidences, + numClasses, + specification.classMap.get(attr), + outputHandle, + specification.classIndex)); + } + // Initialize ROC curves on anonymized data + if (inputHandle != outputHandle) { + for (String attr : specification.classMap.keySet()) { + ROC.put(attr, new ROCCurve(attr, + outputConfidences, + numClasses, + specification.classMap.get(attr), + outputHandle, + specification.classIndex)); + } + } + + // Maintain data about outputLR + if (inputHandle != outputHandle) { + this.averageError /= (double)classifications; + this.accuracy /= (double)classifications; + } else { + this.averageError = this.originalAverageError; + this.accuracy = this.originalAccuracy; + } + + this.numMeasurements = classifications; + } + + /** + * Evaluation of the Classification using a test subset + * + * @param + * @param + */ + private void evaluateWithTestingSet(DataHandleInternal inputHandle, + DataHandleInternal outputHandle, + ARXClassificationConfiguration config, + ClassificationDataSpecification specification) throws ParseException { + + + // Track + int classifications = 0; + double done = 0d; + + // ROC + double[] inputConfidences = new double[numSamples * ( 1 + numClasses)]; + double[] outputConfidences = (inputHandle == outputHandle) ? null : new double[numSamples * ( 1 + numClasses)]; + double[] zerorConfidences = new double[numSamples * ( 1 + numClasses)]; + int confidencesIndex = 0; + + double total = 100d / (double)numSamples; + + // get the training data + DataSubset subsetTrain = inputHandle.getSubset(); + + // do training + + // Create classifiers + ClassificationMethod inputClassifier = getClassifier(interrupt, specification, config, inputHandle); + ClassificationMethod inputZeroR = new MultiClassZeroR(interrupt, specification); + ClassificationMethod outputClassifier = null; + if (inputHandle != outputHandle) { + outputClassifier = getClassifier(interrupt, specification, config, inputHandle); + } + + // Try + try { + // Train with the training subset + for (int index : subsetTrain.getArray()) { + checkInterrupt(); + inputClassifier.train(inputHandle, outputHandle, index); + inputZeroR.train(inputHandle, outputHandle, index); + if (outputClassifier != null && !outputHandle.isOutlier(index)) { + outputClassifier.train(outputHandle, outputHandle, index); + } + this.progress.value = (int)((++done) * total); + } + + // Close + inputClassifier.close(); + inputZeroR.close(); + if (outputClassifier != null ) { + outputClassifier.close(); + } + + // create the testing subset indices + Set subsetIndicesTest = new HashSet(); + for (int i = 0; i < inputHandle.getNumRows(); ++i) { + if (! subsetTrain.getSet().contains(i)) { + subsetIndicesTest.add(i); + } + } + + for (int index : subsetIndicesTest ) { + // Check + checkInterrupt(); + + // Classify + ClassificationResult resultInput = inputClassifier.classify(inputHandle, index); + ClassificationResult resultInputZR = inputZeroR.classify(inputHandle, index); + ClassificationResult resultOutput = outputClassifier == null ? null : outputClassifier.classify(outputHandle, index); + classifications++; + + // Correct result + String actualValue = outputHandle.getValue(index, specification.classIndex, true); + + // Maintain data about ZeroR + this.zeroRAverageError += resultInputZR.error(actualValue); + this.zeroRAccuracy += resultInputZR.correct(actualValue) ? 1d : 0d; + double[] confidences = resultInputZR.confidences(); + zerorConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, zerorConfidences, confidencesIndex + 1, confidences.length); + + // Maintain data about input-based classifier + boolean correct = resultInput.correct(actualValue); + this.originalAverageError += resultInput.error(actualValue); + this.originalAccuracy += correct ? 1d : 0d; + confidences = resultInput.confidences(); + inputConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, inputConfidences, confidencesIndex + 1, confidences.length); + + // Maintain data about output-based + if (resultOutput != null) { + correct = resultOutput.correct(actualValue); + this.averageError += resultOutput.error(actualValue); + this.accuracy += correct ? 1d : 0d; + confidences = resultOutput.confidences(); + outputConfidences[confidencesIndex] = index; + System.arraycopy(confidences, 0, outputConfidences, confidencesIndex + 1, confidences.length); + } + + // Next + confidencesIndex += numClasses + 1; + this.progress.value = (int)((++done) * total); + } + } catch (Exception e) { + if (e instanceof ComputationInterruptedException) { + throw e; + } else { + throw new UnexpectedErrorException(e); + } + } + + // Maintain data about inputZR + this.zeroRAverageError /= (double)classifications; + this.zeroRAccuracy/= (double)classifications; + + // Maintain data about inputLR + this.originalAverageError /= (double)classifications; + this.originalAccuracy /= (double)classifications; + + // Brier score + this.zerorBrierScore = calculateBrierScore(zerorConfidences, outputHandle, specification); + this.originalBrierScore = calculateBrierScore(inputConfidences, outputHandle, specification); + if (inputHandle != outputHandle) { + this.brierScore = calculateBrierScore(outputConfidences, outputHandle, specification); + } + + // Initialize ROC curves for zeroR + for (String attr : specification.classMap.keySet()) { + zerorROC.put(attr, new ROCCurve(attr, + zerorConfidences, + numClasses, + specification.classMap.get(attr), + outputHandle, + specification.classIndex)); + } + + // Initialize ROC curves on original data + for (String attr : specification.classMap.keySet()) { + originalROC.put(attr, new ROCCurve(attr, + inputConfidences, + numClasses, + specification.classMap.get(attr), + outputHandle, + specification.classIndex)); + } + // Initialize ROC curves on anonymized data + if (inputHandle != outputHandle) { + for (String attr : specification.classMap.keySet()) { + ROC.put(attr, new ROCCurve(attr, + outputConfidences, + numClasses, + specification.classMap.get(attr), + outputHandle, + specification.classIndex)); + } + } + + // Maintain data about outputLR + if (inputHandle != outputHandle) { + this.averageError /= (double)classifications; + this.accuracy /= (double)classifications; + } else { + this.averageError = this.originalAverageError; + this.accuracy = this.originalAccuracy; + } + + this.numMeasurements = classifications; + } } From fc65a3243355cd85a635b9687c13f00d688d304b Mon Sep 17 00:00:00 2001 From: idhamari Date: Mon, 9 May 2022 18:36:27 +0200 Subject: [PATCH 4/7] add missing parameter in the documentation --- .../arx/aggregates/StatisticsClassification.java | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index 94d04ab98..81388a2ba 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -1,4 +1,3 @@ -/* * ARX: Powerful Data Anonymization * Copyright 2012 - 2021 Fabian Prasser and contributors * @@ -632,8 +631,11 @@ private int getNumSamples(int numRows, ARXClassificationConfiguration config) /** * Evaluation of the Classification using 10 K-fold cross validation * - * @param - * @param + * @param inputHandle - The input features handle + * @param outputHandle - The output features handle + * @param config - The configuration + * @param specification - The specification + * @throws ParseException */ private void evaluateWithKFoldCrossValidation(DataHandleInternal inputHandle, DataHandleInternal outputHandle, @@ -811,8 +813,11 @@ private void evaluateWithKFoldCrossValidation(DataHandleInternal inputHandle, /** * Evaluation of the Classification using a test subset * - * @param - * @param + * @param inputHandle - The input features handle + * @param outputHandle - The output features handle + * @param config - The configuration + * @param specification - The specification + * @throws ParseException */ private void evaluateWithTestingSet(DataHandleInternal inputHandle, DataHandleInternal outputHandle, @@ -981,3 +986,4 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, this.numMeasurements = classifications; } } + From 10c4b8f717f378eb0c0f4fb4568a8e85753b774b Mon Sep 17 00:00:00 2001 From: idhamari Date: Mon, 9 May 2022 18:37:23 +0200 Subject: [PATCH 5/7] remove extra line --- .../deidentifier/arx/aggregates/StatisticsClassification.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index 81388a2ba..409b4333f 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -986,4 +986,3 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, this.numMeasurements = classifications; } } - From be7bfec47aa956eef5ec7cf6d67231ab18f60366 Mon Sep 17 00:00:00 2001 From: idhamari Date: Wed, 25 May 2022 08:31:49 +0200 Subject: [PATCH 6/7] add gui button and actions for kfold option --- .../org/deidentifier/arx/gui/model/Model.java | 15 +- .../arx/gui/model/ModelEvent.java | 3 + .../arx/gui/resources/Resources.java | 1 + .../arx/gui/resources/crossKFold.png | Bin 0 -> 534 bytes .../arx/gui/resources/messages.properties | 3 +- .../arx/gui/resources/tickKFold.png | Bin 0 -> 517 bytes .../impl/utility/LayoutUtilityStatistics.java | 144 +++++++++++++----- 7 files changed, 130 insertions(+), 36 deletions(-) create mode 100644 src/gui/org/deidentifier/arx/gui/resources/crossKFold.png create mode 100644 src/gui/org/deidentifier/arx/gui/resources/tickKFold.png diff --git a/src/gui/org/deidentifier/arx/gui/model/Model.java b/src/gui/org/deidentifier/arx/gui/model/Model.java index 382ddfda5..45fb261d1 100644 --- a/src/gui/org/deidentifier/arx/gui/model/Model.java +++ b/src/gui/org/deidentifier/arx/gui/model/Model.java @@ -182,8 +182,11 @@ public static enum Perspective { /** Current selection. */ private String selectedAttribute = null; - /** Enable/disable. */ + /** Enable/disable visualization */ private Boolean showVisualization = true; + + /** Enable/disable KFold evaluation */ + private Boolean useKFold = true; /** Last two selections. */ private String[] pair = new String[] { null, null }; @@ -1864,6 +1867,16 @@ public void setVisualizationEnabled(boolean value){ this.setModified(); } + /** + * Sets KFold as enabled/disabled. + * + * @param value + */ + public void setKFoldEnabled(boolean value){ + this.useKFold = value; + this.setModified(); + } + /** * Converts attributes into an array ordered by occurrence in the dataset * @param set diff --git a/src/gui/org/deidentifier/arx/gui/model/ModelEvent.java b/src/gui/org/deidentifier/arx/gui/model/ModelEvent.java index a08b97d4b..a1ced0403 100644 --- a/src/gui/org/deidentifier/arx/gui/model/ModelEvent.java +++ b/src/gui/org/deidentifier/arx/gui/model/ModelEvent.java @@ -85,6 +85,9 @@ public static enum ModelPart { /** SELECTED_UTILITY_VISUALIZATION */ SELECTED_UTILITY_VISUALIZATION, + /** SELECTED_UTILITY_KFold */ + SELECTED_UTILITY_KFold, + /** ATTRIBUTE_VALUE */ ATTRIBUTE_VALUE, diff --git a/src/gui/org/deidentifier/arx/gui/resources/Resources.java b/src/gui/org/deidentifier/arx/gui/resources/Resources.java index 0274297ed..93ce501eb 100644 --- a/src/gui/org/deidentifier/arx/gui/resources/Resources.java +++ b/src/gui/org/deidentifier/arx/gui/resources/Resources.java @@ -252,6 +252,7 @@ public Image getManagedImage(final String name) { if (imageCache.containsKey(name)) { return imageCache.get(name); } else { + //System.out.println("xxxxxxxxxxxxx image: " + name); Image image = getImage(name); imageCache.put(name, image); return image; diff --git a/src/gui/org/deidentifier/arx/gui/resources/crossKFold.png b/src/gui/org/deidentifier/arx/gui/resources/crossKFold.png new file mode 100644 index 0000000000000000000000000000000000000000..eb608e25536b83545c30343713262742c8bb8785 GIT binary patch literal 534 zcmV+x0_pvUP);aJgP6fTbZJU*e-7`xB(D+72Y0yNJ?Gx@ zedk;Q(1?gb&8)|Hqb8|Xl%>dvKT6+9Z zI`p}&xyl7uTz)pIt)}$!mDI{a75r{O`3|H4m@L4H2DCO|_B3q!;~<60_t~kv=Tsi% zmSE;E2*5c<>lus(kS{~03oq}%gXx>NsC3L=b(Qt|kj}&NHhgbH5)$Z0^#$qb8|l}# zVy|@hzSe)9(bAgK*_J9L>C;JR=z_+IbY)9gnSCO?jYu=8TGhmaw0|t54rHaty^#ll z680PgFSkp(S_7dND_&7H8#X0rMDf!tE1rOGvu)=(B$0{{R307*qoM6N<$f{!!$hX4Qo literal 0 HcmV?d00001 diff --git a/src/gui/org/deidentifier/arx/gui/resources/messages.properties b/src/gui/org/deidentifier/arx/gui/resources/messages.properties index c44d7efc3..5d35ef356 100644 --- a/src/gui/org/deidentifier/arx/gui/resources/messages.properties +++ b/src/gui/org/deidentifier/arx/gui/resources/messages.properties @@ -649,7 +649,8 @@ SeparatorDialog.9=Error StatisticsView.0=Distribution StatisticsView.1=Contingency StatisticsView.2=Properties -StatisticsView.3=Enable/disable +StatisticsView.3=Enable/disable visualization +StatisticsView.31=useKFold StatisticsView.4=Distribution (table) StatisticsView.5=Contingency (table) StatisticsView.6=Summary statistics diff --git a/src/gui/org/deidentifier/arx/gui/resources/tickKFold.png b/src/gui/org/deidentifier/arx/gui/resources/tickKFold.png new file mode 100644 index 0000000000000000000000000000000000000000..bc8be08c78ed8b68d04278268f0a99c79a87f27e GIT binary patch literal 517 zcmV+g0{Z=lP)}{fkT6Iss=uH` zyMh8GA_{*n8x86i8iC|OFo+fg?na_bLPC)hC|QfbMNRYOep}r49Bn3iXLawn=li|i zd%x!xNg`r|rN)?zTr4Um1cu-+zb(P$+bXH30$84dvDaXOqIe22ahQAy0x)pA8#0F+ zaD5q`zJ#^vKtbXFG=GEp58&}fIGKXFgYab*2A@MsN*vke$P4e`YzF3kfOp49lC@TA zQ_`-alpd4(qV(dnl<1HqKD+ZUa`&WoTAIyEt4mT<*`s}!>4(d&V96BP)KUV(u$e4mGNw_wGV&MNg>8G)&WYguul1wKqe&j|bsaR2s%wYEf% z&pXfm6u{KN%^^6y9~MTTI}2ZnuFm%X3q?UJ7IQYfc_<~@B$JeOv`VLjrJPsFk&^aw zN!JIYrekj3*eiLTBoo+MCNFCIcNU@ysidt0AtU)I literal 0 HcmV?d00001 diff --git a/src/gui/org/deidentifier/arx/gui/view/impl/utility/LayoutUtilityStatistics.java b/src/gui/org/deidentifier/arx/gui/view/impl/utility/LayoutUtilityStatistics.java index af529f13a..a2eeda93e 100644 --- a/src/gui/org/deidentifier/arx/gui/view/impl/utility/LayoutUtilityStatistics.java +++ b/src/gui/org/deidentifier/arx/gui/view/impl/utility/LayoutUtilityStatistics.java @@ -17,12 +17,21 @@ package org.deidentifier.arx.gui.view.impl.utility; +import java.io.IOError; +import java.security.SecureRandom; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Random; +import java.util.Set; import java.util.Map.Entry; +import org.deidentifier.arx.Data; +import org.deidentifier.arx.DataSubset; +import org.deidentifier.arx.RowSet; +import org.deidentifier.arx.criteria.Inclusion; import org.deidentifier.arx.gui.Controller; import org.deidentifier.arx.gui.model.Model; import org.deidentifier.arx.gui.model.ModelEvent; @@ -75,13 +84,26 @@ public class LayoutUtilityStatistics implements ILayout, IView { private final ComponentTitledFolder folder; /** View */ - private final ToolItem enable; + private final ToolItem chkbtnVisualisation; /** View */ - private final Image enabled; + private ToolItem chkbtnUseKFold; + + + /** View */ + private Boolean useKFold = true; + + /** View */ + private final Image icnVisEnabled; + + /** View */ + private final Image icnVisDisabled; + + /** View */ + private final Image icnUseKFoldEnabled; /** View */ - private final Image disabled; + private final Image icnUseKFoldDisabled; /** View */ private final Map helpids = new HashMap(); @@ -108,23 +130,37 @@ public LayoutUtilityStatistics(final Composite parent, final ModelPart target, final ModelPart reset) { - this.enabled = controller.getResources().getManagedImage("tick.png"); //$NON-NLS-1$ - this.disabled = controller.getResources().getManagedImage("cross.png"); //$NON-NLS-1$ + this.icnVisEnabled = controller.getResources().getManagedImage("tick.png"); //$NON-NLS-1$ + this.icnVisDisabled = controller.getResources().getManagedImage("cross.png"); //$NON-NLS-1$ + + this.icnUseKFoldEnabled = controller.getResources().getManagedImage("tickKFold.png"); //$NON-NLS-1$ + this.icnUseKFoldDisabled = controller.getResources().getManagedImage("crossKFold.png"); //$NON-NLS-1$ + this.controller = controller; controller.addListener(ModelPart.MODEL, this); controller.addListener(ModelPart.SELECTED_UTILITY_VISUALIZATION, this); - // Create enable/disable button - final String label = Resources.getMessage("StatisticsView.3"); //$NON-NLS-1$ - ComponentTitledFolderButtonBar bar = new ComponentTitledFolderButtonBar("id-50", helpids); //$NON-NLS-1$ - bar.add(label, disabled, true, new Runnable() { @Override public void run() { - toggleEnabled(); - toggleImage(); + // Create visualization toolbar + ComponentTitledFolderButtonBar toolbarVis = new ComponentTitledFolderButtonBar("id-50", helpids); //$NON-NLS-1$ + + // Create visualization enable/disable check button + final String chkbtnVisualisationLabel = Resources.getMessage("StatisticsView.3"); //$NON-NLS-1$ + toolbarVis.add(chkbtnVisualisationLabel, icnVisEnabled, true, new Runnable() { @Override public void run() { + toggleChkbtnVisualization(); + toggleChkbtnVisualizationIcon(); }}); - + + // Create useKFold enable/disable check button + final String chkbtnUseKFoldLabel = Resources.getMessage("StatisticsView.31"); //$NON-NLS-1$ + if (target == ModelPart.OUTPUT) { + toolbarVis.add(chkbtnUseKFoldLabel, icnUseKFoldEnabled, true, new Runnable() { @Override public void run() { + toggleChkbtnUseKFold(); + toggleChkbtnUseKFoldIcon(); + }}); + } // Create the tab folder - folder = new ComponentTitledFolder(parent, controller, bar, null, false, true); + folder = new ComponentTitledFolder(parent, controller, toolbarVis, null, false, true); // Register tabs this.registerView(new ViewStatisticsSummaryTable(folder.createItem(TAB_SUMMARY, null, true), controller, target, reset), "help.utility.summary"); //$NON-NLS-1$ @@ -143,9 +179,15 @@ public LayoutUtilityStatistics(final Composite parent, // Init folder this.folder.setSelection(0); - this.enable = folder.getButtonItem(label); - this.enable.setEnabled(false); + + this.chkbtnVisualisation = folder.getButtonItem(chkbtnVisualisationLabel); + this.chkbtnVisualisation.setEnabled(false); + if ( target == ModelPart.OUTPUT ) { + this.chkbtnUseKFold = folder.getButtonItem(chkbtnUseKFoldLabel); + this.chkbtnUseKFold.setEnabled(true); + this.chkbtnUseKFold.setToolTipText("Disable K-fold evaluation!"); + }; // Set initial visibility folder.setVisibleItems(Arrays.asList(new String[] { TAB_SUMMARY, TAB_DISTRIBUTION, @@ -153,7 +195,7 @@ public LayoutUtilityStatistics(final Composite parent, TAB_CLASSES_TABLE, TAB_PROPERTIES })); } - + /** * Adds a selection listener. * @@ -188,9 +230,9 @@ public List getVisibleItems() { @Override public void reset() { model = null; - enable.setSelection(true); - enable.setImage(enabled); - enable.setEnabled(false); + chkbtnVisualisation.setSelection(true); + chkbtnVisualisation.setImage(icnVisEnabled); + chkbtnVisualisation.setEnabled(false); } /** @@ -227,13 +269,13 @@ public void update(ModelEvent event) { if (event.part == ModelPart.MODEL) { this.model = (Model)event.data; - this.enable.setEnabled(true); - this.enable.setSelection(model.isVisualizationEnabled()); - this.toggleImage(); + this.chkbtnVisualisation.setEnabled(true); + this.chkbtnVisualisation.setSelection(model.isVisualizationEnabled()); + this.toggleChkbtnVisualizationIcon(); } else if (event.part == ModelPart.SELECTED_UTILITY_VISUALIZATION) { - this.enable.setSelection(model.isVisualizationEnabled()); - this.toggleImage(); - } + this.chkbtnVisualisation.setSelection(model.isVisualizationEnabled()); + this.toggleChkbtnVisualizationIcon(); + } } /** @@ -259,19 +301,53 @@ private void registerView(ViewStatisticsBasic view, String helpid) { /** * Toggle visualization enabled. */ - private void toggleEnabled() { - this.model.setVisualizationEnabled(this.enable.getSelection()); - this.controller.update(new ModelEvent(this, ModelPart.SELECTED_UTILITY_VISUALIZATION, enable.getSelection())); + private void toggleChkbtnVisualization() { + this.model.setVisualizationEnabled(this.chkbtnVisualisation.getSelection()); + this.controller.update(new ModelEvent(this, ModelPart.SELECTED_UTILITY_VISUALIZATION, chkbtnVisualisation.getSelection())); } /** - * Toggle image. + * Toggle visualization icon. */ - private void toggleImage(){ - if (enable.getSelection()) { - enable.setImage(enabled); + private void toggleChkbtnVisualizationIcon(){ + if (chkbtnVisualisation.getSelection()) { + chkbtnVisualisation.setImage(icnVisEnabled); } else { - enable.setImage(disabled); + chkbtnVisualisation.setImage(icnVisDisabled); + } + } + + + /** + * Toggle UseKFold enabled. + */ + private void toggleChkbtnUseKFold() { + + // It should work when classification tab is active + if (folder.getSelectionIndex()==5) { + DataSubset trainingSubset = DataSubset.create(this.model.getInputConfig().getInput(), this.model.getInputConfig().getResearchSubset()); + + this.model.setVisualizationEnabled(false); + this.controller.update(new ModelEvent(this, ModelPart.SELECTED_UTILITY_VISUALIZATION, false)); + + this.model.getClassificationModel().getCurrentConfiguration().setEvaluateWithKfold(chkbtnUseKFold.getSelection()); + this.controller.update(new ModelEvent(this, ModelPart.SELECTED_UTILITY_KFold, chkbtnUseKFold.getSelection())); + + this.model.setVisualizationEnabled(true); + this.controller.update(new ModelEvent(this, ModelPart.SELECTED_UTILITY_VISUALIZATION,true)); + } + } + + /** + * Toggle UseKFold icon. + */ + private void toggleChkbtnUseKFoldIcon(){ + if (folder.getSelectionIndex()==5) { + if (this.chkbtnUseKFold.getSelection()) { + this.chkbtnUseKFold.setImage(icnUseKFoldEnabled); + } else { + this.chkbtnUseKFold.setImage(icnUseKFoldDisabled); + } } } -} +} \ No newline at end of file From 3be9d929b19aaec1375ccbaebec4c7ec94c8aea2 Mon Sep 17 00:00:00 2001 From: idhamari Date: Wed, 25 May 2022 08:42:51 +0200 Subject: [PATCH 7/7] fix inputHandle.getSubSet is null when called from GUI --- .../arx/gui/resources/Resources.java | 1 - .../aggregates/StatisticsClassification.java | 39 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/gui/org/deidentifier/arx/gui/resources/Resources.java b/src/gui/org/deidentifier/arx/gui/resources/Resources.java index 93ce501eb..0274297ed 100644 --- a/src/gui/org/deidentifier/arx/gui/resources/Resources.java +++ b/src/gui/org/deidentifier/arx/gui/resources/Resources.java @@ -252,7 +252,6 @@ public Image getManagedImage(final String name) { if (imageCache.containsKey(name)) { return imageCache.get(name); } else { - //System.out.println("xxxxxxxxxxxxx image: " + name); Image image = getImage(name); imageCache.put(name, image); return image; diff --git a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java index 409b4333f..d896b7701 100644 --- a/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java +++ b/src/main/org/deidentifier/arx/aggregates/StatisticsClassification.java @@ -836,9 +836,19 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, int confidencesIndex = 0; double total = 100d / (double)numSamples; - - // get the training data - DataSubset subsetTrain = inputHandle.getSubset(); + + // Training dataset + DataSubset subsetTrain ; + + // Complete dataset size + int dataSize = inputHandle.getNumRows(); + + // Get the training dataset + subsetTrain = inputHandle.getSubset(); + if (inputHandle.getSubset()== null) { + subsetTrain = inputHandle.getSuperset().getSubset(); + dataSize = inputHandle.getSuperset().getNumRows(); + } // do training @@ -851,34 +861,31 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, } // Try - try { + try { // Train with the training subset - for (int index : subsetTrain.getArray()) { - checkInterrupt(); - inputClassifier.train(inputHandle, outputHandle, index); + for (int index=0; index subsetIndicesTest = new HashSet(); - for (int i = 0; i < inputHandle.getNumRows(); ++i) { + for (int i = 0; i < dataSize; i++) { if (! subsetTrain.getSet().contains(i)) { subsetIndicesTest.add(i); } } - - for (int index : subsetIndicesTest ) { + for (int index=0; index< subsetIndicesTest.size();index++ ) { // Check checkInterrupt(); @@ -897,7 +904,7 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, double[] confidences = resultInputZR.confidences(); zerorConfidences[confidencesIndex] = index; System.arraycopy(confidences, 0, zerorConfidences, confidencesIndex + 1, confidences.length); - + // Maintain data about input-based classifier boolean correct = resultInput.correct(actualValue); this.originalAverageError += resultInput.error(actualValue); @@ -905,7 +912,7 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, confidences = resultInput.confidences(); inputConfidences[confidencesIndex] = index; System.arraycopy(confidences, 0, inputConfidences, confidencesIndex + 1, confidences.length); - + // Maintain data about output-based if (resultOutput != null) { correct = resultOutput.correct(actualValue); @@ -921,6 +928,8 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, this.progress.value = (int)((++done) * total); } } catch (Exception e) { + System.out.println("IA Error : " + e.getMessage() ); + if (e instanceof ComputationInterruptedException) { throw e; } else { @@ -983,6 +992,6 @@ private void evaluateWithTestingSet(DataHandleInternal inputHandle, this.accuracy = this.originalAccuracy; } - this.numMeasurements = classifications; + this.numMeasurements = classifications; } }