Search in sources :

Example 11 with PriorProbClassifier

use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.

the class ENRecoverCBMOptimizer method updateBinaryClassifier.

@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
        cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
    }
    int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
    double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
    // no parallelism
    ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, 2, targetsDistribution, activeGammas).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).build();
    elasticNetLogisticTrainer.setActiveSet(activeSet);
    elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
    elasticNetLogisticTrainer.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
    }
}
Also used : PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 12 with PriorProbClassifier

use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.

the class BRInspector method decisionProcess.

// only show the positive class score calculation
public static ClassScoreCalculation decisionProcess(CBM cbm, LabelTranslator labelTranslator, double prob, Vector vector, int classIndex, int limit) {
    if (cbm.getBinaryClassifiers()[0][classIndex] instanceof PriorProbClassifier) {
        PriorProbClassifier priorProbClassifier = (PriorProbClassifier) cbm.getBinaryClassifiers()[0][classIndex];
        ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), priorProbClassifier.predictClassProb(vector, 1));
        classScoreCalculation.setClassProbability(prob);
        return classScoreCalculation;
    }
    LogisticRegression logisticRegression = (LogisticRegression) cbm.getBinaryClassifiers()[0][classIndex];
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), logisticRegression.predictClassScore(vector, 1));
    classScoreCalculation.setClassProbability(prob);
    List<LinearRule> linearRules = new ArrayList<>();
    Rule bias = new ConstantRule(logisticRegression.getWeights().getBiasForClass(1));
    classScoreCalculation.addRule(bias);
    // todo speed up using sparsity
    for (int j = 0; j < logisticRegression.getNumFeatures(); j++) {
        Feature feature = logisticRegression.getFeatureList().get(j);
        double weight = logisticRegression.getWeights().getWeightsWithoutBiasForClass(1).get(j);
        double featureValue = vector.get(j);
        double score = weight * featureValue;
        LinearRule rule = new LinearRule();
        rule.setFeature(feature);
        rule.setFeatureValue(featureValue);
        rule.setScore(score);
        rule.setWeight(weight);
        linearRules.add(rule);
    }
    Comparator<LinearRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<LinearRule> sorted = linearRules.stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (LinearRule linearRule : sorted) {
        classScoreCalculation.addRule(linearRule);
    }
    return classScoreCalculation;
}
Also used : ArrayList(java.util.ArrayList) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) Feature(edu.neu.ccs.pyramid.feature.Feature)

Example 13 with PriorProbClassifier

use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.

the class BRInspector method decisionProcessWObias.

public static ClassScoreCalculation decisionProcessWObias(CBM cbm, LabelTranslator labelTranslator, double prob, Vector vector, int classIndex, int limit) {
    if (cbm.getBinaryClassifiers()[0][classIndex] instanceof PriorProbClassifier) {
        PriorProbClassifier priorProbClassifier = (PriorProbClassifier) cbm.getBinaryClassifiers()[0][classIndex];
        ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), priorProbClassifier.predictClassProb(vector, 1));
        classScoreCalculation.setClassProbability(prob);
        return classScoreCalculation;
    }
    LogisticRegression logisticRegression = (LogisticRegression) cbm.getBinaryClassifiers()[0][classIndex];
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), logisticRegression.predictClassScore(vector, 1));
    classScoreCalculation.setClassProbability(prob);
    List<LinearRule> linearRules = new ArrayList<>();
    // todo speed up using sparsity
    for (int j = 0; j < logisticRegression.getNumFeatures(); j++) {
        Feature feature = logisticRegression.getFeatureList().get(j);
        double weight = logisticRegression.getWeights().getWeightsWithoutBiasForClass(1).get(j);
        double featureValue = vector.get(j);
        double score = weight * featureValue;
        LinearRule rule = new LinearRule();
        rule.setFeature(feature);
        rule.setFeatureValue(featureValue);
        rule.setScore(score);
        rule.setWeight(weight);
        linearRules.add(rule);
    }
    Comparator<LinearRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<LinearRule> sorted = linearRules.stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (LinearRule linearRule : sorted) {
        classScoreCalculation.addRule(linearRule);
    }
    return classScoreCalculation;
}
Also used : ArrayList(java.util.ArrayList) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) Feature(edu.neu.ccs.pyramid.feature.Feature)

Example 14 with PriorProbClassifier

use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.

the class AbstractCBMOptimizer method skipOrUpdateBinaryClassifier.

protected void skipOrUpdateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataSet, double[] activeGammas, double totalWeight) {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    double effectivePositives = effectivePositives(component, label);
    double nonSmoothedPositiveProb = effectivePositives / totalWeight;
    // smooth the component-wise label fraction with global label fraction
    double smoothedPositiveProb = (effectivePositives + smoothingStrength * positiveCounts[label]) / (totalWeight + smoothingStrength * dataSet.getNumDataPoints());
    StringBuilder sb = new StringBuilder();
    sb.append("for component ").append(component).append(", label ").append(label);
    sb.append(", weighted positives = ").append(effectivePositives);
    sb.append(", non-smoothed positive fraction = " + (effectivePositives / totalWeight));
    sb.append(", global positive fraction = " + ((double) positiveCounts[label] / dataSet.getNumDataPoints()));
    sb.append(", smoothed positive fraction = " + smoothedPositiveProb);
    // it be happen that p >1 for numerical reasons
    if (smoothedPositiveProb >= 1) {
        smoothedPositiveProb = 1;
    }
    // todo avoid zero probability
    if (smoothedPositiveProb < 1E-30) {
        smoothedPositiveProb = 1E-30;
    }
    if (nonSmoothedPositiveProb < skipLabelThreshold || nonSmoothedPositiveProb > 1 - skipLabelThreshold) {
        double[] probs = { 1 - smoothedPositiveProb, smoothedPositiveProb };
        cbm.binaryClassifiers[component][label] = new PriorProbClassifier(probs);
        sb.append(", skip, use prior = ").append(smoothedPositiveProb);
        sb.append(", time spent = ").append(stopWatch.toString());
        if (logger.isDebugEnabled()) {
            logger.debug(sb.toString());
        }
        return;
    }
    if (logger.isDebugEnabled()) {
        logger.debug(sb.toString());
    }
    updateBinaryClassifier(component, label, activeDataSet, activeGammas);
}
Also used : PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) StopWatch(org.apache.commons.lang3.time.StopWatch)

Aggregations

PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)14 StopWatch (org.apache.commons.lang3.time.StopWatch)10 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)8 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)4 ArrayList (java.util.ArrayList)3 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)2 Feature (edu.neu.ccs.pyramid.feature.Feature)2 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1 BMSelector (edu.neu.ccs.pyramid.clustering.bm.BMSelector)1 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 ConstantRegressor (edu.neu.ccs.pyramid.regression.ConstantRegressor)1 Regressor (edu.neu.ccs.pyramid.regression.Regressor)1 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)1 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)1 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)1 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)1 File (java.io.File)1