use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class MLFlatScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLFlatScaling scaling = new MLFlatScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(scaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class HMLGradientBoostingTest method test5.
static void test5() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
HMLGradientBoosting boosting = new HMLGradientBoosting(dataSet.getNumClasses(), assignments);
HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class HMLGradientBoostingTest method spam_build.
static void spam_build() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
HMLGradientBoosting boosting = new HMLGradientBoosting(2, assignments);
HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(7).learningRate(0.1).numSplitIntervals(50).minDataPerLeaf(1).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 200; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
// System.out.println(Arrays.toString(boosting.getGradients(0)));
// System.out.println(Arrays.toString(boosting.getGradients(1)));
}
stopWatch.stop();
System.out.println(stopWatch);
System.out.println(boosting);
// for (int i=0;i<numDataPoints;i++){
// FeatureRow featureRow = dataSet.getRow(i);
// System.out.println("label="+dataSet.getMultiLabels()[i]);
// System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
// System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
// System.out.println(boosting.predict(featureRow));
// }
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
System.out.println("macro-averaged:");
System.out.println(new MacroAveragedMeasures(boosting, dataSet));
boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class HMLGradientBoostingTest method test3_build.
/**
* add 2 fake labels in spam data set,
* if x=spam and x_0<0.1, also label it as 2
* if x=spam and x_1<0.1, also label it as 3
* @throws Exception
*/
static void test3_build() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(4).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
dataSet.addLabel(i, 2);
}
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(1) < 0.1) {
dataSet.addLabel(i, 3);
}
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
List<String> extLabels = new ArrayList<>();
extLabels.add("non_spam");
extLabels.add("spam");
extLabels.add("fake2");
extLabels.add("fake3");
LabelTranslator labelTranslator = new LabelTranslator(extLabels);
dataSet.setLabelTranslator(labelTranslator);
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
assignments.add(new MultiLabel().addLabel(1).addLabel(2));
assignments.add(new MultiLabel().addLabel(1).addLabel(3));
assignments.add(new MultiLabel().addLabel(1).addLabel(2).addLabel(3));
HMLGradientBoosting boosting = new HMLGradientBoosting(4, assignments);
HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
// System.out.println(Arrays.toString(boosting.getGradients(0)));
// System.out.println(Arrays.toString(boosting.getGradients(1)));
// System.out.println(Arrays.toString(boosting.getGradients(2)));
// System.out.println(Arrays.toString(boosting.getGradients(3)));
}
stopWatch.stop();
System.out.println(stopWatch);
// System.out.println(boosting);
for (int i = 0; i < numDataPoints; i++) {
Vector featureRow = dataSet.getRow(i);
MultiLabel label = dataSet.getMultiLabels()[i];
MultiLabel prediction = boosting.predict(featureRow);
// System.out.println("prediction="+prediction);
if (!label.equals(prediction)) {
System.out.println(i);
System.out.println("label=" + label);
System.out.println("prediction=" + prediction);
}
}
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class HMLGradientBoostingTest method test4_build.
/**
* same as test3, the only difference is we now load data directly
* @throws Exception
*/
static void test4_build() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/4labels/train.trec"), DataSetType.ML_CLF_DENSE, true);
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
assignments.add(new MultiLabel().addLabel(1).addLabel(2));
assignments.add(new MultiLabel().addLabel(1).addLabel(3));
assignments.add(new MultiLabel().addLabel(1).addLabel(2).addLabel(3));
HMLGradientBoosting boosting = new HMLGradientBoosting(4, assignments);
HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet).numLeaves(100).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
HMLGBTrainer trainer = new HMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 10; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
}
stopWatch.stop();
System.out.println(stopWatch);
// System.out.println(boosting);
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
System.out.println("macro-averaged:");
System.out.println(new MacroAveragedMeasures(boosting, dataSet));
boosting.serialize(new File(TMP, "/hmlgb/boosting.ser"));
}
Aggregations