use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class IMLGradientBoostingTest method spam_build.
private static void spam_build() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
IMLGradientBoosting boosting = new IMLGradientBoosting(2);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(7).learningRate(0.1).numSplitIntervals(50).minDataPerLeaf(3).dataSamplingRate(1).featureSamplingRate(1).build();
System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 200; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
// System.out.println(Arrays.toString(boosting.getGradients(0)));
// System.out.println(Arrays.toString(boosting.getGradients(1)));
}
stopWatch.stop();
System.out.println(stopWatch);
System.out.println(boosting);
for (int i = 0; i < numDataPoints; i++) {
org.apache.mahout.math.Vector featureRow = dataSet.getRow(i);
System.out.println("" + i);
System.out.println(dataSet.getMultiLabels()[i]);
System.out.println(boosting.predict(featureRow));
}
System.out.println("accuracy");
System.out.println(Accuracy.accuracy(boosting, dataSet));
boosting.serialize(new File(TMP, "/imlgb/boosting.ser"));
Comparator<Map.Entry<List<Integer>, Double>> comparator = Comparator.comparing(entry -> entry.getValue());
System.out.println(IMLGBInspector.countPathMatches(boosting, dataSet, 0).entrySet().stream().sorted(comparator.reversed()).collect(Collectors.toList()).get(0));
// System.out.println(pathcount.values().stream().sorted().collect(Collectors.toList()));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class MLLogisticTrainerTest method test6.
static void test6() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
MLLogisticTrainer trainer = MLLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
MLLogisticRegression mlLogisticRegression = trainer.train(dataSet, assignments);
System.out.println(stopWatch);
System.out.println("training accuracy=" + Accuracy.accuracy(mlLogisticRegression, dataSet));
System.out.println("training overlap = " + Overlap.overlap(mlLogisticRegression, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(mlLogisticRegression, testSet));
System.out.println("test overlap = " + Overlap.overlap(mlLogisticRegression, testSet));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class MLPlattScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class AdaBoostMHTest method test1.
static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
AdaBoostMH boosting = new AdaBoostMH(dataSet.getNumClasses());
AdaBoostMHTrainer trainer = new AdaBoostMHTrainer(dataSet, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 500; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
}
use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.
the class App2 method train.
static void train(Config config, Logger logger) throws Exception {
String output = config.getString("output.folder");
int numIterations = config.getInt("train.numIterations");
int numLeaves = config.getInt("train.numLeaves");
double learningRate = config.getDouble("train.learningRate");
int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
String modelName = "model_app3";
// double featureSamplingRate = config.getDouble("train.featureSamplingRate");
// double dataSamplingRate = config.getDouble("train.dataSamplingRate");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
MultiLabelClfDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = loadData(config, config.getString("input.testData"));
}
int numClasses = dataSet.getNumClasses();
logger.info("number of class = " + numClasses);
IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
IMLGradientBoosting boosting;
if (config.getBoolean("train.warmStart")) {
boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
} else {
boosting = new IMLGradientBoosting(numClasses);
}
logger.info("During training, the performance is reported using Hamming loss optimal predictor");
logger.info("initialing trainer");
IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
boolean earlyStop = config.getBoolean("train.earlyStop");
List<EarlyStopper> earlyStoppers = new ArrayList<>();
List<Terminator> terminators = new ArrayList<>();
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
earlyStoppers.add(earlyStopper);
}
for (int l = 0; l < numClasses; l++) {
Terminator terminator = new Terminator();
terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
terminators.add(terminator);
}
}
logger.info("trainer initialized");
int numLabelsLeftToTrain = numClasses;
int progressInterval = config.getInt("train.showProgress.interval");
for (int i = 1; i <= numIterations; i++) {
logger.info("iteration " + i);
trainer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training set performance");
logger.info(new MLMeasures(boosting, dataSet).toString());
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("test set performance");
logger.info(new MLMeasures(boosting, testSet).toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = earlyStoppers.get(l);
Terminator terminator = terminators.get(l);
if (!trainer.getShouldStop()[l]) {
double kl = KL(boosting, testSet, l);
earlyStopper.add(i, kl);
terminator.add(kl);
if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
trainer.setShouldStop(l);
numLabelsLeftToTrain -= 1;
logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
}
}
}
}
}
if (numLabelsLeftToTrain == 0) {
logger.info("all label training finished");
break;
}
}
logger.info("training done");
File serializedModel = new File(output, modelName);
//todo pick best models
boosting.serialize(serializedModel);
logger.info(stopWatch.toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
logger.info("----------------------------------------------------");
logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
}
}
boolean topFeaturesToFile = true;
if (topFeaturesToFile) {
logger.info("start writing top features");
int limit = config.getInt("report.topFeatures.limit");
List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "top_features.json";
mapper.writeValue(new File(output, file), topFeaturesList);
StringBuilder sb = new StringBuilder();
for (int l = 0; l < boosting.getNumClasses(); l++) {
sb.append("-------------------------").append("\n");
sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
sb.append(feature.simpleString()).append(", ");
}
sb.append("\n");
}
FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
logger.info("finish writing top features");
}
}
Aggregations