Search in sources :

Example 6 with ElasticNetLogisticTrainer

use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.

the class CBMNoiseOptimizerFixed method updateBinaryLogisticRegressionEL.

private void updateBinaryLogisticRegressionEL(int componentIndex, int labelIndex) {
    ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, 2, binaryTargetsDistributions[labelIndex], gammasT[componentIndex]).setRegularization(regularizationBinary).setL1Ratio(l1RatioBinary).setLineSearch(lineSearch).build();
    //TODO: maximum iterations
    elasticNetLogisticTrainer.getTerminator().setMaxIteration(10);
    elasticNetLogisticTrainer.optimize();
}
Also used : ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)

Example 7 with ElasticNetLogisticTrainer

use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.

the class LKTreeBoostTest method logisticTest.

static void logisticTest() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
    System.out.println(dataSet.getMetaInfo());
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
    LKBoost lkBoost = new LKBoost(2);
    LKBoostOptimizer trainer = new LKBoostOptimizer(lkBoost, dataSet);
    trainer.initialize();
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    ElasticNetLogisticTrainer logisticTrainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0.9).setRegularization(0.001).build();
    logisticTrainer.optimize();
    System.out.println("logistic regression accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
    System.out.println("num feature used = " + LogisticRegressionInspector.numOfUsedFeaturesCombined(logisticRegression));
    //        lktbTrainer.addLogisticRegression(logisticRegression);
    System.out.println("boosting accuracy = " + Accuracy.accuracy(lkBoost, testSet));
    for (int i = 0; i < 100; i++) {
        trainer.iterate();
        System.out.println("iteration " + i);
        System.out.println("boosting accuracy = " + Accuracy.accuracy(lkBoost, testSet));
    }
}
Also used : ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Example 8 with ElasticNetLogisticTrainer

use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.

the class CBMNoiseOptimizerFixed method updateMultiClassEL.

private void updateMultiClassEL() {
    ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.multiClassClassifier, dataSet, cbm.multiClassClassifier.getNumClasses(), gammas).setRegularization(regularizationMultiClass).setL1Ratio(l1RatioMultiClass).setLineSearch(lineSearch).build();
    // TODO: maximum iterations
    elasticNetLogisticTrainer.getTerminator().setMaxIteration(10);
    elasticNetLogisticTrainer.optimize();
}
Also used : ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)

Example 9 with ElasticNetLogisticTrainer

use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.

the class MLPlattScaling method fitClassK.

private static LogisticRegression fitClassK(double[] scores, int[] labels) {
    ClfDataSet dataSet = ClfDataSetBuilder.getBuilder().numClasses(2).numDataPoints(scores.length).numFeatures(1).dense(true).missingValue(false).build();
    for (int i = 0; i < scores.length; i++) {
        dataSet.setFeatureValue(i, 0, scores[i]);
        dataSet.setLabel(i, labels[i]);
    }
    LogisticRegression logisticRegression = new LogisticRegression(2, dataSet.getNumFeatures());
    ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setRegularization(1.0E-9).setL1Ratio(0).build();
    trainer.optimize();
    return logisticRegression;
}
Also used : ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)

Example 10 with ElasticNetLogisticTrainer

use of edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer in project pyramid by cheng-li.

the class ENCBMOptimizer method updateMultiClassClassifier.

@Override
protected void updateMultiClassClassifier() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateMultiClassClassifier");
    }
    ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression) cbm.multiClassClassifier, dataSet, cbm.multiClassClassifier.getNumClasses(), gammas).setRegularization(regularizationMultiClass).setL1Ratio(l1RatioMultiClass).setLineSearch(lineSearch).build();
    elasticNetLogisticTrainer.setActiveSet(activeSet);
    elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.multiclassUpdatesPerIter);
    elasticNetLogisticTrainer.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateMultiClassClassifier");
    }
}
Also used : ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)

Aggregations

ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)10 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)4 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 File (java.io.File)1 StopWatch (org.apache.commons.lang3.time.StopWatch)1