use of edu.neu.ccs.pyramid.eval.Accuracy in project pyramid by cheng-li.
the class IMLGradientBoostingTest method test4.
static void test4() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 10; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
System.out.println("label = ");
System.out.println(dataSet.getMultiLabels()[0]);
System.out.println("pro for 1 = " + boosting.predictClassProb(dataSet.getRow(0), 1));
System.out.println("pro for 17 = " + boosting.predictClassProb(dataSet.getRow(0), 17));
// System.out.println(boosting.predictAssignmentProb(dataSet.getRow(0),dataSet.getMultiLabels()[0]));
// System.out.println(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), dataSet.getMultiLabels()[0]));
System.out.println(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(0), dataSet.getMultiLabels()[0]));
for (MultiLabel multiLabel : boosting.getAssignments()) {
System.out.println("multilabel = " + multiLabel);
System.out.println("prob = " + boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), multiLabel));
}
double sum = boosting.getAssignments().stream().mapToDouble(multiLabel -> boosting.predictAssignmentProbWithConstraint(dataSet.getRow(0), multiLabel)).sum();
System.out.println(sum);
}
use of edu.neu.ccs.pyramid.eval.Accuracy in project pyramid by cheng-li.
the class EMLevelEval method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
RegDataSet train = TRECFormat.loadRegDataSet(config.getString("input.trainData"), DataSetType.REG_SPARSE, true);
Set<Double> unique = new HashSet<>();
for (double d : train.getLabels()) {
unique.add(d);
}
List<Double> levels = unique.stream().sorted().collect(Collectors.toList());
RegDataSet test = TRECFormat.loadRegDataSet(config.getString("input.testData"), DataSetType.REG_SPARSE, true);
double[] doubleTruth = test.getLabels();
double[] doublePred = loadPrediction(config.getString("input.prediction"));
double[] roundedPred = Arrays.stream(doublePred).map(d -> round(d, levels)).toArray();
System.out.println("before rounding");
System.out.println("rmse = " + RMSE.rmse(doubleTruth, doublePred));
System.out.println("after rounding");
System.out.println("rmse = " + RMSE.rmse(doubleTruth, roundedPred));
System.out.println("accuracy = " + IntStream.range(0, test.getNumDataPoints()).filter(i -> doubleTruth[i] == roundedPred[i]).count() / (double) test.getNumDataPoints());
System.out.println("the distribution of predicted label for a given true label");
for (int l = 0; l < levels.size(); l++) {
double level = levels.get(l);
System.out.println("for true label " + level);
truthToPred(test.getLabels(), roundedPred, level, levels);
}
System.out.println("=============================");
System.out.println("the distribution of true label for a given predicted label");
for (int l = 0; l < levels.size(); l++) {
double level = levels.get(l);
System.out.println("for predicted label " + level);
predToTruth(test.getLabels(), roundedPred, level, levels);
}
}
Aggregations