use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig in project pyramid by cheng-li.
the class MLACPlattScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 10; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLACPlattScaling plattScaling = new MLACPlattScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig in project pyramid by cheng-li.
the class MLFlatScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLFlatScaling scaling = new MLFlatScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(scaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig in project pyramid by cheng-li.
the class MLPlattScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
Aggregations