Search in sources :

Example 1 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class ENCBMOptimizer method updateBinaryClassifier.

@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
        cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
    }
    int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
    double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
    ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, 2, targetsDistribution, activeGammas).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).build();
    elasticNetLogisticTrainer.setActiveSet(activeSet);
    elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
    elasticNetLogisticTrainer.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
    }
}
Also used : PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 2 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class LRCBMOptimizer method updateBinaryClassifier.

@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
        cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
    }
    RidgeLogisticOptimizer ridgeLogisticOptimizer;
    int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
    // no parallelism
    ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[component][label], activeDataset, binaryLabels, activeGammas, priorVarianceBinary, false);
    ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(binaryUpdatesPerIter);
    ridgeLogisticOptimizer.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
    }
}
Also used : PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 3 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class PlattScaling method fitClassK.

private static LogisticRegression fitClassK(double[] scores, int[] labels) {
    ClfDataSet dataSet = ClfDataSetBuilder.getBuilder().numClasses(2).numDataPoints(scores.length).numFeatures(1).dense(true).missingValue(false).build();
    for (int i = 0; i < scores.length; i++) {
        dataSet.setFeatureValue(i, 0, scores[i]);
        dataSet.setLabel(i, labels[i]);
    }
    LogisticRegression logisticRegression = new LogisticRegression(2, dataSet.getNumFeatures());
    ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setRegularization(1.0E-9).setL1Ratio(0).build();
    trainer.optimize();
    return logisticRegression;
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)

Example 4 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class CBMNoiseOptimizerFixed method updateBinaryLogisticRegression.

private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
    RidgeLogisticOptimizer ridgeLogisticOptimizer;
    //        System.out.println("for component "+componentIndex+"  label "+labelIndex);
    //        System.out.println("weights="+Arrays.toString(gammasT[componentIndex]));
    //        System.out.println("binary target distribution="+Arrays.deepToString(binaryTargetsDistributions[labelIndex]));
    //        double posProb = 0;
    //        double negProb = 0;
    //        for (int i=0;i<dataSet.getNumDataPoints();i++){
    //            posProb += gammasT[componentIndex][i] * binaryTargetsDistributions[labelIndex][i][1];
    //            negProb += gammasT[componentIndex][i] * binaryTargetsDistributions[labelIndex][i][0];
    //        }
    //        System.out.println("sum pos prob = "+posProb);
    //        System.out.println("sum neg prob = "+negProb);
    // no parallelism
    ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, gammasT[componentIndex], binaryTargetsDistributions[labelIndex], priorVarianceBinary, false);
    //TODO maximum iterations
    ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(10);
    ridgeLogisticOptimizer.optimize();
//        if (logger.isDebugEnabled()){
//            logger.debug("for cluster "+clusterIndex+" label "+labelIndex+" history= "+ridgeLogisticOptimizer.getOptimizer().getTerminator().getHistory());
//        }
}
Also used : LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)

Example 5 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test4.

private static void test4() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(3).setNoiseSD(0.00000001).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line4/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line4/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Aggregations

LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)19 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)7 File (java.io.File)7 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)5 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)5 Vector (org.apache.mahout.math.Vector)5 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)3 Classifier (edu.neu.ccs.pyramid.classification.Classifier)2 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 IntStream (java.util.stream.IntStream)2 StopWatch (org.apache.commons.lang3.time.StopWatch)2 DenseVector (org.apache.mahout.math.DenseVector)2 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)1 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)1 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1