use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class ENRecoverCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
// no parallelism
ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, 2, targetsDistribution, activeGammas).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).build();
elasticNetLogisticTrainer.setActiveSet(activeSet);
elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
elasticNetLogisticTrainer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class BRInspector method decisionProcess.
// only show the positive class score calculation
public static ClassScoreCalculation decisionProcess(CBM cbm, LabelTranslator labelTranslator, double prob, Vector vector, int classIndex, int limit) {
if (cbm.getBinaryClassifiers()[0][classIndex] instanceof PriorProbClassifier) {
PriorProbClassifier priorProbClassifier = (PriorProbClassifier) cbm.getBinaryClassifiers()[0][classIndex];
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), priorProbClassifier.predictClassProb(vector, 1));
classScoreCalculation.setClassProbability(prob);
return classScoreCalculation;
}
LogisticRegression logisticRegression = (LogisticRegression) cbm.getBinaryClassifiers()[0][classIndex];
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), logisticRegression.predictClassScore(vector, 1));
classScoreCalculation.setClassProbability(prob);
List<LinearRule> linearRules = new ArrayList<>();
Rule bias = new ConstantRule(logisticRegression.getWeights().getBiasForClass(1));
classScoreCalculation.addRule(bias);
// todo speed up using sparsity
for (int j = 0; j < logisticRegression.getNumFeatures(); j++) {
Feature feature = logisticRegression.getFeatureList().get(j);
double weight = logisticRegression.getWeights().getWeightsWithoutBiasForClass(1).get(j);
double featureValue = vector.get(j);
double score = weight * featureValue;
LinearRule rule = new LinearRule();
rule.setFeature(feature);
rule.setFeatureValue(featureValue);
rule.setScore(score);
rule.setWeight(weight);
linearRules.add(rule);
}
Comparator<LinearRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
List<LinearRule> sorted = linearRules.stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
for (LinearRule linearRule : sorted) {
classScoreCalculation.addRule(linearRule);
}
return classScoreCalculation;
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class BRInspector method decisionProcessWObias.
public static ClassScoreCalculation decisionProcessWObias(CBM cbm, LabelTranslator labelTranslator, double prob, Vector vector, int classIndex, int limit) {
if (cbm.getBinaryClassifiers()[0][classIndex] instanceof PriorProbClassifier) {
PriorProbClassifier priorProbClassifier = (PriorProbClassifier) cbm.getBinaryClassifiers()[0][classIndex];
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), priorProbClassifier.predictClassProb(vector, 1));
classScoreCalculation.setClassProbability(prob);
return classScoreCalculation;
}
LogisticRegression logisticRegression = (LogisticRegression) cbm.getBinaryClassifiers()[0][classIndex];
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), logisticRegression.predictClassScore(vector, 1));
classScoreCalculation.setClassProbability(prob);
List<LinearRule> linearRules = new ArrayList<>();
// todo speed up using sparsity
for (int j = 0; j < logisticRegression.getNumFeatures(); j++) {
Feature feature = logisticRegression.getFeatureList().get(j);
double weight = logisticRegression.getWeights().getWeightsWithoutBiasForClass(1).get(j);
double featureValue = vector.get(j);
double score = weight * featureValue;
LinearRule rule = new LinearRule();
rule.setFeature(feature);
rule.setFeatureValue(featureValue);
rule.setScore(score);
rule.setWeight(weight);
linearRules.add(rule);
}
Comparator<LinearRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
List<LinearRule> sorted = linearRules.stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
for (LinearRule linearRule : sorted) {
classScoreCalculation.addRule(linearRule);
}
return classScoreCalculation;
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class AbstractCBMOptimizer method skipOrUpdateBinaryClassifier.
protected void skipOrUpdateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataSet, double[] activeGammas, double totalWeight) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
double effectivePositives = effectivePositives(component, label);
double nonSmoothedPositiveProb = effectivePositives / totalWeight;
// smooth the component-wise label fraction with global label fraction
double smoothedPositiveProb = (effectivePositives + smoothingStrength * positiveCounts[label]) / (totalWeight + smoothingStrength * dataSet.getNumDataPoints());
StringBuilder sb = new StringBuilder();
sb.append("for component ").append(component).append(", label ").append(label);
sb.append(", weighted positives = ").append(effectivePositives);
sb.append(", non-smoothed positive fraction = " + (effectivePositives / totalWeight));
sb.append(", global positive fraction = " + ((double) positiveCounts[label] / dataSet.getNumDataPoints()));
sb.append(", smoothed positive fraction = " + smoothedPositiveProb);
// it be happen that p >1 for numerical reasons
if (smoothedPositiveProb >= 1) {
smoothedPositiveProb = 1;
}
// todo avoid zero probability
if (smoothedPositiveProb < 1E-30) {
smoothedPositiveProb = 1E-30;
}
if (nonSmoothedPositiveProb < skipLabelThreshold || nonSmoothedPositiveProb > 1 - skipLabelThreshold) {
double[] probs = { 1 - smoothedPositiveProb, smoothedPositiveProb };
cbm.binaryClassifiers[component][label] = new PriorProbClassifier(probs);
sb.append(", skip, use prior = ").append(smoothedPositiveProb);
sb.append(", time spent = ").append(stopWatch.toString());
if (logger.isDebugEnabled()) {
logger.debug(sb.toString());
}
return;
}
if (logger.isDebugEnabled()) {
logger.debug(sb.toString());
}
updateBinaryClassifier(component, label, activeDataSet, activeGammas);
}
Aggregations