Search in sources :

Example 6 with GeneralF1Predictor

use of edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor in project pyramid by cheng-li.

the class PluginF1 method showPredictBySupport.

// public MultiLabel showPredictBySampling(Vector vector){
// System.out.println("sampling procedure");
// //        List<MultiLabel> samples = cbm.samples(vector, numSamples);
// Pair<List<MultiLabel>, List<Double>> pair = cbm.samples(vector, probMassThreshold);
// List<Pair<MultiLabel, Double>> list = new ArrayList<>();
// List<MultiLabel> labels = pair.getFirst();
// List<Double> probs = pair.getSecond();
// for (int i=0;i<labels.size();i++){
// list.add(new Pair<>(labels.get(i),probs.get(i)));
// }
// Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a-> a.getSecond());
// 
// System.out.println(list.stream().sorted(comparator.reversed()).collect(Collectors.toList()));
// 
// 
// 
// //        for (int i=0;i<labels.size();i++){
// //            System.out.println(labels.get(i)+": "+probs.get(i));
// //        }
// return GeneralF1Predictor.predict(cbm.getNumClasses(),pair.getFirst(), pair.getSecond());
// }
// 
// public void showPredictBySamplingNonEmpty(Vector vector){
// System.out.println("sampling procedure");
// Pair<List<MultiLabel>, List<Double>> pair = cbm.sampleNonEmptySets(vector, probMassThreshold);
// List<Pair<MultiLabel, Double>> list = new ArrayList<>();
// List<MultiLabel> labels = pair.getFirst();
// List<Double> probs = pair.getSecond();
// double[] probsArray = probs.stream().mapToDouble(a->a).toArray();
// 
// for (int i=0;i<labels.size();i++){
// list.add(new Pair<>(labels.get(i),probs.get(i)));
// }
// Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a-> a.getSecond());
// 
// MultiLabel gfmPred =  GeneralF1Predictor.predict(cbm.getNumClasses(),pair.getFirst(), pair.getSecond());
// MultiLabel argmaxPre = cbm.predict(vector);
// System.out.println("expected f1 of argmax predictor= "+GeneralF1Predictor.expectedF1(labels,probsArray, argmaxPre,cbm.getNumClasses()));
// System.out.println("expected f1 of GFM predictor= "+GeneralF1Predictor.expectedF1(labels,probsArray, gfmPred,cbm.getNumClasses()));
// 
// System.out.println(list.stream().sorted(comparator.reversed()).collect(Collectors.toList()));
// }
public GeneralF1Predictor.Analysis showPredictBySupport(Vector vector, MultiLabel truth) {
    // System.out.println("support procedure");
    double[] probArray = cbm.predictAssignmentProbs(vector, support);
    GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
    MultiLabel prediction = generalF1Predictor.predict(cbm.getNumClasses(), support, probArray);
    GeneralF1Predictor.Analysis analysis = GeneralF1Predictor.showSupportPrediction(support, probArray, truth, prediction, cbm.getNumClasses());
    return analysis;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)

Example 7 with GeneralF1Predictor

use of edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor in project pyramid by cheng-li.

the class InstanceF1Predictor method predict.

@Override
public MultiLabel predict(Vector vector) {
    List<MultiLabel> supports = cmlcrf.getSupportCombinations();
    double[] probs = cmlcrf.predictCombinationProbs(vector);
    GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
    return generalF1Predictor.predict(numClasses, supports, probs);
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)

Example 8 with GeneralF1Predictor

use of edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor in project pyramid by cheng-li.

the class CBMSF1Predictor method predictBySupport.

private MultiLabel predictBySupport(Vector vector) {
    double[] probs = cbm.predictAssignmentProbs(vector, support);
    GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
    return generalF1Predictor.predict(cbm.getNumClasses(), support, probs);
}
Also used : GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)

Aggregations

GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)8 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)6 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)2 Collectors (java.util.stream.Collectors)2 Vector (org.apache.mahout.math.Vector)2 ConcurrentHashMultiset (com.google.common.collect.ConcurrentHashMultiset)1 Multiset (com.google.common.collect.Multiset)1 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)1 Pair (edu.neu.ccs.pyramid.util.Pair)1 java.util (java.util)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Matrix (org.apache.mahout.math.Matrix)1