use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class SplitterTest method test1.
private static void test1() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/train.trec"), DataSetType.CLF_SPARSE, true);
PriorProbClassifier priorProbClassifier = new PriorProbClassifier(dataSet.getNumClasses());
priorProbClassifier.fit(dataSet);
double[] gradient = priorProbClassifier.getGradient(dataSet, 1);
RegTreeConfig regTreeConfig = new RegTreeConfig();
int[] activeFeatures = IntStream.range(0, dataSet.getNumFeatures()).toArray();
int[] activeDataPoints = IntStream.range(0, dataSet.getNumDataPoints()).toArray();
Comparator<SplitResult> comparator = Comparator.comparing(SplitResult::getReduction);
List<Integer> results = Splitter.getAllSplits(regTreeConfig, dataSet, gradient).stream().sorted(comparator.reversed()).map(result -> result.getFeatureIndex()).limit(100).collect(Collectors.toList());
// results.stream().forEach(i-> System.out.println(dataSet.getFeatureSetting(i).getFeatureName()));
System.out.println(results);
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class LRCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, binaryLabels, activeGammas, priorVarianceBinary, false);
((LBFGS) ridgeLogisticOptimizer.getOptimizer()).getLineSearcher().setInitialStepLength(initialStepSize);
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(binaryUpdatesPerIter);
ridgeLogisticOptimizer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class AbstractRobustCBMOptimizer method skipOrUpdateBinaryClassifier.
protected void skipOrUpdateBinaryClassifier(int component, int label, List<Integer> activeIndices, MultiLabelClfDataSet activeDataSet, double totalWeight) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
double effectivePositives = effectivePositives(component, label);
double nonSmoothedPositiveProb = effectivePositives / totalWeight;
// smooth the component-wise label fraction with global label fraction
double smoothedPositiveProb = (effectivePositives + smoothingStrength * positiveCounts[label]) / (totalWeight + smoothingStrength * dataSet.getNumDataPoints());
StringBuilder sb = new StringBuilder();
sb.append("for component ").append(component).append(", label ").append(label);
sb.append(", weighted positives = ").append(effectivePositives);
sb.append(", non-smoothed positive fraction = " + (effectivePositives / totalWeight));
sb.append(", global positive fraction = " + ((double) positiveCounts[label] / dataSet.getNumDataPoints()));
sb.append(", smoothed positive fraction = " + smoothedPositiveProb);
// it be happen that p >1 for numerical reasons
if (smoothedPositiveProb >= 1) {
smoothedPositiveProb = 1;
}
if (nonSmoothedPositiveProb < skipLabelThreshold || nonSmoothedPositiveProb > 1 - skipLabelThreshold) {
double[] probs = { 1 - smoothedPositiveProb, smoothedPositiveProb };
cbm.binaryClassifiers[component][label] = new PriorProbClassifier(probs);
sb.append(", skip, use prior = ").append(smoothedPositiveProb);
sb.append(", time spent = ").append(stopWatch.toString());
if (logger.isDebugEnabled()) {
logger.debug(sb.toString());
}
return;
}
if (logger.isDebugEnabled()) {
logger.debug(sb.toString());
}
double[] activeInstanceWeights = activeIndices.stream().mapToDouble(i -> gammas[i][component] * noiseLabelWeights[i][label]).toArray();
updateBinaryClassifier(component, label, activeDataSet, activeInstanceWeights);
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class ENCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
double[] overallWeights = new double[activeGammas.length];
for (int i = 0; i < overallWeights.length; i++) {
overallWeights[i] = activeGammas[i] * instanceWeights[i];
}
ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, 2, targetsDistribution, overallWeights).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).setMaxNumLinearRegUpdates(maxNumLinearRegUpdates).build();
elasticNetLogisticTrainer.setActiveSet(activeSet);
elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
elasticNetLogisticTrainer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.classification.PriorProbClassifier in project pyramid by cheng-li.
the class LRRecoverCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
}
RidgeLogisticOptimizer ridgeLogisticOptimizer;
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
// no parallelism
ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, binaryLabels, activeGammas, priorVarianceBinary, false);
ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(binaryUpdatesPerIter);
ridgeLogisticOptimizer.optimize();
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
Aggregations