use of edu.neu.ccs.pyramid.eval.Overlap in project pyramid by cheng-li.
the class IMLGradientBoostingTest method test4.
static void test4() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 10; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
System.out.println("label = ");
System.out.println(dataSet.getMultiLabels()[0]);
System.out.println("pro for 1 = " + boosting.predictClassProb(dataSet.getRow(0), 1));
System.out.println("pro for 17 = " + boosting.predictClassProb(dataSet.getRow(0), 17));
// System.out.println(boosting.predictAssignmentProb(dataSet.getRow(0),dataSet.getMultiLabels()[0]));
// System.out.println(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), dataSet.getMultiLabels()[0]));
System.out.println(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(0), dataSet.getMultiLabels()[0]));
for (MultiLabel multiLabel : boosting.getAssignments()) {
System.out.println("multilabel = " + multiLabel);
System.out.println("prob = " + boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), multiLabel));
}
double sum = boosting.getAssignments().stream().mapToDouble(multiLabel -> boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), multiLabel)).sum();
System.out.println(sum);
}
Aggregations