-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored performance measure after discretization changes #21
- Loading branch information
B96
committed
Jul 23, 2019
1 parent
36923bc
commit 8fb2714
Showing
7 changed files
with
82 additions
and
247 deletions.
There are no files selected for viewing
21 changes: 0 additions & 21 deletions
21
AutoTuning/src/main/java/LossFunctions/Accuracy/CategoricalFeature.java
This file was deleted.
Oops, something went wrong.
18 changes: 0 additions & 18 deletions
18
AutoTuning/src/main/java/LossFunctions/Accuracy/Feature.java
This file was deleted.
Oops, something went wrong.
29 changes: 0 additions & 29 deletions
29
AutoTuning/src/main/java/LossFunctions/Accuracy/MetricFeature.java
This file was deleted.
Oops, something went wrong.
104 changes: 0 additions & 104 deletions
104
AutoTuning/src/main/java/LossFunctions/Accuracy/PredictionModel.java
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
81 changes: 81 additions & 0 deletions
81
AutoTuning/src/main/java/LossFunctions/PredictionModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package LossFunctions; | ||
|
||
import de.viadee.xai.anchor.adapter.tabular.TabularInstance; | ||
import de.viadee.xai.anchor.algorithm.AnchorResult; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Comparator; | ||
import java.util.List; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
public class PredictionModel { | ||
|
||
private static final Logger LOGGER = LoggerFactory.getLogger(PredictionModel.class); | ||
|
||
/** | ||
* - if one rule applies predict the label according to rule | ||
* - if two rules apply predict according to more precise rule | ||
* - if both rules are equally precise predict randomly | ||
*/ | ||
|
||
private final List<AnchorResult<TabularInstance>> rules; | ||
|
||
public PredictionModel(List<AnchorResult<TabularInstance>> globalExplanations) { | ||
this.rules = globalExplanations; | ||
// sort by precision so higher precisions get prioritized | ||
rules.sort(Comparator.comparingDouble(AnchorResult<TabularInstance>::getPrecision).reversed()); | ||
|
||
} | ||
|
||
/** | ||
* @param instances | ||
* @return | ||
*/ | ||
public List<Integer> predict(TabularInstance[] instances) { | ||
|
||
List<Integer> predictions = new ArrayList<>(); | ||
|
||
for (TabularInstance instance : instances) { | ||
predictions.add(this.predictSingle(instance)); | ||
} | ||
return predictions; | ||
} | ||
|
||
/** | ||
* @param instance | ||
* @return | ||
*/ | ||
public int predictSingle(TabularInstance instance) { | ||
|
||
int ruleNumber = 0; | ||
|
||
for (AnchorResult<TabularInstance> rule : this.rules) { | ||
int numberMatches = 0; | ||
ruleNumber++; | ||
// System.out.println("Check rule " + ruleNumber + " with precision " + r.getPrecision() + " and label " + r.getLabel()); | ||
|
||
for (int f : rule.getCanonicalFeatures()) { | ||
|
||
double instanceValue = instance.getValue(f); | ||
double ruleValue = rule.getInstance().getValue(f); | ||
|
||
System.out.println("Feature: " + rule.getInstance().getFeatures()[f].getName() + " - Instance value: " + instanceValue + " ---- Rule value: " + ruleValue); | ||
|
||
if (instanceValue == ruleValue) { | ||
numberMatches++; | ||
} | ||
} | ||
|
||
if (numberMatches == (rule.getCanonicalFeatures().size())) { | ||
LOGGER.info("Predict " + rule.getLabel() + " for the instance based on rule " + ruleNumber + "."); | ||
return rule.getLabel(); | ||
} | ||
} | ||
|
||
LOGGER.info("No rule found to predict the instance."); | ||
return -1; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters