Search in sources :

Example 6 with MultiLabelPredictionAnalysis

use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.

the class HMLGradientBoostingTest method test3_load.

static void test3_load() throws Exception {
    ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
    int numDataPoints = singleLabeldataSet.getNumDataPoints();
    int numFeatures = singleLabeldataSet.getNumFeatures();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(4).build();
    int[] labels = singleLabeldataSet.getLabels();
    for (int i = 0; i < numDataPoints; i++) {
        dataSet.addLabel(i, labels[i]);
        if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
            dataSet.addLabel(i, 2);
        }
        if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(1) < 0.1) {
            dataSet.addLabel(i, 3);
        }
        for (int j = 0; j < numFeatures; j++) {
            double value = singleLabeldataSet.getRow(i).get(j);
            dataSet.setFeatureValue(i, j, value);
        }
    }
    List<String> extLabels = new ArrayList<>();
    extLabels.add("non_spam");
    extLabels.add("spam");
    extLabels.add("fake2");
    extLabels.add("fake3");
    LabelTranslator labelTranslator = new LabelTranslator(extLabels);
    dataSet.setLabelTranslator(labelTranslator);
    HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP, "/hmlgb/boosting.ser"));
    System.out.println(Accuracy.accuracy(boosting, dataSet));
    for (int i = 0; i < numDataPoints; i++) {
        Vector featureRow = dataSet.getRow(i);
        MultiLabel label = dataSet.getMultiLabels()[i];
        MultiLabel prediction = boosting.predict(featureRow);
    //            System.out.println("label="+label);
    //            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
    //            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
    //            System.out.println("prediction="+prediction);
    //            if (!MultiLabel.equivalent(label,prediction)){
    //                System.out.println(i);
    //                System.out.println("label="+label);
    //                System.out.println("prediction="+prediction);
    //            }
    }
    MultiLabelPredictionAnalysis analysis = HMLGBInspector.analyzePrediction(boosting, dataSet, 0, 10);
    ObjectMapper mapper1 = new ObjectMapper();
    mapper1.writeValue(new File(TMP, "prediction_analysis.json"), analysis);
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) ArrayList(java.util.ArrayList) File(java.io.File) Vector(org.apache.mahout.math.Vector) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Aggregations

MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)6 Vector (org.apache.mahout.math.Vector)4 Feature (edu.neu.ccs.pyramid.feature.Feature)3 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)3 ArrayList (java.util.ArrayList)3 Collectors (java.util.stream.Collectors)3 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)2 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)2 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)2 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)2 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)2 Pair (edu.neu.ccs.pyramid.util.Pair)2 File (java.io.File)2 java.util (java.util)2 IntStream (java.util.stream.IntStream)2 JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)1 JsonFactory (com.fasterxml.jackson.core.JsonFactory)1 JsonGenerator (com.fasterxml.jackson.core.JsonGenerator)1