Search in sources :

Example 36 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class AugmentedLRLossTest method main.

public static void main(String[] args) throws Exception {
    LoggerContext ctx = (LoggerContext) LogManager.getContext(false);
    Configuration config = ctx.getConfiguration();
    LoggerConfig loggerConfig = config.getLoggerConfig(LogManager.ROOT_LOGGER_NAME);
    loggerConfig.setLevel(Level.DEBUG);
    ctx.updateLoggers();
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
    AugmentedLR augmentedLR = new AugmentedLR(dataSet.getNumFeatures(), 1);
    double[][] gammas = new double[dataSet.getNumDataPoints()][1];
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        gammas[i][0] = 1;
    }
    AugmentedLRLoss loss = new AugmentedLRLoss(dataSet, 0, gammas, augmentedLR, 1, 1);
    LBFGS lbfgs = new LBFGS(loss);
    for (int i = 0; i < 100; i++) {
        lbfgs.iterate();
        System.out.println(loss.getValue());
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Configuration(org.apache.logging.log4j.core.config.Configuration) LoggerContext(org.apache.logging.log4j.core.LoggerContext) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) LoggerConfig(org.apache.logging.log4j.core.config.LoggerConfig)

Example 37 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMInitializerTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/flags/train"), DataSetType.ML_CLF_SPARSE, true);
    int numClusters = 2;
    double softmaxVariance = 1000;
    double logitVariance = 1000;
    CBM cbm = CBM.getBuilder().setNumClasses(trainSet.getNumClasses()).setNumFeatures(trainSet.getNumFeatures()).setNumComponents(numClusters).setBinaryClassifierType("lr").setMultiClassClassifierType("lr").build();
    CBMOptimizer optimizer = new CBMOptimizer(cbm, trainSet);
    CBMInitializer.initialize(cbm, trainSet, optimizer);
}
Also used : File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 38 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMInspectorTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "meka_imdb/1/data_sets/test"), DataSetType.ML_CLF_SPARSE, true);
    CBM CBM = (CBM) Serialization.deserialize(new File(TMP, "model"));
    System.out.println(Accuracy.accuracy(CBM, testSet));
    for (int i = 0; i < testSet.getNumDataPoints(); i++) {
        MultiLabel trueLabel = testSet.getMultiLabels()[i];
        MultiLabel pred = CBM.predict(testSet.getRow(i));
        MultiLabel expectation = CBM.predictByMarginals(testSet.getRow(i));
        if (pred.equals(trueLabel) && !pred.equals(expectation) && expectation.getMatchedLabels().size() > 0) {
            System.out.println("==============================");
            System.out.println("data point " + i);
            System.out.println("prediction = " + pred);
            System.out.println("expectation = " + expectation);
            CBMInspector.covariance(CBM, testSet.getRow(i), testSet.getLabelTranslator());
        }
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 39 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMTest method test2.

private static void test2() throws Exception {
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(1000).build();
    BernoulliDistribution bernoulliDistribution = new BernoulliDistribution(0.5);
    for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
        for (int m = 0; m < dataSet.getNumFeatures(); m++) {
            int bit = bernoulliDistribution.sample();
            int flip = bit;
            if (Math.random() < 0.1) {
                flip = 1 - bit;
            }
            dataSet.setFeatureValue(n, m, bit);
            if (m == 0) {
                if (flip == 0) {
                    dataSet.addLabel(n, 0);
                } else {
                    dataSet.addLabel(n, 1);
                }
            } else {
                if (flip == 0) {
                    dataSet.addLabel(n, 2);
                } else {
                    dataSet.addLabel(n, 3);
                }
            }
        }
    }
    MultiLabelClfDataSet testSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(100).build();
    for (int n = 0; n < testSet.getNumDataPoints(); n++) {
        for (int m = 0; m < testSet.getNumFeatures(); m++) {
            int bit = bernoulliDistribution.sample();
            testSet.setFeatureValue(n, m, bit);
            int flip = bit;
            if (Math.random() < 0.1) {
                flip = 1 - bit;
            }
            if (m == 0) {
                if (flip == 0) {
                    testSet.addLabel(n, 0);
                } else {
                    testSet.addLabel(n, 1);
                }
            } else {
                if (flip == 0) {
                    testSet.addLabel(n, 2);
                } else {
                    testSet.addLabel(n, 3);
                }
            }
        }
    }
    int numComponents = 4;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("boost").setMultiClassClassifierType("boost").build();
    cbm.setPredictMode("dynamic");
    CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
    optimizer.setPriorVarianceBinary(10);
    optimizer.setPriorVarianceMultiClass(10);
    CBMInitializer.initialize(cbm, dataSet, optimizer);
    for (int i = 0; i < 3; i++) {
        optimizer.iterate();
        System.out.print("i: " + i + "\t");
        System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
        System.out.print("trainAcc: " + Accuracy.accuracy(cbm, dataSet) + "\t");
        System.out.println("testAcc: " + Accuracy.accuracy(cbm, testSet));
    }
    System.out.println(cbm.toString());
}
Also used : BernoulliDistribution(edu.neu.ccs.pyramid.util.BernoulliDistribution) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 40 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CMLCRFTest method test5.

private static void test5() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(false);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 5; i++) {
        //            System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    //            System.out.println("crf = "+cmlcrf.getWeights());
    //            System.out.println(Arrays.toString(predTrain));
    }
    CRFLoss crfLoss2 = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(true);
    LBFGS optimizer2 = new LBFGS(crfLoss2);
    for (int i = 0; i < 50; i++) {
        System.out.println("consider pairs");
        //            System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer2.iterate();
        System.out.println(crfLoss2.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    //            System.out.println("crf = "+cmlcrf.getWeights());
    //            System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)48 File (java.io.File)24 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)23 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)12 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)9 Vector (org.apache.mahout.math.Vector)9 Config (edu.neu.ccs.pyramid.configuration.Config)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 DenseVector (org.apache.mahout.math.DenseVector)7 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 java.util (java.util)5 Collectors (java.util.stream.Collectors)5 IntStream (java.util.stream.IntStream)5 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)4 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)4 MLScorer (edu.neu.ccs.pyramid.multilabel_classification.MLScorer)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 AccScorer (edu.neu.ccs.pyramid.multilabel_classification.AccScorer)3