use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class RerankerTrainer method trainWithSigmoid.
public Reranker trainWithSigmoid(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation, double[] noiseRates0, double[] noiseRates1) {
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels(), noiseRates0, noiseRates1);
// optimizer.setNoiseRates1(noiseRates1);
if (!monotonicityType.equals("none")) {
int[][] mono = new int[1][regDataSet.getNumFeatures()];
mono[0] = predictionFeatureExtractor.featureMonotonicity();
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSLogisticBoost bestModel = null;
for (int i = 1; i <= maxIter; i++) {
optimizer.iterate();
if (i % 10 == 0) {
double mse = MSE.mse(lsLogisticBoost, validation);
// todo
// double trainMse = MSE.mse(lsLogisticBoost, regDataSet);
// System.out.println("iter="+i+", train mse="+trainMse+" , valid mse="+mse);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSLogisticBoost) Serialization.deepCopy(lsLogisticBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class CBMEN method loadNewEarlyStopper.
private static EarlyStopper loadNewEarlyStopper(Config config) {
String earlyStopMetric = config.getString("tune.targetMetric");
int patience = config.getInt("tune.earlyStop.patience");
EarlyStopper.Goal earlyStopGoal = null;
switch(earlyStopMetric) {
case "instance_set_accuracy":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
case "instance_f1":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
case "instance_hamming_loss":
earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
break;
case "label_map":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
default:
throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
}
EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
return earlyStopper;
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class CBMGB method loadNewEarlyStopper.
private static EarlyStopper loadNewEarlyStopper(Config config) {
String earlyStopMetric = config.getString("tune.targetMetric");
int patience = config.getInt("tune.earlyStop.patience");
EarlyStopper.Goal earlyStopGoal = null;
switch(earlyStopMetric) {
case "instance_set_accuracy":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
case "instance_f1":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
case "instance_hamming_loss":
earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
break;
case "instance_log_likelihood":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
default:
throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
}
EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
return earlyStopper;
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class CBMLR method tune.
private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
CBM cbm = newCBM(config, trainSet, hyperParameters);
EarlyStopper earlyStopper = loadNewEarlyStopper(config);
LRCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
optimizer.initialize();
MultiLabelClassifier classifier;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor;
break;
case "instance_f1":
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = pluginF1;
break;
case "instance_hamming_loss":
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = marginalPredictor;
break;
case "instance_log_likelihood":
// acc predictor seems to be the best match for log likelihood
AccPredictor accPredictor1 = new AccPredictor(cbm);
accPredictor1.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor1;
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
int interval = config.getInt("tune.monitorInterval");
for (int iter = 1; true; iter++) {
if (VERBOSE) {
System.out.println("iteration " + iter);
}
optimizer.iterate();
if (iter % interval == 0) {
MLMeasures validMeasures = new MLMeasures(classifier, validSet);
if (VERBOSE) {
System.out.println("validation performance with " + predictTarget + " optimal predictor:");
System.out.println(validMeasures);
}
switch(predictTarget) {
case "instance_set_accuracy":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
break;
case "instance_f1":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
break;
case "instance_hamming_loss":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
break;
case "instance_log_likelihood":
earlyStopper.add(iter, LogLikelihood.averageLogLikelihood(cbm, validSet, unobservedLabels));
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
if (earlyStopper.shouldStop()) {
if (VERBOSE) {
System.out.println("Early Stopper: the training should stop now!");
}
break;
}
}
}
if (VERBOSE) {
System.out.println("done!");
}
hyperParameters.iterations = earlyStopper.getBestIteration();
TuneResult tuneResult = new TuneResult();
tuneResult.hyperParameters = hyperParameters;
tuneResult.performance = earlyStopper.getBestValue();
return tuneResult;
}
use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.
the class CBMLR method loadNewEarlyStopper.
private static EarlyStopper loadNewEarlyStopper(Config config) {
String earlyStopMetric = config.getString("tune.targetMetric");
int patience = config.getInt("tune.earlyStop.patience");
EarlyStopper.Goal earlyStopGoal = null;
switch(earlyStopMetric) {
case "instance_set_accuracy":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
case "instance_f1":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
case "instance_hamming_loss":
earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
break;
case "instance_log_likelihood":
earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
break;
default:
throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
}
EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
return earlyStopper;
}
Aggregations