use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class AbstractRecoverCBMOptimizer method skipOrUpdateBinaryClassifier.
protected void skipOrUpdateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataSet, double[] activeGammas, double totalWeight) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
double effectivePositives = effectivePositives(component, label);
double nonSmoothedPositiveProb = effectivePositives / totalWeight;
// smooth the component-wise label fraction with global label fraction
double positiveCount = labelMatrix.getColumn(label).getNumNonZeroElements();
double smoothedPositiveProb = (effectivePositives + smoothingStrength * positiveCount) / (totalWeight + smoothingStrength * groundTruth.getNumDataPoints());
StringBuilder sb = new StringBuilder();
sb.append("for component ").append(component).append(", label ").append(label);
sb.append(", weighted positives = ").append(effectivePositives);
sb.append(", non-smoothed positive fraction = " + (effectivePositives / totalWeight));
sb.append(", global positive fraction = " + (positiveCount / groundTruth.getNumDataPoints()));
sb.append(", smoothed positive fraction = " + smoothedPositiveProb);
// it be happen that p >1 for numerical reasons
if (smoothedPositiveProb >= 1) {
smoothedPositiveProb = 1;
}
if (nonSmoothedPositiveProb < skipLabelThreshold || nonSmoothedPositiveProb > 1 - skipLabelThreshold) {
double[] probs = { 1 - smoothedPositiveProb, smoothedPositiveProb };
cbm.binaryClassifiers[component][label] = new PriorProbClassifier(probs);
sb.append(", skip, use prior = ").append(smoothedPositiveProb);
sb.append(", time spent = ").append(stopWatch.toString());
if (logger.isDebugEnabled()) {
logger.debug(sb.toString());
}
return;
}
if (logger.isDebugEnabled()) {
logger.debug(sb.toString());
}
updateBinaryClassifier(component, label, activeDataSet, activeGammas);
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class LKBoostOptimizer method addPriors.
@Override
protected void addPriors() {
PriorProbClassifier priorProbClassifier = new PriorProbClassifier(numClasses);
priorProbClassifier.fit(dataSet, targetDistribution, weights);
double[] probs = priorProbClassifier.getClassProbs();
double[] scores = MathUtil.inverseSoftMax(probs);
// weaken the priors
for (int i = 0; i < scores.length; i++) {
if (scores[i] > 5) {
scores[i] = 5;
}
if (scores[i] < -5) {
scores[i] = -5;
}
}
for (int k = 0; k < numClasses; k++) {
Regressor constant = new ConstantRegressor(scores[k]);
boosting.getEnsemble(k).add(constant);
}
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class GBCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LKBoost(2);
}
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
double[][] targetsDistributions = DataSetUtil.labelsToDistributions(binaryLabels, 2);
LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[component][label];
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, activeDataset, regTreeFactory, activeGammas, targetsDistributions);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(binaryUpdatesPerIter);
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class SparseCBMOptimzer method updateBinaryLogisticRegression.
private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
double effectivePositives = effectivePositives(componentIndex, labelIndex);
StringBuilder sb = new StringBuilder();
sb.append("for component ").append(componentIndex).append(", label ").append(labelIndex);
sb.append(", effective positives = ").append(effectivePositives);
if (effectivePositives <= 1) {
double positiveProb = prior(componentIndex, labelIndex);
double[] probs = { 1 - positiveProb, positiveProb };
cbm.binaryClassifiers[componentIndex][labelIndex] = new PriorProbClassifier(probs);
sb.append(", skip, use prior = ").append(positiveProb);
sb.append(", time spent = " + stopWatch.toString());
System.out.println(sb.toString());
return;
}
if (cbm.binaryClassifiers[componentIndex][labelIndex] == null || cbm.binaryClassifiers[componentIndex][labelIndex] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[componentIndex][labelIndex] = new LogisticRegression(2, dataSet.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(dataSet.getMultiLabels(), labelIndex);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, binaryLabels, activeGammas, priorVarianceBinary, false);
// TODO maximum iterations
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(numBinaryUpdates);
ridgeLogisticOptimizer.optimize();
sb.append(", time spent = " + stopWatch.toString());
System.out.println(sb.toString());
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class RobustLRCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeInstanceWeights) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, binaryLabels, activeInstanceWeights, priorVarianceBinary, false);
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(binaryUpdatesPerIter);
ridgeLogisticOptimizer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
Aggregations