use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class MLPlattScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class AdaBoostMHTest method test1.
static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
AdaBoostMH boosting = new AdaBoostMH(dataSet.getNumClasses());
AdaBoostMHTrainer trainer = new AdaBoostMHTrainer(dataSet, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 500; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
}
Aggregations