use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.
the class HMLGradientBoostingTest method test3_load.
static void test3_load() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(4).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
dataSet.addLabel(i, 2);
}
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(1) < 0.1) {
dataSet.addLabel(i, 3);
}
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
List<String> extLabels = new ArrayList<>();
extLabels.add("non_spam");
extLabels.add("spam");
extLabels.add("fake2");
extLabels.add("fake3");
LabelTranslator labelTranslator = new LabelTranslator(extLabels);
dataSet.setLabelTranslator(labelTranslator);
HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP, "/hmlgb/boosting.ser"));
System.out.println(Accuracy.accuracy(boosting, dataSet));
for (int i = 0; i < numDataPoints; i++) {
Vector featureRow = dataSet.getRow(i);
MultiLabel label = dataSet.getMultiLabels()[i];
MultiLabel prediction = boosting.predict(featureRow);
// System.out.println("label="+label);
// System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
// System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
// System.out.println("prediction="+prediction);
// if (!MultiLabel.equivalent(label,prediction)){
// System.out.println(i);
// System.out.println("label="+label);
// System.out.println("prediction="+prediction);
// }
}
MultiLabelPredictionAnalysis analysis = HMLGBInspector.analyzePrediction(boosting, dataSet, 0, 10);
ObjectMapper mapper1 = new ObjectMapper();
mapper1.writeValue(new File(TMP, "prediction_analysis.json"), analysis);
}
Aggregations