Search in sources :

Example 26 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(true);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 5000; i++) {
        //            System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    //            System.out.println("crf = "+cmlcrf.getWeights());
    //            System.out.println(Arrays.toString(predTrain));
    }
//        LBFGS optimizer = new LBFGS(crfLoss);
//        optimizer.getTerminator().setAbsoluteEpsilon(0.01);
//        optimizer.optimize();
//        predTrain = cmlcrf.predict(dataSet);
//        predTest = cmlcrf.predict(testSet);
//        System.out.print("Train acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
//        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
//        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
//        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 27 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFTest method test8.

private static void test8() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf = new CMLCRF(trainSet);
    BlockwiseCD blockwiseCD = new BlockwiseCD(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    for (int i = 0; i < 10000; i++) {
        blockwiseCD.iterate();
        predTrain = cmlcrf.predict(trainSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("iter: " + String.format("%04d", i));
        System.out.print("\tobjective: " + String.format("%.4f", blockwiseCD.getValue()));
        System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
        System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
        System.out.print("\tTrain F1 " + String.format("%.4f", FMeasure.f1(trainSet.getMultiLabels(), predTrain)));
        System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
        System.out.print("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
        System.out.println("\tTest F1 " + String.format("%.4f", FMeasure.f1(testSet.getMultiLabels(), predTest)));
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
    System.out.println("========== Train ==========\n");
    System.out.println(measures);
    System.out.println("========== Test ==========\n");
    long startTimePred = System.nanoTime();
    MultiLabel[] preds = cmlcrf.predict(testSet);
    long stopTimePred = System.nanoTime();
    long predTime = stopTimePred - startTimePred;
    System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
    System.out.println(new MLMeasures(cmlcrf, testSet));
    System.out.println("\n\n");
    InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println("Plugin F1");
    System.out.println(new MLMeasures(pluginF1, testSet));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 28 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFTest method test4.

private static void test4() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 50; i++) {
        //            System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    //            System.out.println("crf = "+cmlcrf.getWeights());
    //            System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 29 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class BMSelector method selectAll.

public static Pair<BM, double[][]> selectAll(int numClasses, MultiLabel[] multiLabels, int numClusters) {
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel multiLabel = multiLabels[i];
        for (int label : multiLabel.getMatchedLabels()) {
            dataSet.setFeatureValue(i, label, 1);
        }
    }
    BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
    //        System.out.println("bm = "+trainer.bm);
    //        System.out.println("gamma = "+ Arrays.deepToString(trainer.gammas));
    Pair<BM, double[][]> pair = new Pair<>();
    pair.setFirst(trainer.getBm());
    pair.setSecond(trainer.gammas);
    return pair;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 30 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBM method predictByMarginals.

/**
     * sort marginals, and keep top few
     * @param vector
     * @param top
     * @return
     */
public MultiLabel predictByMarginals(Vector vector, int top) {
    double[] probs = predictClassProbs(vector);
    int[] sortedIndices = ArgSort.argSortDescending(probs);
    MultiLabel prediction = new MultiLabel();
    for (int i = 0; i < top; i++) {
        prediction.addLabel(sortedIndices[i]);
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3