Search in sources :

Example 31 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class ShortCircuitPosteriorTest method main.

public static void main(String[] args) throws Exception {
    //        System.out.println((0.0+10)/(30.0+10000));
    double[] s = { -40, -40, -20, 0 };
    System.out.println(Arrays.toString(MathUtil.softmax(s)));
    //        System.out.println(MathUtil.logSoftmax(s)[0]);
    //        System.out.println(Math.exp(-20));
    //        MultiLabelClfDataSet train = TRECFormat.loadMultiLabelClfDataSet("/Users/chengli/tmp/mlc_data_pyramid/rcv1subset_topics_1/train_test_split/train", DataSetType.ML_CLF_SEQ_SPARSE,true);
    MultiLabelClfDataSet test = TRECFormat.loadMultiLabelClfDataSet("/Users/chengli/tmp/mlc_data_pyramid/rcv1subset_topics_1/train_test_split/test", DataSetType.ML_CLF_SEQ_SPARSE, true);
    //        boolean[] check = new boolean[train.getNumClasses()];
    //        for (int i=0;i<train.getNumDataPoints();i++){
    //            MultiLabel multiLabel = train.getMultiLabels()[i];
    //            for (int l:multiLabel.getMatchedLabels()){
    //                check[l]=true;
    //            }
    //        }
    //        System.out.println(Arrays.toString(check));
    int dataIndex = 190;
    CBM cbm = (CBM) Serialization.deserialize("/Users/chengli/tmp/model");
    BMDistribution distribution = cbm.computeBM(test.getRow(dataIndex));
    System.out.println("pi");
    System.out.println(Arrays.toString(cbm.getMultiClassClassifier().predictClassProbs(test.getRow(dataIndex))));
    System.out.println("posterior");
    System.out.println(Arrays.toString(distribution.posteriorMembership(test.getMultiLabels()[dataIndex])));
    System.out.println("approximate posterior = ");
    System.out.println(Arrays.toString(new ShortCircuitPosterior(cbm, test.getRow(dataIndex), test.getMultiLabels()[dataIndex]).posteriorMembership()));
    System.out.println(Arrays.toString(distribution.getLogClassProbs()));
    for (int k = 0; k < cbm.getNumComponents(); k++) {
        System.out.println("k=" + k);
        System.out.println(cbm.getMultiClassClassifier().predictLogClassProbs(test.getRow(dataIndex))[k]);
        System.out.println(distribution.logYGivenComponentByDefault(test.getMultiLabels()[dataIndex], k));
        System.out.println(distribution.posteriorMembership(test.getMultiLabels()[dataIndex])[k]);
    }
    double[][][] logClassProbs = distribution.getLogClassProbs();
    for (int l = 0; l < test.getNumClasses(); l++) {
        final int label = l;
        double max = IntStream.range(0, cbm.getNumComponents()).mapToDouble(k -> logClassProbs[k][label][1]).max().getAsDouble();
        System.out.println("label " + l);
        System.out.println("max = " + max);
    }
    System.out.println(distribution.logProbability(test.getMultiLabels()[dataIndex]));
    for (int k = 0; k < cbm.getNumComponents(); k++) {
        System.out.println(distribution.logYGivenComponentByDefault(test.getMultiLabels()[dataIndex], k));
    }
//        System.out.println(cbm.predictLogAssignmentProb(test.getRow(dataIndex),test.getMultiLabels()[dataIndex]));
}
Also used : MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 32 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class SparseCBMOptimzerTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
    int numComponents = 10;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setMultiClassClassifierType("lr").setBinaryClassifierType("lr").build();
    SparseCBMOptimzer optimzer = new SparseCBMOptimzer(cbm, dataSet);
    optimzer.initalizeGammaByBM();
    optimzer.updateMultiClassLR();
    optimzer.updateAllBinary();
    //        System.out.println(new MLMeasures(cbm, dataSet));
    System.out.println("test");
    System.out.println(new MLMeasures(cbm, testSet));
    System.out.println("update gamma");
    optimzer.updateGamma();
    optimzer.updateMultiClassLR();
    optimzer.updateAllBinary();
    //        System.out.println(new MLMeasures(cbm, dataSet));
    System.out.println("test");
    System.out.println(new MLMeasures(cbm, testSet));
    System.out.println("update gamma again");
    optimzer.updateGamma();
    optimzer.updateMultiClassLR();
    optimzer.updateAllBinary();
    //        System.out.println(new MLMeasures(cbm, dataSet));
    System.out.println("test");
    System.out.println(new MLMeasures(cbm, testSet));
}
Also used : File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 33 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CMLCRFTest method test2.

public static void test2() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
    double gaussianVariance = config.getDouble("gaussianVariance");
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf;
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    if (config.getBoolean("train.warmStart")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else {
        cmlcrf = new CMLCRF(trainSet);
        CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
        if (config.getBoolean("isLBFGS")) {
            LBFGS optimizer = new LBFGS(crfLoss);
            optimizer.getTerminator().setAbsoluteEpsilon(0.1);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        } else {
            GradientDescent optimizer = new GradientDescent(crfLoss);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        }
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    predTrain = cmlcrf.predict(trainSet);
    predTest = cmlcrf.predict(testSet);
    System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
    System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 34 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CMLCRFTest method test7.

private static void test7() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf = null;
    if (config.getString("train.warmStart").equals("true")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else if (config.getString("train.warmStart").equals("auto")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("retrain model:");
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    } else if (config.getString("train.warmStart").equals("false")) {
        cmlcrf = new CMLCRF(trainSet);
        cmlcrf.setConsiderPair(config.getBoolean("considerLabelPair"));
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
    System.out.println("========== Train ==========\n");
    System.out.println(measures);
    System.out.println("========== Test ==========\n");
    long startTimePred = System.nanoTime();
    MultiLabel[] preds = cmlcrf.predict(testSet);
    long stopTimePred = System.nanoTime();
    long predTime = stopTimePred - startTimePred;
    System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
    System.out.println(new MLMeasures(cmlcrf, testSet));
    System.out.println("\n\n");
    InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println("Plugin F1");
    System.out.println(new MLMeasures(pluginF1, testSet));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 35 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class NoiseOptimizerTest method test1.

private static void test1() {
    MultiLabelClfDataSet train = MultiLabelSynthesizer.crfArgmaxDrop();
    MultiLabelClfDataSet test = MultiLabelSynthesizer.crfArgmax();
    TRECFormat.save(train, new File(TMP, "train"));
    TRECFormat.save(test, new File(TMP, "test"));
    CMLCRF cmlcrf = new CMLCRF(train);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
    MLScorer accScorer = new AccScorer();
    SubsetAccPredictor plugInAcc = new SubsetAccPredictor(cmlcrf);
    InstanceF1Predictor plugInF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println(cmlcrf);
    System.out.println("training performance acc");
    System.out.println(new MLMeasures(cmlcrf, train));
    System.out.println("test performance acc");
    System.out.println(new MLMeasures(cmlcrf, test));
    System.out.println("training performance f1");
    System.out.println(new MLMeasures(plugInF1, train));
    System.out.println("test performance f1");
    System.out.println(new MLMeasures(plugInF1, test));
    LogRiskOptimizer accOptimizer = new LogRiskOptimizer(train, accScorer, cmlcrf, 1, false, false, 1, 1);
    accOptimizer.iterate();
    System.out.println("after ML estimation");
    System.out.println("training with Acc predictor");
    System.out.println(new MLMeasures(plugInAcc, train));
    System.out.println("training with F1 predictor");
    System.out.println(new MLMeasures(plugInF1, train));
    System.out.println("test with Acc predictor");
    System.out.println(new MLMeasures(plugInAcc, test));
    System.out.println("test with F1 predictor");
    System.out.println(new MLMeasures(plugInF1, test));
    System.out.println(cmlcrf);
    NoiseOptimizer noiseOptimizer = new NoiseOptimizer(train, cmlcrf, 1);
    while (!noiseOptimizer.getTerminator().shouldTerminate()) {
        System.out.println("------------");
        noiseOptimizer.iterate();
        System.out.println(noiseOptimizer.objectiveDetail());
        System.out.println("training performance acc");
        System.out.println(new MLMeasures(cmlcrf, train));
        System.out.println("test performance acc");
        System.out.println(new MLMeasures(cmlcrf, test));
        System.out.println("training performance f1");
        System.out.println(new MLMeasures(plugInF1, train));
        System.out.println("test performance f1");
        System.out.println(new MLMeasures(plugInF1, test));
        System.out.println(cmlcrf);
    }
}
Also used : MLScorer(edu.neu.ccs.pyramid.multilabel_classification.MLScorer) AccScorer(edu.neu.ccs.pyramid.multilabel_classification.AccScorer) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)48 File (java.io.File)24 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)23 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)12 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)9 Vector (org.apache.mahout.math.Vector)9 Config (edu.neu.ccs.pyramid.configuration.Config)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 DenseVector (org.apache.mahout.math.DenseVector)7 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 java.util (java.util)5 Collectors (java.util.stream.Collectors)5 IntStream (java.util.stream.IntStream)5 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)4 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)4 MLScorer (edu.neu.ccs.pyramid.multilabel_classification.MLScorer)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 AccScorer (edu.neu.ccs.pyramid.multilabel_classification.AccScorer)3