use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.
the class HMLGradientBoostingTest method spam_load.
static void spam_load() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP, "/hmlgb/boosting.ser"));
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("macro-averaged:");
System.out.println(new MacroAveragedMeasures(boosting, dataSet));
}
use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.
the class HMLGradientBoostingTest method spam_build.
static void spam_build() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
HMLGradientBoosting boosting = new HMLGradientBoosting(2, assignments);
HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(7).learningRate(0.1).numSplitIntervals(50).minDataPerLeaf(1).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 200; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
// System.out.println(Arrays.toString(boosting.getGradients(0)));
// System.out.println(Arrays.toString(boosting.getGradients(1)));
}
stopWatch.stop();
System.out.println(stopWatch);
System.out.println(boosting);
// for (int i=0;i<numDataPoints;i++){
// FeatureRow featureRow = dataSet.getRow(i);
// System.out.println("label="+dataSet.getMultiLabels()[i]);
// System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
// System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
// System.out.println(boosting.predict(featureRow));
// }
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
System.out.println("macro-averaged:");
System.out.println(new MacroAveragedMeasures(boosting, dataSet));
boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.
the class HMLGradientBoostingTest method test4_load.
static void test4_load() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/4labels/test.trec"), DataSetType.ML_CLF_DENSE, true);
HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP, "/hmlgb/boosting.ser"));
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("macro-averaged:");
System.out.println(new MacroAveragedMeasures(boosting, dataSet));
}
use of edu.neu.ccs.pyramid.eval.MacroAveragedMeasures in project pyramid by cheng-li.
the class HMLGradientBoostingTest method test4_build.
/**
* same as test3, the only difference is we now load data directly
* @throws Exception
*/
static void test4_build() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/4labels/train.trec"), DataSetType.ML_CLF_DENSE, true);
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
assignments.add(new MultiLabel().addLabel(1).addLabel(2));
assignments.add(new MultiLabel().addLabel(1).addLabel(3));
assignments.add(new MultiLabel().addLabel(1).addLabel(2).addLabel(3));
HMLGradientBoosting boosting = new HMLGradientBoosting(4, assignments);
HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(100).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 10; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
}
stopWatch.stop();
System.out.println(stopWatch);
// System.out.println(boosting);
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
System.out.println("macro-averaged:");
System.out.println(new MacroAveragedMeasures(boosting, dataSet));
boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
Aggregations