use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class LRCBMOptimizer method updateMultiClassClassifier.
@Override
protected void updateMultiClassClassifier() {
if (logger.isDebugEnabled()) {
logger.debug("start updateMultiClassClassifier");
}
// parallel
RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.multiClassClassifier, dataSet, gammas, priorVarianceMultiClass, true);
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(multiclassUpdatesPerIter);
ridgeLogisticOptimizer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("finish updateMultiClassClassifier");
}
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class SparseCBMOptimzer method updateBinaryLogisticRegression.
private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
double effectivePositives = effectivePositives(componentIndex, labelIndex);
StringBuilder sb = new StringBuilder();
sb.append("for component ").append(componentIndex).append(", label ").append(labelIndex);
sb.append(", effective positives = ").append(effectivePositives);
if (effectivePositives <= 1) {
double positiveProb = prior(componentIndex, labelIndex);
double[] probs = { 1 - positiveProb, positiveProb };
cbm.binaryClassifiers[componentIndex][labelIndex] = new PriorProbClassifier(probs);
sb.append(", skip, use prior = ").append(positiveProb);
sb.append(", time spent = " + stopWatch.toString());
System.out.println(sb.toString());
return;
}
if (cbm.binaryClassifiers[componentIndex][labelIndex] == null || cbm.binaryClassifiers[componentIndex][labelIndex] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[componentIndex][labelIndex] = new LogisticRegression(2, dataSet.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(dataSet.getMultiLabels(), labelIndex);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, binaryLabels, activeGammas, priorVarianceBinary, false);
//TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(numBinaryUpdates);
ridgeLogisticOptimizer.optimize();
sb.append(", time spent = " + stopWatch.toString());
System.out.println(sb.toString());
}
Aggregations