Search in sources :

Example 1 with Accuracy

use of edu.neu.ccs.pyramid.eval.Accuracy in project pyramid by cheng-li.

the class IMLGradientBoostingTest method test4.

static void test4() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 10; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
    System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
    System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
    System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
    System.out.println("label = ");
    System.out.println(dataSet.getMultiLabels()[0]);
    System.out.println("pro for 1 = " + boosting.predictClassProb(dataSet.getRow(0), 1));
    System.out.println("pro for 17 = " + boosting.predictClassProb(dataSet.getRow(0), 17));
    //        System.out.println(boosting.predictAssignmentProb(dataSet.getRow(0),dataSet.getMultiLabels()[0]));
    //        System.out.println(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), dataSet.getMultiLabels()[0]));
    System.out.println(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(0), dataSet.getMultiLabels()[0]));
    for (MultiLabel multiLabel : boosting.getAssignments()) {
        System.out.println("multilabel = " + multiLabel);
        System.out.println("prob = " + boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), multiLabel));
    }
    double sum = boosting.getAssignments().stream().mapToDouble(multiLabel -> boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), multiLabel)).sum();
    System.out.println(sum);
}
Also used : java.util(java.util) Arrays(java.util.Arrays) DenseVector(org.apache.mahout.math.DenseVector) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) Vector(org.apache.mahout.math.Vector) StopWatch(org.apache.commons.lang3.time.StopWatch) Overlap(edu.neu.ccs.pyramid.eval.Overlap) Collectors(java.util.stream.Collectors) File(java.io.File) Config(edu.neu.ccs.pyramid.configuration.Config) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 2 with Accuracy

use of edu.neu.ccs.pyramid.eval.Accuracy in project pyramid by cheng-li.

the class EMLevelEval method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    RegDataSet train = TRECFormat.loadRegDataSet(config.getString("input.trainData"), DataSetType.REG_SPARSE, true);
    Set<Double> unique = new HashSet<>();
    for (double d : train.getLabels()) {
        unique.add(d);
    }
    List<Double> levels = unique.stream().sorted().collect(Collectors.toList());
    RegDataSet test = TRECFormat.loadRegDataSet(config.getString("input.testData"), DataSetType.REG_SPARSE, true);
    double[] doubleTruth = test.getLabels();
    double[] doublePred = loadPrediction(config.getString("input.prediction"));
    double[] roundedPred = Arrays.stream(doublePred).map(d -> round(d, levels)).toArray();
    System.out.println("before rounding");
    System.out.println("rmse = " + RMSE.rmse(doubleTruth, doublePred));
    System.out.println("after rounding");
    System.out.println("rmse = " + RMSE.rmse(doubleTruth, roundedPred));
    System.out.println("accuracy = " + IntStream.range(0, test.getNumDataPoints()).filter(i -> doubleTruth[i] == roundedPred[i]).count() / (double) test.getNumDataPoints());
    System.out.println("the distribution of predicted label for a given true label");
    for (int l = 0; l < levels.size(); l++) {
        double level = levels.get(l);
        System.out.println("for true label " + level);
        truthToPred(test.getLabels(), roundedPred, level, levels);
    }
    System.out.println("=============================");
    System.out.println("the distribution of true label for a given predicted label");
    for (int l = 0; l < levels.size(); l++) {
        double level = levels.get(l);
        System.out.println("for predicted label " + level);
        predToTruth(test.getLabels(), roundedPred, level, levels);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) HashSet(java.util.HashSet) List(java.util.List) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) RMSE(edu.neu.ccs.pyramid.eval.RMSE) Config(edu.neu.ccs.pyramid.configuration.Config) Config(edu.neu.ccs.pyramid.configuration.Config) HashSet(java.util.HashSet)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 Accuracy (edu.neu.ccs.pyramid.eval.Accuracy)2 File (java.io.File)2 Arrays (java.util.Arrays)2 Collectors (java.util.stream.Collectors)2 Overlap (edu.neu.ccs.pyramid.eval.Overlap)1 RMSE (edu.neu.ccs.pyramid.eval.RMSE)1 ConstantRegressor (edu.neu.ccs.pyramid.regression.ConstantRegressor)1 IOException (java.io.IOException)1 java.util (java.util)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Set (java.util.Set)1 IntStream (java.util.stream.IntStream)1 FileUtils (org.apache.commons.io.FileUtils)1 StopWatch (org.apache.commons.lang3.time.StopWatch)1 DenseVector (org.apache.mahout.math.DenseVector)1 Vector (org.apache.mahout.math.Vector)1