use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.
the class ENCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, 2, targetsDistribution, activeGammas).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).build();
elasticNetLogisticTrainer.setActiveSet(activeSet);
elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
elasticNetLogisticTrainer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.
the class CBMUtilityOptimizer method updateBinaryLogisticRegressionEL.
private void updateBinaryLogisticRegressionEL(int componentIndex, int labelIndex) {
ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, 2, binaryTargetsDistributions[labelIndex], gammasT[componentIndex]).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).build();
//TODO: maximum iterations
elasticNetLogisticTrainer.getTerminator().setMaxIteration(10);
elasticNetLogisticTrainer.optimize();
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.
the class CBMUtilityOptimizer method updateMultiClassEL.
private void updateMultiClassEL() {
ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.multiClassClassifier, dataSet, cbm.multiClassClassifier.getNumClasses(), gammas).setRegularization(regularizationMultiClass).setL1Ratio(l1RatioMultiClass).setLineSearch(lineSearch).build();
// TODO: maximum iterations
elasticNetLogisticTrainer.getTerminator().setMaxIteration(10);
elasticNetLogisticTrainer.optimize();
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateMultiClassEL.
private void updateMultiClassEL() {
ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.multiClassClassifier, dataSet, cbm.multiClassClassifier.getNumClasses(), gammas).setRegularization(regularizationMultiClass).setL1Ratio(l1RatioMultiClass).setLineSearch(lineSearch).build();
// TODO: maximum iterations
elasticNetLogisticTrainer.getTerminator().setMaxIteration(15);
elasticNetLogisticTrainer.optimize();
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.
the class PlattScaling method fitClassK.
private static LogisticRegression fitClassK(double[] scores, int[] labels) {
ClfDataSet dataSet = ClfDataSetBuilder.getBuilder().numClasses(2).numDataPoints(scores.length).numFeatures(1).dense(true).missingValue(false).build();
for (int i = 0; i < scores.length; i++) {
dataSet.setFeatureValue(i, 0, scores[i]);
dataSet.setLabel(i, labels[i]);
}
LogisticRegression logisticRegression = new LogisticRegression(2, dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setRegularization(1.0E-9).setL1Ratio(0).build();
trainer.optimize();
return logisticRegression;
}
Aggregations