Search in sources :

Example 1 with MacroAveragedMeasures

use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.

the class HMLGradientBoostingTest method spam_load.

static void spam_load() throws Exception {
    ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
    int numDataPoints = singleLabeldataSet.getNumDataPoints();
    int numFeatures = singleLabeldataSet.getNumFeatures();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
    int[] labels = singleLabeldataSet.getLabels();
    for (int i = 0; i < numDataPoints; i++) {
        dataSet.addLabel(i, labels[i]);
        for (int j = 0; j < numFeatures; j++) {
            double value = singleLabeldataSet.getRow(i).get(j);
            dataSet.setFeatureValue(i, j, value);
        }
    }
    HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP, "/hmlgb/boosting.ser"));
    System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
    System.out.println("macro-averaged:");
    System.out.println(new MacroAveragedMeasures(boosting, dataSet));
}
Also used : File(java.io.File) MacroAveragedMeasures(edu.neu.ccs.pyramid.eval.MacroAveragedMeasures)

Example 2 with MacroAveragedMeasures

use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.

the class HMLGradientBoostingTest method spam_build.

static void spam_build() throws Exception {
    ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
    int numDataPoints = singleLabeldataSet.getNumDataPoints();
    int numFeatures = singleLabeldataSet.getNumFeatures();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
    int[] labels = singleLabeldataSet.getLabels();
    for (int i = 0; i < numDataPoints; i++) {
        dataSet.addLabel(i, labels[i]);
        for (int j = 0; j < numFeatures; j++) {
            double value = singleLabeldataSet.getRow(i).get(j);
            dataSet.setFeatureValue(i, j, value);
        }
    }
    List<MultiLabel> assignments = new ArrayList<>();
    assignments.add(new MultiLabel().addLabel(0));
    assignments.add(new MultiLabel().addLabel(1));
    HMLGradientBoosting boosting = new HMLGradientBoosting(2, assignments);
    HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(7).learningRate(0.1).numSplitIntervals(50).minDataPerLeaf(1).dataSamplingRate(1).featureSamplingRate(1).build();
    System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
    HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 200; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
    //            System.out.println(Arrays.toString(boosting.getGradients(0)));
    //            System.out.println(Arrays.toString(boosting.getGradients(1)));
    }
    stopWatch.stop();
    System.out.println(stopWatch);
    System.out.println(boosting);
    //        for (int i=0;i<numDataPoints;i++){
    //            FeatureRow featureRow = dataSet.getRow(i);
    //            System.out.println("label="+dataSet.getMultiLabels()[i]);
    //            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
    //            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
    //            System.out.println(boosting.predict(featureRow));
    //        }
    System.out.println("accuracy");
    System.out.println(Accuracy.accuracy(boosting, dataSet));
    System.out.println("macro-averaged:");
    System.out.println(new MacroAveragedMeasures(boosting, dataSet));
    boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
Also used : ArrayList(java.util.ArrayList) MacroAveragedMeasures(edu.neu.ccs.pyramid.eval.MacroAveragedMeasures) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File)

Example 3 with MacroAveragedMeasures

use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.

the class HMLGradientBoostingTest method test4_load.

static void test4_load() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/4labels/test.trec"), DataSetType.ML_CLF_DENSE, true);
    HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP, "/hmlgb/boosting.ser"));
    System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
    System.out.println("macro-averaged:");
    System.out.println(new MacroAveragedMeasures(boosting, dataSet));
}
Also used : File(java.io.File) MacroAveragedMeasures(edu.neu.ccs.pyramid.eval.MacroAveragedMeasures)

Example 4 with MacroAveragedMeasures

use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.

the class HMLGradientBoostingTest method test4_build.

/**
     * same as test3, the only difference is we now load data directly
     * @throws Exception
     */
static void test4_build() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/4labels/train.trec"), DataSetType.ML_CLF_DENSE, true);
    List<MultiLabel> assignments = new ArrayList<>();
    assignments.add(new MultiLabel().addLabel(0));
    assignments.add(new MultiLabel().addLabel(1));
    assignments.add(new MultiLabel().addLabel(1).addLabel(2));
    assignments.add(new MultiLabel().addLabel(1).addLabel(3));
    assignments.add(new MultiLabel().addLabel(1).addLabel(2).addLabel(3));
    HMLGradientBoosting boosting = new HMLGradientBoosting(4, assignments);
    HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(100).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
    HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 10; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
    }
    stopWatch.stop();
    System.out.println(stopWatch);
    //        System.out.println(boosting);
    System.out.println("accuracy");
    System.out.println(Accuracy.accuracy(boosting, dataSet));
    System.out.println("macro-averaged:");
    System.out.println(new MacroAveragedMeasures(boosting, dataSet));
    boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
Also used : ArrayList(java.util.ArrayList) File(java.io.File) MacroAveragedMeasures(edu.neu.ccs.pyramid.eval.MacroAveragedMeasures) StopWatch(org.apache.commons.lang3.time.StopWatch)

Aggregations

MacroAveragedMeasures (edu.neu.ccs.pyramid.eval.MacroAveragedMeasures)4 File (java.io.File)4 ArrayList (java.util.ArrayList)2 StopWatch (org.apache.commons.lang3.time.StopWatch)2