Search in sources :

Example 6 with GradientDescent

use of edu.neu.ccs.pyramid.optimization.GradientDescent in project pyramid by cheng-li.

the class LogRiskOptimizer method updateModel.

private void updateModel() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateModel()");
    }
    KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
    // todo
    Optimizer opt = null;
    switch(optimizer) {
        case "LBFGS":
            opt = new LBFGS(klLoss);
            break;
        case "GD":
            opt = new GradientDescent(klLoss);
            break;
        default:
            throw new IllegalArgumentException("unknown");
    }
    opt.optimize();
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateModel()");
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Optimizer(edu.neu.ccs.pyramid.optimization.Optimizer) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent)

Example 7 with GradientDescent

use of edu.neu.ccs.pyramid.optimization.GradientDescent in project pyramid by cheng-li.

the class CMLCRFTest method test2.

public static void test2() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
    double gaussianVariance = config.getDouble("gaussianVariance");
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf;
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    if (config.getBoolean("train.warmStart")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else {
        cmlcrf = new CMLCRF(trainSet);
        CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
        if (config.getBoolean("isLBFGS")) {
            LBFGS optimizer = new LBFGS(crfLoss);
            optimizer.getTerminator().setAbsoluteEpsilon(0.1);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        } else {
            GradientDescent optimizer = new GradientDescent(crfLoss);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        }
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    predTrain = cmlcrf.predict(trainSet);
    predTest = cmlcrf.predict(testSet);
    System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
    System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 8 with GradientDescent

use of edu.neu.ccs.pyramid.optimization.GradientDescent in project pyramid by cheng-li.

the class LogisticRegressionTest method test1.

private static void test1() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, false);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, false);
    System.out.println(dataSet.getMetaInfo());
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    LogisticLoss function = new LogisticLoss(logisticRegression, dataSet, 1000, true);
    GradientDescent gradientDescent = new GradientDescent(function);
    gradientDescent.getLineSearcher().setInitialStepLength(1.0E-4);
    gradientDescent.optimize();
    System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent) ConjugateGradientDescent(edu.neu.ccs.pyramid.optimization.ConjugateGradientDescent) File(java.io.File)

Aggregations

GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 Optimizer (edu.neu.ccs.pyramid.optimization.Optimizer)6 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)2 File (java.io.File)2 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)1 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)1 ConjugateGradientDescent (edu.neu.ccs.pyramid.optimization.ConjugateGradientDescent)1