use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class RerankerTrainer method train.
public Reranker train(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation) {
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels());
if (!monotonicityType.equals("none")) {
int[][] mono = new int[1][regDataSet.getNumFeatures()];
mono[0] = predictionFeatureExtractor.featureMonotonicity();
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSBoost bestModel = null;
for (int i = 1; i <= maxIter; i++) {
optimizer.iterate();
if (i % 10 == 0 || i == maxIter) {
double mse = MSE.mse(lsBoost, validation);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
// System.out.println("best iteration = "+earlyStopper.getBestIteration());
return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class BRGB method train.
static void train(Config config, Logger logger) throws Exception {
String output = config.getString("output.folder");
int numIterations = config.getInt("train.numIterations");
int numLeaves = config.getInt("train.numLeaves");
double learningRate = config.getDouble("train.learningRate");
int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
int randomSeed = config.getInt("train.randomSeed");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
MultiLabelClfDataSet allTrainData = loadData(config, config.getString("input.trainData"));
double[] instanceWeights = new double[allTrainData.getNumDataPoints()];
Arrays.fill(instanceWeights, 1.0);
if (config.getBoolean("train.useInstanceWeights")) {
instanceWeights = loadInstanceWeights(config);
}
MultiLabelClfDataSet trainSetForEval = minibatch(allTrainData, instanceWeights, config.getInt("train.showProgress.sampleSize"), 0 + randomSeed).getFirst();
MultiLabelClfDataSet validSet = loadData(config, config.getString("input.validData"));
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(allTrainData);
Serialization.serialize(support, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "support"));
int numClasses = allTrainData.getNumClasses();
logger.info("number of class = " + numClasses);
IMLGradientBoosting boosting;
List<EarlyStopper> earlyStoppers;
List<Terminator> terminators;
boolean[] shouldStop;
int numLabelsLeftToTrain;
int startIter;
List<Pair<Integer, Double>> trainingTime;
List<Pair<Integer, Double>> accuracy;
double startTime = 0;
boolean earlyStop = config.getBoolean("train.earlyStop");
CheckPoint checkPoint;
if (config.getBoolean("train.warmStart")) {
checkPoint = (CheckPoint) Serialization.deserialize(Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "checkpoint"));
boosting = checkPoint.boosting;
earlyStoppers = checkPoint.earlyStoppers;
terminators = checkPoint.terminators;
shouldStop = checkPoint.shouldStop;
numLabelsLeftToTrain = checkPoint.numLabelsLeftToTrain;
startIter = checkPoint.lastIter + 1;
trainingTime = checkPoint.trainingTime;
accuracy = checkPoint.accuracy;
startTime = checkPoint.trainingTime.get(trainingTime.size() - 1).getSecond();
} else {
boosting = new IMLGradientBoosting(numClasses);
earlyStoppers = new ArrayList<>();
terminators = new ArrayList<>();
trainingTime = new ArrayList<>();
accuracy = new ArrayList<>();
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
earlyStoppers.add(earlyStopper);
}
for (int l = 0; l < numClasses; l++) {
Terminator terminator = new Terminator();
terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
terminators.add(terminator);
}
}
shouldStop = new boolean[allTrainData.getNumClasses()];
numLabelsLeftToTrain = numClasses;
checkPoint = new CheckPoint();
checkPoint.boosting = boosting;
checkPoint.earlyStoppers = earlyStoppers;
checkPoint.terminators = terminators;
checkPoint.shouldStop = shouldStop;
// this is not a pointer, has to be updated
checkPoint.numLabelsLeftToTrain = numLabelsLeftToTrain;
checkPoint.lastIter = 0;
checkPoint.trainingTime = trainingTime;
checkPoint.accuracy = accuracy;
startIter = 1;
}
logger.info("During training, the performance is reported using Hamming loss optimal predictor. The performance is computed approximately with " + config.getInt("train.showProgress.sampleSize") + " instances.");
int progressInterval = config.getInt("train.showProgress.interval");
int interval = config.getInt("train.fullScanInterval");
int minibatchLifeSpan = config.getInt("train.minibatchLifeSpan");
int numActiveFeatures = config.getInt("train.numActiveFeatures");
int numofLabels = allTrainData.getNumClasses();
List<Integer>[] activeFeaturesLists = new ArrayList[numofLabels];
for (int labelnum = 0; labelnum < numofLabels; labelnum++) {
activeFeaturesLists[labelnum] = new ArrayList<>();
}
MultiLabelClfDataSet trainBatch = null;
IMLGBTrainer trainer = null;
StopWatch timeWatch = new StopWatch();
timeWatch.start();
for (int i = startIter; i <= numIterations; i++) {
logger.info("iteration " + i);
if (i % minibatchLifeSpan == 1 || i == startIter) {
Pair<MultiLabelClfDataSet, double[]> sampled = minibatch(allTrainData, instanceWeights, config.getInt("train.batchSize"), i + randomSeed);
trainBatch = sampled.getFirst();
IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(trainBatch).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).numActiveFeatures(numActiveFeatures).build();
trainer = new IMLGBTrainer(imlgbConfig, boosting, shouldStop);
trainer.setInstanceWeights(sampled.getSecond());
}
if (i % interval == 1) {
trainer.iterate(activeFeaturesLists, true);
} else {
trainer.iterate(activeFeaturesLists, false);
}
checkPoint.lastIter += 1;
if (earlyStop && (i % progressInterval == 0 || i == numIterations)) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = earlyStoppers.get(l);
Terminator terminator = terminators.get(l);
if (!shouldStop[l]) {
double kl = KL(boosting, validSet, l);
earlyStopper.add(i, kl);
terminator.add(kl);
if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
logger.info("training for label " + l + " (" + allTrainData.getLabelTranslator().toExtLabel(l) + ") should stop now");
logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
if (i != earlyStopper.getBestIteration()) {
boosting.cutTail(l, earlyStopper.getBestIteration());
logger.info("roll back the model for this label to iteration " + earlyStopper.getBestIteration());
}
shouldStop[l] = true;
numLabelsLeftToTrain -= 1;
checkPoint.numLabelsLeftToTrain = numLabelsLeftToTrain;
logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
}
}
}
}
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training set performance (computed approximately with Hamming loss predictor on " + config.getInt("train.showProgress.sampleSize") + " instances).");
logger.info(new MLMeasures(boosting, trainSetForEval).toString());
}
if (config.getBoolean("train.showValidProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("validation set performance (computed approximately with Hamming loss predictor)");
MLMeasures validPerformance = new MLMeasures(boosting, validSet);
logger.info(validPerformance.toString());
accuracy.add(new Pair<>(i, validPerformance.getInstanceAverage().getF1()));
}
trainingTime.add(new Pair<>(i, startTime + timeWatch.getTime() / 1000.0));
Serialization.serialize(checkPoint, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "checkpoint"));
Serialization.serialize(boosting, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
if (numLabelsLeftToTrain == 0) {
logger.info("all label training finished");
break;
}
}
logger.info("training done");
logger.info(stopWatch.toString());
File analysisFolder = Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "analysis").toFile();
if (true) {
ObjectMapper objectMapper = new ObjectMapper();
List<LabelModel> labelModels = IMLGBInspector.getAllRules(boosting);
new File(analysisFolder, "decision_rules").mkdirs();
for (int l = 0; l < boosting.getNumClasses(); l++) {
objectMapper.writeValue(Paths.get(analysisFolder.toString(), "decision_rules", l + ".json").toFile(), labelModels.get(l));
}
}
boolean topFeaturesToFile = true;
if (topFeaturesToFile) {
logger.info("start writing top features");
List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, Integer.MAX_VALUE)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "top_features.json";
mapper.writeValue(new File(analysisFolder, file), topFeaturesList);
StringBuilder sb = new StringBuilder();
for (int l = 0; l < boosting.getNumClasses(); l++) {
sb.append("-------------------------").append("\n");
sb.append(allTrainData.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
sb.append(feature.simpleString()).append(", ");
}
sb.append("\n");
}
FileUtils.writeStringToFile(new File(analysisFolder, "top_features.txt"), sb.toString());
logger.info("finish writing top features");
}
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class CBMEN method tune.
private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
CBM cbm = newCBM(config, trainSet, hyperParameters);
EarlyStopper earlyStopper = loadNewEarlyStopper(config);
ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
if (config.getBoolean("train.randomInitialize")) {
optimizer.randInitialize();
} else {
optimizer.initialize();
}
MultiLabelClassifier classifier;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor;
break;
case "instance_f1":
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = pluginF1;
break;
case "instance_hamming_loss":
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = marginalPredictor;
break;
case "label_map":
AccPredictor accPredictor2 = new AccPredictor(cbm);
accPredictor2.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor2;
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
int interval = config.getInt("tune.monitorInterval");
for (int iter = 1; true; iter++) {
if (VERBOSE) {
System.out.println("iteration " + iter);
}
optimizer.iterate();
if (iter % interval == 0) {
MLMeasures validMeasures = new MLMeasures(classifier, validSet);
if (VERBOSE) {
System.out.println("validation performance with " + predictTarget + " optimal predictor:");
System.out.println(validMeasures);
}
switch(predictTarget) {
case "instance_set_accuracy":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
break;
case "instance_f1":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
break;
case "instance_hamming_loss":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
break;
case "label_map":
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
double map = MAP.mapBySupport(cbm, validSet, support);
earlyStopper.add(iter, map);
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
}
if (earlyStopper.shouldStop()) {
if (VERBOSE) {
System.out.println("Early Stopper: the training should stop now!");
}
break;
}
}
}
if (VERBOSE) {
System.out.println("done!");
}
hyperParameters.iterations = earlyStopper.getBestIteration();
TuneResult tuneResult = new TuneResult();
tuneResult.hyperParameters = hyperParameters;
tuneResult.performance = earlyStopper.getBestValue();
return tuneResult;
}
Aggregations