use of org.tribuo.Excuse in project tribuo by oracle.
the class LibLinearClassificationModel method innerGetExcuse.
/**
* The call to model.getFeatureWeights in the public methods copies the
* weights array so this inner method exists to save the copy in getExcuses.
* <p>
* If it becomes a problem then we could cache the feature weights in the
* model.
* @param e The example.
* @param allFeatureWeights The feature weights.
* @return An excuse for this example.
*/
@Override
protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWeights) {
de.bwaldvogel.liblinear.Model model = models.get(0);
double[] featureWeights = allFeatureWeights[0];
int[] labels = model.getLabels();
int numClasses = model.getNrClass();
Prediction<Label> prediction = predict(e);
Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
if (numClasses == 2) {
List<Pair<String, Double>> posScores = new ArrayList<>();
List<Pair<String, Double>> negScores = new ArrayList<>();
for (Feature f : e) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = featureWeights[id] * f.getValue();
posScores.add(new Pair<>(f.getName(), score));
negScores.add(new Pair<>(f.getName(), -score));
}
}
posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(outputIDInfo.getOutput(labels[0]).getLabel(), posScores);
weightMap.put(outputIDInfo.getOutput(labels[1]).getLabel(), negScores);
} else {
for (int i = 0; i < labels.length; i++) {
List<Pair<String, Double>> classScores = new ArrayList<>();
for (Feature f : e) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = featureWeights[id * numClasses + i] * f.getValue();
classScores.add(new Pair<>(f.getName(), score));
}
}
classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(outputIDInfo.getOutput(labels[i]).getLabel(), classScores);
}
}
return new Excuse<>(e, prediction, weightMap);
}
use of org.tribuo.Excuse in project tribuo by oracle.
the class LibLinearRegressionModel method innerGetExcuse.
/**
* The call to model.getFeatureWeights in the public methods copies the
* weights array so this inner method exists to save the copy in getExcuses.
* <p>
* If it becomes a problem then we could cache the feature weights in the
* model.
*
* @param e The example.
* @param allFeatureWeights The feature weights.
* @return An excuse for this example.
*/
@Override
protected Excuse<Regressor> innerGetExcuse(Example<Regressor> e, double[][] allFeatureWeights) {
Prediction<Regressor> prediction = predict(e);
Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
for (int i = 0; i < allFeatureWeights.length; i++) {
List<Pair<String, Double>> scores = new ArrayList<>();
for (Feature f : e) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = allFeatureWeights[i][id] * f.getValue();
scores.add(new Pair<>(f.getName(), score));
}
}
scores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(dimensionNames[mapping[i]], scores);
}
return new Excuse<>(e, prediction, weightMap);
}
use of org.tribuo.Excuse in project tribuo by oracle.
the class LibLinearAnomalyModel method innerGetExcuse.
/**
* The call to model.getFeatureWeights in the public methods copies the
* weights array so this inner method exists to save the copy in getExcuses.
* <p>
* If it becomes a problem then we could cache the feature weights in the
* model.
* @param e The example.
* @param allFeatureWeights The feature weights.
* @return An excuse for this example.
*/
@Override
protected Excuse<Event> innerGetExcuse(Example<Event> e, double[][] allFeatureWeights) {
de.bwaldvogel.liblinear.Model model = models.get(0);
double[] featureWeights = allFeatureWeights[0];
Prediction<Event> prediction = predict(e);
Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
List<Pair<String, Double>> posScores = new ArrayList<>();
List<Pair<String, Double>> negScores = new ArrayList<>();
for (Feature f : e) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = featureWeights[id] * f.getValue();
posScores.add(new Pair<>(f.getName(), score));
negScores.add(new Pair<>(f.getName(), -score));
}
}
posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(Event.EventType.ANOMALOUS.toString(), posScores);
weightMap.put(Event.EventType.EXPECTED.toString(), negScores);
return new Excuse<>(e, prediction, weightMap);
}
Aggregations