use of edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoost in project pyramid by cheng-li.
the class RerankerTrainer method trainWithSigmoid.
public Reranker trainWithSigmoid(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation, double[] noiseRates0, double[] noiseRates1) {
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels(), noiseRates0, noiseRates1);
// optimizer.setNoiseRates1(noiseRates1);
if (!monotonicityType.equals("none")) {
int[][] mono = new int[1][regDataSet.getNumFeatures()];
mono[0] = predictionFeatureExtractor.featureMonotonicity();
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSLogisticBoost bestModel = null;
for (int i = 1; i <= maxIter; i++) {
optimizer.iterate();
if (i % 10 == 0) {
double mse = MSE.mse(lsLogisticBoost, validation);
// todo
// double trainMse = MSE.mse(lsLogisticBoost, regDataSet);
// System.out.println("iter="+i+", train mse="+trainMse+" , valid mse="+mse);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSLogisticBoost) Serialization.deepCopy(lsLogisticBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
Aggregations