Search in sources :

Example 1 with Excuse

use of org.tribuo.Excuse in project tribuo by oracle.

the class LibLinearClassificationModel method innerGetExcuse.

/**
 * The call to model.getFeatureWeights in the public methods copies the
 * weights array so this inner method exists to save the copy in getExcuses.
 * <p>
 * If it becomes a problem then we could cache the feature weights in the
 * model.
 * @param e The example.
 * @param allFeatureWeights The feature weights.
 * @return An excuse for this example.
 */
@Override
protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWeights) {
    de.bwaldvogel.liblinear.Model model = models.get(0);
    double[] featureWeights = allFeatureWeights[0];
    int[] labels = model.getLabels();
    int numClasses = model.getNrClass();
    Prediction<Label> prediction = predict(e);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    if (numClasses == 2) {
        List<Pair<String, Double>> posScores = new ArrayList<>();
        List<Pair<String, Double>> negScores = new ArrayList<>();
        for (Feature f : e) {
            int id = featureIDMap.getID(f.getName());
            if (id > -1) {
                double score = featureWeights[id] * f.getValue();
                posScores.add(new Pair<>(f.getName(), score));
                negScores.add(new Pair<>(f.getName(), -score));
            }
        }
        posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
        negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
        weightMap.put(outputIDInfo.getOutput(labels[0]).getLabel(), posScores);
        weightMap.put(outputIDInfo.getOutput(labels[1]).getLabel(), negScores);
    } else {
        for (int i = 0; i < labels.length; i++) {
            List<Pair<String, Double>> classScores = new ArrayList<>();
            for (Feature f : e) {
                int id = featureIDMap.getID(f.getName());
                if (id > -1) {
                    double score = featureWeights[id * numClasses + i] * f.getValue();
                    classScores.add(new Pair<>(f.getName(), score));
                }
            }
            classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
            weightMap.put(outputIDInfo.getOutput(labels[i]).getLabel(), classScores);
        }
    }
    return new Excuse<>(e, prediction, weightMap);
}
Also used : ONNXNode(org.tribuo.util.onnx.ONNXNode) FeatureNode(de.bwaldvogel.liblinear.FeatureNode) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) ArrayList(java.util.ArrayList) List(java.util.List) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Excuse(org.tribuo.Excuse)

Example 2 with Excuse

use of org.tribuo.Excuse in project tribuo by oracle.

the class LibLinearRegressionModel method innerGetExcuse.

/**
 * The call to model.getFeatureWeights in the public methods copies the
 * weights array so this inner method exists to save the copy in getExcuses.
 * <p>
 * If it becomes a problem then we could cache the feature weights in the
 * model.
 *
 * @param e The example.
 * @param allFeatureWeights The feature weights.
 * @return An excuse for this example.
 */
@Override
protected Excuse<Regressor> innerGetExcuse(Example<Regressor> e, double[][] allFeatureWeights) {
    Prediction<Regressor> prediction = predict(e);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    for (int i = 0; i < allFeatureWeights.length; i++) {
        List<Pair<String, Double>> scores = new ArrayList<>();
        for (Feature f : e) {
            int id = featureIDMap.getID(f.getName());
            if (id > -1) {
                double score = allFeatureWeights[i][id] * f.getValue();
                scores.add(new Pair<>(f.getName(), score));
            }
        }
        scores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
        weightMap.put(dimensionNames[mapping[i]], scores);
    }
    return new Excuse<>(e, prediction, weightMap);
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) ArrayList(java.util.ArrayList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Excuse(org.tribuo.Excuse)

Example 3 with Excuse

use of org.tribuo.Excuse in project tribuo by oracle.

the class LibLinearAnomalyModel method innerGetExcuse.

/**
 * The call to model.getFeatureWeights in the public methods copies the
 * weights array so this inner method exists to save the copy in getExcuses.
 * <p>
 * If it becomes a problem then we could cache the feature weights in the
 * model.
 * @param e The example.
 * @param allFeatureWeights The feature weights.
 * @return An excuse for this example.
 */
@Override
protected Excuse<Event> innerGetExcuse(Example<Event> e, double[][] allFeatureWeights) {
    de.bwaldvogel.liblinear.Model model = models.get(0);
    double[] featureWeights = allFeatureWeights[0];
    Prediction<Event> prediction = predict(e);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    List<Pair<String, Double>> posScores = new ArrayList<>();
    List<Pair<String, Double>> negScores = new ArrayList<>();
    for (Feature f : e) {
        int id = featureIDMap.getID(f.getName());
        if (id > -1) {
            double score = featureWeights[id] * f.getValue();
            posScores.add(new Pair<>(f.getName(), score));
            negScores.add(new Pair<>(f.getName(), -score));
        }
    }
    posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
    negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
    weightMap.put(Event.EventType.ANOMALOUS.toString(), posScores);
    weightMap.put(Event.EventType.EXPECTED.toString(), negScores);
    return new Excuse<>(e, prediction, weightMap);
}
Also used : FeatureNode(de.bwaldvogel.liblinear.FeatureNode) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) Event(org.tribuo.anomaly.Event) ArrayList(java.util.ArrayList) List(java.util.List) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Excuse(org.tribuo.Excuse)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)3 ArrayList (java.util.ArrayList)3 HashMap (java.util.HashMap)3 List (java.util.List)3 Excuse (org.tribuo.Excuse)3 Feature (org.tribuo.Feature)3 FeatureNode (de.bwaldvogel.liblinear.FeatureNode)2 LinkedHashMap (java.util.LinkedHashMap)1 Event (org.tribuo.anomaly.Event)1 Label (org.tribuo.classification.Label)1 Regressor (org.tribuo.regression.Regressor)1 ONNXNode (org.tribuo.util.onnx.ONNXNode)1