Search in sources :

Example 21 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class AugmentedLRLossTest method main.

public static void main(String[] args) throws Exception {
    LoggerContext ctx = (LoggerContext) LogManager.getContext(false);
    Configuration config = ctx.getConfiguration();
    LoggerConfig loggerConfig = config.getLoggerConfig(LogManager.ROOT_LOGGER_NAME);
    loggerConfig.setLevel(Level.DEBUG);
    ctx.updateLoggers();
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
    AugmentedLR augmentedLR = new AugmentedLR(dataSet.getNumFeatures(), 1);
    double[][] gammas = new double[dataSet.getNumDataPoints()][1];
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        gammas[i][0] = 1;
    }
    AugmentedLRLoss loss = new AugmentedLRLoss(dataSet, 0, gammas, augmentedLR, 1, 1);
    LBFGS lbfgs = new LBFGS(loss);
    for (int i = 0; i < 100; i++) {
        lbfgs.iterate();
        System.out.println(loss.getValue());
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Configuration(org.apache.logging.log4j.core.config.Configuration) LoggerContext(org.apache.logging.log4j.core.LoggerContext) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) LoggerConfig(org.apache.logging.log4j.core.config.LoggerConfig)

Example 22 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class App6 method train.

private static void train(Config config) throws Exception {
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(trainSet);
    double gaussianVariance = config.getDouble("train.gaussianVariance");
    cmlcrf.setConsiderPair(true);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
    int maxIteration = config.getInt("train.maxIteration");
    crfLoss.setRegularizeAll(true);
    LBFGS optimizer = new LBFGS(crfLoss);
    optimizer.getTerminator().setMaxIteration(maxIteration);
    PluginPredictor<CMLCRF> predictor = null;
    String predictTarget = config.getString("predict.target");
    switch(predictTarget) {
        case "subsetAccuracy":
            predictor = new SubsetAccPredictor(cmlcrf);
            break;
        case "instanceFMeasure":
            predictor = new InstanceF1Predictor(cmlcrf);
            break;
        default:
            throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
    }
    int progressInterval = config.getInt("train.showProgress.interval");
    System.out.println("start training");
    int iteration = 0;
    while (true) {
        optimizer.iterate();
        iteration += 1;
        if (iteration % progressInterval == 0) {
            System.out.println("iteration " + iteration);
            System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
            System.out.println("training performance:");
            System.out.println(new MLMeasures(predictor, trainSet));
            System.out.println("test performance:");
            System.out.println(new MLMeasures(predictor, testSet));
            String modelName = "model_crf";
            String output = config.getString("output.folder");
            (new File(output)).mkdirs();
            File serializeModel = new File(output, modelName);
            cmlcrf.serialize(serializeModel);
        }
        if (optimizer.getTerminator().shouldTerminate()) {
            System.out.println("iteration " + iteration);
            System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
            System.out.println("training performance:");
            System.out.println(new MLMeasures(predictor, trainSet));
            System.out.println("test performance:");
            System.out.println(new MLMeasures(predictor, testSet));
            System.out.println("training done!");
            break;
        }
    }
    String modelName = "model_crf";
    String output = config.getString("output.folder");
    (new File(output)).mkdirs();
    File serializeModel = new File(output, modelName);
    cmlcrf.serialize(serializeModel);
    MultiLabel[] predictions = cmlcrf.predict(trainSet);
    File predictionFile = new File(output, "train_predictions.txt");
    FileUtils.writeStringToFile(predictionFile, PrintUtil.toMutipleLines(predictions));
    System.out.println("predictions on the training set are written to " + predictionFile.getAbsolutePath());
    if (config.getBoolean("train.generateReports")) {
        report(config, trainSet, "trainSet");
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File)

Aggregations

LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)22 File (java.io.File)12 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)9 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)7 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)7 Optimizer (edu.neu.ccs.pyramid.optimization.Optimizer)6 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)2 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)2 ArrayList (java.util.ArrayList)2 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 LoggerContext (org.apache.logging.log4j.core.LoggerContext)1 Configuration (org.apache.logging.log4j.core.config.Configuration)1 LoggerConfig (org.apache.logging.log4j.core.config.LoggerConfig)1