Search in sources :

Example 16 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class LogRiskOptimizerTest method test1.

private static void test1() {
    MultiLabelClfDataSet train = MultiLabelSynthesizer.independentNoise();
    MultiLabelClfDataSet test = MultiLabelSynthesizer.independent();
    CMLCRF cmlcrf = new CMLCRF(train);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -1);
    MLScorer fScorer = new FScorer();
    LogRiskOptimizer fOptimizer = new LogRiskOptimizer(train, fScorer, cmlcrf, 1, false, false, 1, 1);
    InstanceF1Predictor plugInF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println(cmlcrf);
    System.out.println("initial loss = " + fOptimizer.objective());
    System.out.println("training performance acc");
    System.out.println(new MLMeasures(cmlcrf, train));
    System.out.println("test performance acc");
    System.out.println(new MLMeasures(cmlcrf, test));
    System.out.println("training performance f1");
    System.out.println(new MLMeasures(plugInF1, train));
    System.out.println("test performance f1");
    System.out.println(new MLMeasures(plugInF1, test));
    while (!fOptimizer.getTerminator().shouldTerminate()) {
        System.out.println("------------");
        fOptimizer.iterate();
        System.out.println(fOptimizer.getTerminator().getLastValue());
        System.out.println("training performance acc");
        System.out.println(new MLMeasures(cmlcrf, train));
        System.out.println("test performance acc");
        System.out.println(new MLMeasures(cmlcrf, test));
        System.out.println("training performance f1");
        System.out.println(new MLMeasures(plugInF1, train));
        System.out.println("test performance f1");
        System.out.println(new MLMeasures(plugInF1, test));
    }
    System.out.println(cmlcrf);
}
Also used : MLScorer(edu.neu.ccs.pyramid.multilabel_classification.MLScorer) FScorer(edu.neu.ccs.pyramid.multilabel_classification.FScorer) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 17 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    int numComponents = 4;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setMultiClassClassifierType("lr").setBinaryClassifierType("boost").build();
    cbm.setPredictMode("dynamic");
    CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
    optimizer.setPriorVarianceBinary(10);
    optimizer.setPriorVarianceMultiClass(10);
    CBMInitializer.initialize(cbm, dataSet, optimizer);
    cbm.setNumSample(100);
    System.out.println("num cluster: " + cbm.numComponents);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(cbm, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(cbm, testSet));
    for (int i = 1; i <= 5; i++) {
        optimizer.iterate();
        System.out.print("iter : " + i + "\t");
        System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
        System.out.print("trainAcc : " + Accuracy.accuracy(cbm, dataSet) + "\t");
        System.out.print("trainOver: " + Overlap.overlap(cbm, dataSet) + "\t");
        System.out.print("testAcc  : " + Accuracy.accuracy(cbm, testSet) + "\t");
        System.out.println("testOver : " + Overlap.overlap(cbm, testSet) + "\t");
    }
    System.out.println("history = " + optimizer.getTerminator().getHistory());
    System.out.println(cbm);
}
Also used : File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 18 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMTest method test4.

private static void test4() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "flags/data_sets/train"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "flags/data_sets/test"), DataSetType.ML_CLF_SPARSE, true);
    int numComponents = 4;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("lr").setMultiClassClassifierType("lr").build();
    //        cbm.setPredictMode("dynamic");
    CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
    optimizer.setPriorVarianceBinary(1);
    optimizer.setPriorVarianceMultiClass(1);
    CBMInitializer.initialize(cbm, dataSet, optimizer);
    //        cbm.setNumSample(100);
    System.out.println("num cluster: " + cbm.numComponents);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(cbm, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(cbm, testSet));
    for (int i = 1; i <= 30; i++) {
        optimizer.iterate();
        System.out.print("iter : " + i + "\t");
        System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
        System.out.print("trainAcc : " + Accuracy.accuracy(cbm, dataSet) + "\t");
        System.out.print("trainOver: " + Overlap.overlap(cbm, dataSet) + "\t");
        System.out.print("testAcc  : " + Accuracy.accuracy(cbm, testSet) + "\t");
        System.out.println("testOver : " + Overlap.overlap(cbm, testSet) + "\t");
    }
    System.out.println("history = " + optimizer.getTerminator().getHistory());
    System.out.println(cbm);
}
Also used : File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 19 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CMLCRFTest method test4.

private static void test4() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 50; i++) {
        //            System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    //            System.out.println("crf = "+cmlcrf.getWeights());
    //            System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 20 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class Trec2Meka method main.

public static void main(String[] args) throws IOException, ClassNotFoundException {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    List<String> trecs = config.getStrings("trec");
    List<String> mekas = config.getStrings("meka");
    // for label xml information
    String xmlFile = config.getString("xml");
    for (int i = 0; i < trecs.size(); i++) {
        MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(trecs.get(i)), DataSetType.ML_CLF_SPARSE, true);
        System.out.println(i + " -- Translating on trecs: " + trecs.get(i));
        MekaFormat.save(dataSet, mekas.get(i), config.getString("data.name"));
        if (i == 0) {
            MekaFormat.saveXML(dataSet, xmlFile);
        }
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)48 File (java.io.File)24 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)23 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)12 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)9 Vector (org.apache.mahout.math.Vector)9 Config (edu.neu.ccs.pyramid.configuration.Config)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 DenseVector (org.apache.mahout.math.DenseVector)7 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 java.util (java.util)5 Collectors (java.util.stream.Collectors)5 IntStream (java.util.stream.IntStream)5 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)4 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)4 MLScorer (edu.neu.ccs.pyramid.multilabel_classification.MLScorer)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 AccScorer (edu.neu.ccs.pyramid.multilabel_classification.AccScorer)3