use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class Calibration method trainCalibrator.
private static LSBoost trainCalibrator(RegDataSet calib, RegDataSet valid, int[] monotonicity) {
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10).setMinDataPerLeaf(5);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, calib, regTreeFactory, calib.getLabels());
if (true) {
int[][] mono = new int[1][calib.getNumFeatures()];
mono[0] = monotonicity;
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(0.1);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSBoost bestModel = null;
for (int i = 1; i < 1000; i++) {
optimizer.iterate();
if (i % 10 == 0) {
double mse = MSE.mse(lsBoost, valid);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
return bestModel;
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class App2 method train.
static void train(Config config, Logger logger) throws Exception {
String output = config.getString("output.folder");
int numIterations = config.getInt("train.numIterations");
int numLeaves = config.getInt("train.numLeaves");
double learningRate = config.getDouble("train.learningRate");
int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
String modelName = "model_app3";
// double featureSamplingRate = config.getDouble("train.featureSamplingRate");
// double dataSamplingRate = config.getDouble("train.dataSamplingRate");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
MultiLabelClfDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = loadData(config, config.getString("input.testData"));
}
int numClasses = dataSet.getNumClasses();
logger.info("number of class = " + numClasses);
IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
IMLGradientBoosting boosting;
if (config.getBoolean("train.warmStart")) {
boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
} else {
boosting = new IMLGradientBoosting(numClasses);
}
logger.info("During training, the performance is reported using Hamming loss optimal predictor");
logger.info("initialing trainer");
IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
boolean earlyStop = config.getBoolean("train.earlyStop");
List<EarlyStopper> earlyStoppers = new ArrayList<>();
List<Terminator> terminators = new ArrayList<>();
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
earlyStoppers.add(earlyStopper);
}
for (int l = 0; l < numClasses; l++) {
Terminator terminator = new Terminator();
terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
terminators.add(terminator);
}
}
logger.info("trainer initialized");
int numLabelsLeftToTrain = numClasses;
int progressInterval = config.getInt("train.showProgress.interval");
for (int i = 1; i <= numIterations; i++) {
logger.info("iteration " + i);
trainer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training set performance");
logger.info(new MLMeasures(boosting, dataSet).toString());
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("test set performance");
logger.info(new MLMeasures(boosting, testSet).toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = earlyStoppers.get(l);
Terminator terminator = terminators.get(l);
if (!trainer.getShouldStop()[l]) {
double kl = KL(boosting, testSet, l);
earlyStopper.add(i, kl);
terminator.add(kl);
if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
trainer.setShouldStop(l);
numLabelsLeftToTrain -= 1;
logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
}
}
}
}
}
if (numLabelsLeftToTrain == 0) {
logger.info("all label training finished");
break;
}
}
logger.info("training done");
File serializedModel = new File(output, modelName);
//todo pick best models
boosting.serialize(serializedModel);
logger.info(stopWatch.toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
logger.info("----------------------------------------------------");
logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
}
}
boolean topFeaturesToFile = true;
if (topFeaturesToFile) {
logger.info("start writing top features");
int limit = config.getInt("report.topFeatures.limit");
List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "top_features.json";
mapper.writeValue(new File(output, file), topFeaturesList);
StringBuilder sb = new StringBuilder();
for (int l = 0; l < boosting.getNumClasses(); l++) {
sb.append("-------------------------").append("\n");
sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
sb.append(feature.simpleString()).append(", ");
}
sb.append("\n");
}
FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
logger.info("finish writing top features");
}
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class CBMGB method tune.
private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
CBM cbm = newCBM(config, trainSet, hyperParameters);
EarlyStopper earlyStopper = loadNewEarlyStopper(config);
GBCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
optimizer.initialize();
MultiLabelClassifier classifier;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor;
break;
case "instance_f1":
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = pluginF1;
break;
case "instance_hamming_loss":
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = marginalPredictor;
break;
case "instance_log_likelihood":
// acc predictor seems to be the best match for log likelihood
AccPredictor accPredictor1 = new AccPredictor(cbm);
accPredictor1.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor1;
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
int interval = config.getInt("tune.monitorInterval");
for (int iter = 1; true; iter++) {
if (VERBOSE) {
System.out.println("iteration " + iter);
}
optimizer.iterate();
if (iter % interval == 0) {
MLMeasures validMeasures = new MLMeasures(classifier, validSet);
if (VERBOSE) {
System.out.println("validation performance with " + predictTarget + " optimal predictor:");
System.out.println(validMeasures);
}
switch(predictTarget) {
case "instance_set_accuracy":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
break;
case "instance_f1":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
break;
case "instance_hamming_loss":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
break;
case "instance_log_likelihood":
earlyStopper.add(iter, LogLikelihood.averageLogLikelihood(cbm, validSet, unobservedLabels));
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
}
if (earlyStopper.shouldStop()) {
if (VERBOSE) {
System.out.println("Early Stopper: the training should stop now!");
}
break;
}
}
}
if (VERBOSE) {
System.out.println("done!");
}
hyperParameters.iterations = earlyStopper.getBestIteration();
TuneResult tuneResult = new TuneResult();
tuneResult.hyperParameters = hyperParameters;
tuneResult.performance = earlyStopper.getBestValue();
return tuneResult;
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class BRLREN method train.
private static void train(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet, Logger logger) throws Exception {
List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
if (!unobservedLabels.isEmpty()) {
logger.info("The following labels do not actually appear in the training set and therefore cannot be learned:");
logger.info(ListUtil.toSimpleString(unobservedLabels));
FileUtils.writeStringToFile(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "analysis", "unobserved_labels.txt").toFile(), ListUtil.toSimpleString(unobservedLabels));
}
String output = config.getString("output.dir");
EarlyStopper earlyStopper = loadNewEarlyStopper();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
CBM cbm = newCBM(config, trainSet, hyperParameters, logger);
ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
logger.info("Initializing the model");
if (config.getBoolean("train.randomInitialize")) {
optimizer.randInitialize();
} else {
optimizer.initialize();
}
logger.info("Initialization done");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
int interval = 1;
for (int iter = 1; true; iter++) {
logger.info("Training progress: iteration " + iter);
optimizer.iterate();
if (iter % interval == 0) {
MLMeasures validMeasures = new MLMeasures(accPredictor, validSet);
if (VERBOSE) {
logger.info("validation performance");
logger.info(validMeasures.toString());
}
earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
if (earlyStopper.getBestIteration() == iter) {
Serialization.serialize(cbm, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
}
if (earlyStopper.shouldStop()) {
if (VERBOSE) {
logger.info("Early Stopper: the training should stop now!");
logger.info("Early Stopper: best iteration found = " + earlyStopper.getBestIteration());
logger.info("Early Stopper: best validation performance = " + earlyStopper.getBestValue());
}
break;
}
}
}
logger.info("training done!");
logger.info("time spent on training = " + stopWatch);
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
Serialization.serialize(support, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "support"));
CBM bestModel = (CBM) Serialization.deserialize(Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
boolean topFeaturesToFile = true;
if (topFeaturesToFile) {
File analysisFolder = Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "analysis").toFile();
analysisFolder.mkdirs();
logger.info("start writing top features");
List<TopFeatures> topFeaturesList = IntStream.range(0, bestModel.getNumClasses()).mapToObj(k -> topFeatures(bestModel, trainSet.getFeatureList(), trainSet.getLabelTranslator(), k, 100)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "top_features.json";
mapper.writeValue(new File(analysisFolder, file), topFeaturesList);
StringBuilder sb = new StringBuilder();
for (int l = 0; l < bestModel.getNumClasses(); l++) {
sb.append("-------------------------").append("\n");
sb.append(bestModel.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
sb.append(feature.simpleString()).append(", ");
}
sb.append("\n");
}
FileUtils.writeStringToFile(new File(analysisFolder, "top_features.txt"), sb.toString());
logger.info("finish writing top features");
}
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class BRLREN method loadNewEarlyStopper.
private static EarlyStopper loadNewEarlyStopper() {
int patience = 5;
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MAXIMIZE, patience);
earlyStopper.setMinimumIterations(5);
return earlyStopper;
}
Aggregations