Search in sources :

Example 1 with IMLGBTrainer

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer in project pyramid by cheng-li.

the class MLACPlattScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 10; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLACPlattScaling plattScaling = new MLACPlattScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 2 with IMLGBTrainer

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer in project pyramid by cheng-li.

the class MLFlatScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 100; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLFlatScaling scaling = new MLFlatScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(scaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : MLFlatScaling(edu.neu.ccs.pyramid.multilabel_classification.MLFlatScaling) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 3 with IMLGBTrainer

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer in project pyramid by cheng-li.

the class MLPlattScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 100; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Aggregations

IMLGBConfig (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig)3 IMLGBTrainer (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer)3 IMLGradientBoosting (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)3 File (java.io.File)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 MLFlatScaling (edu.neu.ccs.pyramid.multilabel_classification.MLFlatScaling)1