Search in sources :

Example 11 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class MultinomialNaiveBayesModel method getExcuse.

@Override
public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
    Map<String, List<Pair<String, Double>>> explanation = new HashMap<>();
    for (Pair<Integer, Label> label : outputIDInfo) {
        List<Pair<String, Double>> scores = new ArrayList<>();
        for (Feature f : example) {
            int id = featureIDMap.getID(f.getName());
            if (id > -1) {
                scores.add(new Pair<>(f.getName(), labelWordProbs.getRow(label.getA()).get(id)));
            }
        }
        explanation.put(label.getB().getLabel(), scores);
    }
    return Optional.of(new Excuse<>(example, predict(example), explanation));
}
Also used : HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) ArrayList(java.util.ArrayList) List(java.util.List) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 12 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class MultinomialNaiveBayesModel method getTopFeatures.

@Override
public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
    int maxFeatures = n < 0 ? featureIDMap.size() : n;
    Map<String, List<Pair<String, Double>>> topFeatures = new HashMap<>();
    for (Pair<Integer, Label> label : outputIDInfo) {
        List<Pair<String, Double>> features = new ArrayList<>(labelWordProbs.numActiveElements(label.getA()));
        for (VectorTuple vt : labelWordProbs.getRow(label.getA())) {
            features.add(new Pair<>(featureIDMap.get(vt.index).getName(), vt.value));
        }
        features.sort(Comparator.comparing(x -> -x.getB()));
        if (maxFeatures < featureIDMap.size()) {
            features = features.subList(0, maxFeatures);
        }
        topFeatures.put(label.getB().getLabel(), features);
    }
    return topFeatures;
}
Also used : Example(org.tribuo.Example) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) VectorTuple(org.tribuo.math.la.VectorTuple) ModelProvenance(org.tribuo.provenance.ModelProvenance) Prediction(org.tribuo.Prediction) ImmutableOutputInfo(org.tribuo.ImmutableOutputInfo) Model(org.tribuo.Model) HashMap(java.util.HashMap) Pair(com.oracle.labs.mlrg.olcut.util.Pair) DenseVector(org.tribuo.math.la.DenseVector) ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) ExpNormalizer(org.tribuo.math.util.ExpNormalizer) LinkedHashMap(java.util.LinkedHashMap) Feature(org.tribuo.Feature) List(java.util.List) Map(java.util.Map) Excuse(org.tribuo.Excuse) Optional(java.util.Optional) Comparator(java.util.Comparator) Label(org.tribuo.classification.Label) SparseVector(org.tribuo.math.la.SparseVector) VectorNormalizer(org.tribuo.math.util.VectorNormalizer) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) VectorTuple(org.tribuo.math.la.VectorTuple) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 13 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LIMEBase method explainWithSamples.

protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
    // Predict using the full model, and generate a new example containing that prediction.
    Prediction<Label> prediction = innerModel.predict(example);
    Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction), example, 1.0f);
    // Sample a dataset.
    List<Example<Regressor>> sample = sampleData(example);
    // Generate a sparse model on the sampled data.
    SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
    // Test the sparse model against the predictions of the real model.
    List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
    predictions.add(model.predict(labelledExample));
    RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
    return new Pair<>(new LIMEExplanation(model, prediction, evaluation), sample);
}
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 14 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class IndependentRegressionTreeModel method getExcuse.

@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
    SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
    if (vec.numActiveElements() == 0) {
        return Optional.empty();
    }
    List<String> list = new ArrayList<>();
    List<Prediction<Regressor>> predList = new ArrayList<>();
    Map<String, List<Pair<String, Double>>> map = new HashMap<>();
    for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
        list.clear();
        // 
        // Ensures we handle collisions correctly
        Node<Regressor> oldNode = e.getValue();
        Node<Regressor> curNode = e.getValue();
        while (curNode != null) {
            oldNode = curNode;
            if (oldNode instanceof SplitNode) {
                SplitNode<?> node = (SplitNode<?>) curNode;
                list.add(featureIDMap.get(node.getFeatureID()).getName());
            }
            curNode = oldNode.getNextNode(vec);
        }
        // 
        // oldNode must be a LeafNode.
        predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
        List<Pair<String, Double>> pairs = new ArrayList<>();
        int i = list.size() + 1;
        for (String s : list) {
            pairs.add(new Pair<>(s, i + 0.0));
            i--;
        }
        map.put(e.getKey(), pairs);
    }
    Prediction<Regressor> combinedPrediction = combine(predList);
    return Optional.of(new Excuse<>(example, combinedPrediction, map));
}
Also used : HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) SplitNode(org.tribuo.common.tree.SplitNode) LeafNode(org.tribuo.common.tree.LeafNode) Node(org.tribuo.common.tree.Node) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector) SplitNode(org.tribuo.common.tree.SplitNode) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 15 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LibLinearRegressionModel method innerGetExcuse.

/**
 * The call to model.getFeatureWeights in the public methods copies the
 * weights array so this inner method exists to save the copy in getExcuses.
 * <p>
 * If it becomes a problem then we could cache the feature weights in the
 * model.
 *
 * @param e The example.
 * @param allFeatureWeights The feature weights.
 * @return An excuse for this example.
 */
@Override
protected Excuse<Regressor> innerGetExcuse(Example<Regressor> e, double[][] allFeatureWeights) {
    Prediction<Regressor> prediction = predict(e);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    for (int i = 0; i < allFeatureWeights.length; i++) {
        List<Pair<String, Double>> scores = new ArrayList<>();
        for (Feature f : e) {
            int id = featureIDMap.getID(f.getName());
            if (id > -1) {
                double score = allFeatureWeights[i][id] * f.getValue();
                scores.add(new Pair<>(f.getName(), score));
            }
        }
        scores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
        weightMap.put(dimensionNames[mapping[i]], scores);
    }
    return new Excuse<>(e, prediction, weightMap);
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) ArrayList(java.util.ArrayList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Excuse(org.tribuo.Excuse)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)59 ArrayList (java.util.ArrayList)27 List (java.util.List)21 HashMap (java.util.HashMap)18 MutableDataset (org.tribuo.MutableDataset)17 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)16 Label (org.tribuo.classification.Label)14 Feature (org.tribuo.Feature)11 Regressor (org.tribuo.regression.Regressor)11 Prediction (org.tribuo.Prediction)10 DenseVector (org.tribuo.math.la.DenseVector)10 SparseVector (org.tribuo.math.la.SparseVector)10 SGDVector (org.tribuo.math.la.SGDVector)9 Map (java.util.Map)7 Example (org.tribuo.Example)7 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)7 PriorityQueue (java.util.PriorityQueue)6 Excuse (org.tribuo.Excuse)5 Model (org.tribuo.Model)5 LabelFactory (org.tribuo.classification.LabelFactory)5