Skip to content

Commit

Permalink
Refactored performance measure after discretization changes #21
Browse files Browse the repository at this point in the history
  • Loading branch information
B96 committed Jul 23, 2019
1 parent 36923bc commit 8fb2714
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 247 deletions.

This file was deleted.

18 changes: 0 additions & 18 deletions AutoTuning/src/main/java/LossFunctions/Accuracy/Feature.java

This file was deleted.

29 changes: 0 additions & 29 deletions AutoTuning/src/main/java/LossFunctions/Accuracy/MetricFeature.java

This file was deleted.

104 changes: 0 additions & 104 deletions AutoTuning/src/main/java/LossFunctions/Accuracy/PredictionModel.java

This file was deleted.

74 changes: 0 additions & 74 deletions AutoTuning/src/main/java/LossFunctions/Accuracy/Rule.java

This file was deleted.

81 changes: 81 additions & 0 deletions AutoTuning/src/main/java/LossFunctions/PredictionModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package LossFunctions;

import de.viadee.xai.anchor.adapter.tabular.TabularInstance;
import de.viadee.xai.anchor.algorithm.AnchorResult;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PredictionModel {

private static final Logger LOGGER = LoggerFactory.getLogger(PredictionModel.class);

/**
* - if one rule applies predict the label according to rule
* - if two rules apply predict according to more precise rule
* - if both rules are equally precise predict randomly
*/

private final List<AnchorResult<TabularInstance>> rules;

public PredictionModel(List<AnchorResult<TabularInstance>> globalExplanations) {
this.rules = globalExplanations;
// sort by precision so higher precisions get prioritized
rules.sort(Comparator.comparingDouble(AnchorResult<TabularInstance>::getPrecision).reversed());

}

/**
* @param instances
* @return
*/
public List<Integer> predict(TabularInstance[] instances) {

List<Integer> predictions = new ArrayList<>();

for (TabularInstance instance : instances) {
predictions.add(this.predictSingle(instance));
}
return predictions;
}

/**
* @param instance
* @return
*/
public int predictSingle(TabularInstance instance) {

int ruleNumber = 0;

for (AnchorResult<TabularInstance> rule : this.rules) {
int numberMatches = 0;
ruleNumber++;
// System.out.println("Check rule " + ruleNumber + " with precision " + r.getPrecision() + " and label " + r.getLabel());

for (int f : rule.getCanonicalFeatures()) {

double instanceValue = instance.getValue(f);
double ruleValue = rule.getInstance().getValue(f);

System.out.println("Feature: " + rule.getInstance().getFeatures()[f].getName() + " - Instance value: " + instanceValue + " ---- Rule value: " + ruleValue);

if (instanceValue == ruleValue) {
numberMatches++;
}
}

if (numberMatches == (rule.getCanonicalFeatures().size())) {
LOGGER.info("Predict " + rule.getLabel() + " for the instance based on rule " + ruleNumber + ".");
return rule.getLabel();
}
}

LOGGER.info("No rule found to predict the instance.");
return -1;
}

}
2 changes: 1 addition & 1 deletion AutoTuning/src/main/java/RandomSearch/RandomSearch.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package RandomSearch;

import LossFunctions.Accuracy.PredictionModel;
import LossFunctions.PredictionModel;
import LossFunctions.PerformanceMeasures;
import de.viadee.xai.anchor.adapter.tabular.AnchorTabular;
import de.viadee.xai.anchor.adapter.tabular.TabularInstance;
Expand Down

0 comments on commit 8fb2714

Please sign in to comment.