use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class LRCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, binaryLabels, activeGammas, priorVarianceBinary, false);
((LBFGS) ridgeLogisticOptimizer.getOptimizer()).getLineSearcher().setInitialStepLength(initialStepSize);
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(binaryUpdatesPerIter);
ridgeLogisticOptimizer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateBinaryLogisticRegression.
// //todo pay attention to parallelism
// private void updateBinaryClassifiers(int component){
// String type = cbm.getBinaryClassifierType();
// switch (type){
// case "lr":
// IntStream.range(0, cbm.numLabels).parallel().forEach(l-> updateBinaryLogisticRegression(component,l));
// break;
// case "boost":
// // no parallel for boosting
// IntStream.range(0, cbm.numLabels).forEach(l -> updateBinaryBoosting(component, l));
// break;
// case "elasticnet":
// IntStream.range(0, cbm.numLabels).parallel().forEach(l-> updateBinaryLogisticRegressionEL(component,l));
// break;
// default:
// throw new IllegalArgumentException("unknown type: " + cbm.getBinaryClassifierType());
// }
// }
// private void updateBinaryBoosting(int componentIndex, int labelIndex){
// int numIterations = numIterationsBinary;
// double shrinkage = shrinkageBinary;
// LKBoost boost = (LKBoost)this.cbm.binaryClassifiers[componentIndex][labelIndex];
// RegTreeConfig regTreeConfig = new RegTreeConfig()
// .setMaxNumLeaves(numLeavesBinary);
// RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
// regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
// LKBoostOptimizer optimizer = new LKBoostOptimizer(boost,dataSet, regTreeFactory,
// gammasT[componentIndex],targetsDistributions[labelIndex]);
// optimizer.setShrinkage(shrinkage);
// optimizer.initialize();
// optimizer.iterate(numIterations);
// }
private static BinaryTaskResult updateBinaryLogisticRegression(int componentIndex, int labelIndex, LogisticRegression logisticRegression, MultiLabelClfDataSet dataSet, double[] weights, double[][] targets, double variance) {
RidgeLogisticOptimizer ridgeLogisticOptimizer;
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, weights, targets, variance, false);
// TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(15);
ridgeLogisticOptimizer.optimize();
return new BinaryTaskResult(componentIndex, labelIndex, logisticRegression);
// if (logger.isDebugEnabled()){
// logger.debug("for cluster "+clusterIndex+" label "+labelIndex+" history= "+ridgeLogisticOptimizer.getOptimizer().getTerminator().getHistory());
// }
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateMultiClassLR.
private void updateMultiClassLR() {
// parallel
RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.multiClassClassifier, dataSet, gammas, priorVarianceMultiClass, true);
// TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(15);
ridgeLogisticOptimizer.optimize();
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class CBMNoiseOptimizerFixed method updateBinaryLogisticRegression.
private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
RidgeLogisticOptimizer ridgeLogisticOptimizer;
// System.out.println("for component "+componentIndex+" label "+labelIndex);
// System.out.println("weights="+Arrays.toString(gammasT[componentIndex]));
// System.out.println("binary target distribution="+Arrays.deepToString(binaryTargetsDistributions[labelIndex]));
// double posProb = 0;
// double negProb = 0;
// for (int i=0;i<dataSet.getNumDataPoints();i++){
// posProb += gammasT[componentIndex][i] * binaryTargetsDistributions[labelIndex][i][1];
// negProb += gammasT[componentIndex][i] * binaryTargetsDistributions[labelIndex][i][0];
// }
// System.out.println("sum pos prob = "+posProb);
// System.out.println("sum neg prob = "+negProb);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, gammasT[componentIndex], binaryTargetsDistributions[labelIndex], priorVarianceBinary, false);
// TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(10);
ridgeLogisticOptimizer.optimize();
// if (logger.isDebugEnabled()){
// logger.debug("for cluster "+clusterIndex+" label "+labelIndex+" history= "+ridgeLogisticOptimizer.getOptimizer().getTerminator().getHistory());
// }
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer in project pyramid by cheng-li.
the class SparseCBMOptimzer method updateMultiClassLR.
public void updateMultiClassLR() {
// parallel
RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.multiClassClassifier, dataSet, gammas, priorVarianceMultiClass, true);
// TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(numMulticlassUpdates);
ridgeLogisticOptimizer.optimize();
}
Aggregations