use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class GBCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LKBoost(2);
}
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
double[][] targetsDistributions = DataSetUtil.labelsToDistributions(binaryLabels, 2);
LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[component][label];
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, activeDataset, regTreeFactory, activeGammas, targetsDistributions);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(binaryUpdatesPerIter);
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class SparseCBMOptimzer method updateBinaryLogisticRegression.
private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
double effectivePositives = effectivePositives(componentIndex, labelIndex);
StringBuilder sb = new StringBuilder();
sb.append("for component ").append(componentIndex).append(", label ").append(labelIndex);
sb.append(", effective positives = ").append(effectivePositives);
if (effectivePositives <= 1) {
double positiveProb = prior(componentIndex, labelIndex);
double[] probs = { 1 - positiveProb, positiveProb };
cbm.binaryClassifiers[componentIndex][labelIndex] = new PriorProbClassifier(probs);
sb.append(", skip, use prior = ").append(positiveProb);
sb.append(", time spent = " + stopWatch.toString());
System.out.println(sb.toString());
return;
}
if (cbm.binaryClassifiers[componentIndex][labelIndex] == null || cbm.binaryClassifiers[componentIndex][labelIndex] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[componentIndex][labelIndex] = new LogisticRegression(2, dataSet.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(dataSet.getMultiLabels(), labelIndex);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, binaryLabels, activeGammas, priorVarianceBinary, false);
//TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(numBinaryUpdates);
ridgeLogisticOptimizer.optimize();
sb.append(", time spent = " + stopWatch.toString());
System.out.println(sb.toString());
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class IMLGradientBoostingTest method test2_build.
/**
* add a fake label in spam data set, if x=spam and x_0<0.1, also label it as 2
* @throws Exception
*/
static void test2_build() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(3).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
dataSet.addLabel(i, 2);
}
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
IMLGradientBoosting boosting = new IMLGradientBoosting(3);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(60).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 20; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
// System.out.println(Arrays.toString(boosting.getGradients(0)));
// System.out.println(Arrays.toString(boosting.getGradients(1)));
}
stopWatch.stop();
System.out.println(stopWatch);
System.out.println(boosting);
for (int i = 0; i < numDataPoints; i++) {
Vector featureRow = dataSet.getRow(i);
MultiLabel label = dataSet.getMultiLabels()[i];
MultiLabel prediction = boosting.predict(featureRow);
// System.out.println("label="+label);
// System.out.println(boosting.calAssignmentScore(featureRow,assignments.get(0)));
// System.out.println(boosting.calAssignmentScore(featureRow,assignments.get(1)));
// System.out.println("prediction="+prediction);
// if (!MultiLabel.equivalent(label,prediction)){
// System.out.println(i);
// System.out.println("label="+label);
// System.out.println("prediction="+prediction);
// }
}
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
boosting.serialize(new File(TMP, "imlgb/boosting.ser"));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class IMLGradientBoostingTest method test3_build.
/**
* add 2 fake labels in spam data set,
* if x=spam and x_0<0.1, also label it as 2
* if x=spam and x_1<0.1, also label it as 3
* @throws Exception
*/
static void test3_build() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(4).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
dataSet.addLabel(i, 2);
}
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(1) < 0.1) {
dataSet.addLabel(i, 3);
}
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
IMLGradientBoosting boosting = new IMLGradientBoosting(4);
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(10).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
// System.out.println(Arrays.toString(boosting.getGradients(0)));
// System.out.println(Arrays.toString(boosting.getGradients(1)));
// System.out.println(Arrays.toString(boosting.getGradients(2)));
// System.out.println(Arrays.toString(boosting.getGradients(3)));
}
stopWatch.stop();
System.out.println(stopWatch);
// System.out.println(boosting);
for (int i = 0; i < numDataPoints; i++) {
Vector featureRow = dataSet.getRow(i);
MultiLabel label = dataSet.getMultiLabels()[i];
MultiLabel prediction = boosting.predict(featureRow);
// System.out.println("prediction="+prediction);
if (!label.equals(prediction)) {
System.out.println(i);
System.out.println("label=" + label);
System.out.println("prediction=" + prediction);
}
}
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
System.out.println("overlap = " + Overlap.overlap(boosting, dataSet));
boosting.serialize(new File(TMP, "/imlgb/boosting.ser"));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class IMLLogisticRegressionTest method test1.
static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
IMLLogisticTrainer trainer = IMLLogisticTrainer.getBuilder().setEpsilon(0.01).setGaussianPriorVariance(1).setHistory(5).build();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
IMLLogisticRegression logisticRegression = trainer.train(dataSet, assignments);
System.out.println(stopWatch);
System.out.println("training accuracy=" + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("training overlap = " + Overlap.overlap(logisticRegression, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(logisticRegression, testSet));
System.out.println("test overlap = " + Overlap.overlap(logisticRegression, testSet));
}
Aggregations