Search in sources :

Example 1 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class HdbscanModel method getClusters.

/**
 * Returns the features in each cluster exemplar.
 * <p>
 * In many cases this should be used in preference to {@link #getClusterExemplars()}
 * as it performs the mapping from Tribuo's internal feature ids to
 * the externally visible feature names.
 * @return The cluster exemplars.
 */
public List<Pair<Integer, List<Feature>>> getClusters() {
    List<Pair<Integer, List<Feature>>> list = new ArrayList<>(clusterExemplars.size());
    for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
        List<Feature> features = new ArrayList<>(e.getFeatures().numActiveElements());
        for (VectorTuple v : e.getFeatures()) {
            Feature f = new Feature(featureIDMap.get(v.index).getName(), v.value);
            features.add(f);
        }
        list.add(new Pair<>(e.getLabel(), features));
    }
    return list;
}
Also used : ArrayList(java.util.ArrayList) VectorTuple(org.tribuo.math.la.VectorTuple) Feature(org.tribuo.Feature) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 2 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class SparseLinearModel method getExcuse.

@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
    Prediction<Regressor> prediction = predict(example);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    SparseVector features = createFeatures(example);
    for (int i = 0; i < dimensions.length; i++) {
        List<Pair<String, Double>> classScores = new ArrayList<>();
        for (VectorTuple f : features) {
            double score = weights[i].get(f.index) * f.value;
            classScores.add(new Pair<>(featureIDMap.get(f.index).getName(), score));
        }
        classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
        weightMap.put(dimensions[i], classScores);
    }
    return Optional.of(new Excuse<>(example, prediction, weightMap));
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector) ArrayList(java.util.ArrayList) List(java.util.List) VectorTuple(org.tribuo.math.la.VectorTuple) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 3 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class SquaredLoss method lossAndGradient.

@Override
public Pair<Double, SGDVector> lossAndGradient(DenseVector truth, SGDVector prediction) {
    DenseVector difference = truth.subtract(prediction);
    double loss = difference.reduce(0.0, (a) -> 0.5 * a * a, Double::sum);
    return new Pair<>(loss, difference);
}
Also used : DenseVector(org.tribuo.math.la.DenseVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 4 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class AbsoluteLoss method lossAndGradient.

@Override
public Pair<Double, SGDVector> lossAndGradient(DenseVector truth, SGDVector prediction) {
    DenseVector difference = truth.subtract(prediction);
    DenseVector absoluteDifference = difference.copy();
    absoluteDifference.foreachInPlace(Math::abs);
    double loss = absoluteDifference.sum() - 0.5 * absoluteDifference.size();
    difference.foreachInPlace((a) -> Double.compare(a, 0.0));
    return new Pair<>(loss, difference);
}
Also used : DenseVector(org.tribuo.math.la.DenseVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 5 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class EvaluationAggregationTests method summarizeF1AcrossDatasets_v2.

@Test
public void summarizeF1AcrossDatasets_v2() {
    Pair<Dataset<Label>, Dataset<Label>> pair = LabelledDataGenerator.denseTrainTest(-0.3);
    Model<Label> model = DummyClassifierTrainer.createMostFrequentTrainer().train(pair.getA());
    List<Dataset<Label>> datasets = Arrays.asList(LabelledDataGenerator.denseTrainTest(-1.0).getB(), LabelledDataGenerator.denseTrainTest(-0.5).getB(), LabelledDataGenerator.denseTrainTest(-0.1).getB());
    Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
    Map<MetricID<Label>, DescriptiveStats> summaries = EvaluationAggregator.summarize(evaluator, model, datasets);
    MetricID<Label> macroF1 = LabelMetrics.F1.forTarget(MetricTarget.macroAverageTarget()).getID();
    DescriptiveStats summary = summaries.get(macroF1);
    // Can also do this:
    List<LabelEvaluation> evals = datasets.stream().map(dataset -> evaluator.evaluate(model, dataset)).collect(Collectors.toList());
    Map<MetricID<Label>, DescriptiveStats> summaries2 = EvaluationAggregator.summarize(evals);
    assertEquals(summaries, summaries2);
}
Also used : MetricTarget(org.tribuo.evaluation.metrics.MetricTarget) Arrays(java.util.Arrays) Evaluator(org.tribuo.evaluation.Evaluator) Prediction(org.tribuo.Prediction) Model(org.tribuo.Model) EvaluationAggregator(org.tribuo.evaluation.EvaluationAggregator) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Collectors(java.util.stream.Collectors) MetricID(org.tribuo.evaluation.metrics.MetricID) System.out(java.lang.System.out) ArrayList(java.util.ArrayList) Dataset(org.tribuo.Dataset) Test(org.junit.jupiter.api.Test) Trainer(org.tribuo.Trainer) DummyClassifierTrainer(org.tribuo.classification.baseline.DummyClassifierTrainer) List(java.util.List) LabelFactory(org.tribuo.classification.LabelFactory) Map(java.util.Map) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) LabelledDataGenerator(org.tribuo.classification.example.LabelledDataGenerator) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) Comparator(java.util.Comparator) Label(org.tribuo.classification.Label) CrossValidation(org.tribuo.evaluation.CrossValidation) Dataset(org.tribuo.Dataset) Label(org.tribuo.classification.Label) MetricID(org.tribuo.evaluation.metrics.MetricID) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) Test(org.junit.jupiter.api.Test)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)59 ArrayList (java.util.ArrayList)27 List (java.util.List)21 HashMap (java.util.HashMap)18 MutableDataset (org.tribuo.MutableDataset)17 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)16 Label (org.tribuo.classification.Label)14 Feature (org.tribuo.Feature)11 Regressor (org.tribuo.regression.Regressor)11 Prediction (org.tribuo.Prediction)10 DenseVector (org.tribuo.math.la.DenseVector)10 SparseVector (org.tribuo.math.la.SparseVector)10 SGDVector (org.tribuo.math.la.SGDVector)9 Map (java.util.Map)7 Example (org.tribuo.Example)7 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)7 PriorityQueue (java.util.PriorityQueue)6 Excuse (org.tribuo.Excuse)5 Model (org.tribuo.Model)5 LabelFactory (org.tribuo.classification.LabelFactory)5