Search in sources :

Example 11 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class CBMSOptimizer method updateBinaryLogisticRegression.

// todo pay attention to parallelism
private void updateBinaryLogisticRegression(int labelIndex) {
    AugmentedLRLoss loss = new AugmentedLRLoss(dataSet, labelIndex, gammas, cbms.getBinaryClassifiers()[labelIndex], priorVarianceBinary, componentWeightsVariance);
    LBFGS lbfgs = new LBFGS(loss);
    // todo
    lbfgs.getTerminator().setMaxIteration(numBinaryParaUpdates);
    lbfgs.getTerminator().setGoal(Terminator.Goal.MINIMIZE);
    lbfgs.optimize();
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS)

Example 12 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class NoiseOptimizer method updateModel.

private void updateModel() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateModel()");
    }
    KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
    // todo
    Optimizer opt = null;
    switch(optimizer) {
        case "LBFGS":
            opt = new LBFGS(klLoss);
            break;
        case "GD":
            opt = new GradientDescent(klLoss);
            break;
        default:
            throw new IllegalArgumentException("unknown");
    }
    opt.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateModel()");
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Optimizer(edu.neu.ccs.pyramid.optimization.Optimizer) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent)

Example 13 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class NoiseOptimizer method updateModelPartial.

private void updateModelPartial(int modelIterations) {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateModelPartial()");
    }
    KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
    // todo
    Optimizer opt = null;
    switch(optimizer) {
        case "LBFGS":
            opt = new LBFGS(klLoss);
            break;
        case "GD":
            opt = new GradientDescent(klLoss);
            break;
        default:
            throw new IllegalArgumentException("unknown");
    }
    opt.getTerminator().setMaxIteration(modelIterations);
    opt.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateModelPartial()");
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Optimizer(edu.neu.ccs.pyramid.optimization.Optimizer) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent)

Example 14 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class LogRiskOptimizer method updateModel.

private void updateModel() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateModel()");
    }
    KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
    // todo
    Optimizer opt = null;
    switch(optimizer) {
        case "LBFGS":
            opt = new LBFGS(klLoss);
            break;
        case "GD":
            opt = new GradientDescent(klLoss);
            break;
        default:
            throw new IllegalArgumentException("unknown");
    }
    opt.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateModel()");
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Optimizer(edu.neu.ccs.pyramid.optimization.Optimizer) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent)

Example 15 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class KLLossTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    List<MultiLabel> support = cmlcrf.getSupportCombinations();
    double[][] targetDistribution = new double[dataSet.getNumDataPoints()][cmlcrf.getSupportCombinations().size()];
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        for (int c = 0; c < support.size(); c++) {
            MultiLabel multiLabel = dataSet.getMultiLabels()[i];
            if (support.get(c).equals(multiLabel)) {
                targetDistribution[i][c] = 1;
            }
        }
    }
    System.out.println("start");
    KLLoss klLoss = new KLLoss(cmlcrf, dataSet, targetDistribution, 1);
    cmlcrf.setConsiderPair(true);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(klLoss);
    for (int i = 0; i < 200; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(klLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)22 File (java.io.File)12 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)9 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)7 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)7 Optimizer (edu.neu.ccs.pyramid.optimization.Optimizer)6 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)2 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)2 ArrayList (java.util.ArrayList)2 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 LoggerContext (org.apache.logging.log4j.core.LoggerContext)1 Configuration (org.apache.logging.log4j.core.config.Configuration)1 LoggerConfig (org.apache.logging.log4j.core.config.LoggerConfig)1