Search in sources :

Example 16 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class CBMLR method reportF1Prediction.

private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance F1 optimal predictor");
    String output = config.getString("output.dir");
    PluginF1 pluginF1 = new PluginF1(cbm);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
    pluginF1.setSupport(support);
    pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = pluginF1.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance F1 optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 17 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class CBMGB method tune.

private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
    CBM cbm = newCBM(config, trainSet, hyperParameters);
    EarlyStopper earlyStopper = loadNewEarlyStopper(config);
    GBCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    optimizer.initialize();
    MultiLabelClassifier classifier;
    String predictTarget = config.getString("tune.targetMetric");
    switch(predictTarget) {
        case "instance_set_accuracy":
            AccPredictor accPredictor = new AccPredictor(cbm);
            accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor;
            break;
        case "instance_f1":
            PluginF1 pluginF1 = new PluginF1(cbm);
            List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
            pluginF1.setSupport(support);
            pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = pluginF1;
            break;
        case "instance_hamming_loss":
            MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
            marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = marginalPredictor;
            break;
        default:
            throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
    }
    int interval = config.getInt("tune.monitorInterval");
    for (int iter = 1; true; iter++) {
        if (VERBOSE) {
            System.out.println("iteration " + iter);
        }
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(classifier, validSet);
            if (VERBOSE) {
                System.out.println("validation performance with " + predictTarget + " optimal predictor:");
                System.out.println(validMeasures);
            }
            switch(predictTarget) {
                case "instance_set_accuracy":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
                    break;
                case "instance_f1":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
                    break;
                case "instance_hamming_loss":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
                    break;
                default:
                    throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    System.out.println("Early Stopper: the training should stop now!");
                }
                break;
            }
        }
    }
    if (VERBOSE) {
        System.out.println("done!");
    }
    hyperParameters.iterations = earlyStopper.getBestIteration();
    TuneResult tuneResult = new TuneResult();
    tuneResult.hyperParameters = hyperParameters;
    tuneResult.performance = earlyStopper.getBestValue();
    return tuneResult;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 18 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class CBMEN method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 19 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class SparseCBMOptimzerTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
    int numComponents = 10;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setMultiClassClassifierType("lr").setBinaryClassifierType("lr").build();
    SparseCBMOptimzer optimzer = new SparseCBMOptimzer(cbm, dataSet);
    optimzer.initalizeGammaByBM();
    optimzer.updateMultiClassLR();
    optimzer.updateAllBinary();
    //        System.out.println(new MLMeasures(cbm, dataSet));
    System.out.println("test");
    System.out.println(new MLMeasures(cbm, testSet));
    System.out.println("update gamma");
    optimzer.updateGamma();
    optimzer.updateMultiClassLR();
    optimzer.updateAllBinary();
    //        System.out.println(new MLMeasures(cbm, dataSet));
    System.out.println("test");
    System.out.println(new MLMeasures(cbm, testSet));
    System.out.println("update gamma again");
    optimzer.updateGamma();
    optimzer.updateMultiClassLR();
    optimzer.updateAllBinary();
    //        System.out.println(new MLMeasures(cbm, dataSet));
    System.out.println("test");
    System.out.println(new MLMeasures(cbm, testSet));
}
Also used : File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 20 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class CMLCRFTest method test7.

private static void test7() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf = null;
    if (config.getString("train.warmStart").equals("true")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else if (config.getString("train.warmStart").equals("auto")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("retrain model:");
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    } else if (config.getString("train.warmStart").equals("false")) {
        cmlcrf = new CMLCRF(trainSet);
        cmlcrf.setConsiderPair(config.getBoolean("considerLabelPair"));
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
    System.out.println("========== Train ==========\n");
    System.out.println(measures);
    System.out.println("========== Test ==========\n");
    long startTimePred = System.nanoTime();
    MultiLabel[] preds = cmlcrf.predict(testSet);
    long stopTimePred = System.nanoTime();
    long predTime = stopTimePred - startTimePred;
    System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
    System.out.println(new MLMeasures(cmlcrf, testSet));
    System.out.println("\n\n");
    InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println("Plugin F1");
    System.out.println(new MLMeasures(pluginF1, testSet));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)24 File (java.io.File)17 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)12 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)12 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)11 Config (edu.neu.ccs.pyramid.configuration.Config)10 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)10 Serialization (edu.neu.ccs.pyramid.util.Serialization)10 Paths (java.nio.file.Paths)10 Collectors (java.util.stream.Collectors)10 IntStream (java.util.stream.IntStream)10 FileUtils (org.apache.commons.io.FileUtils)10 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)9 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)9 Pair (edu.neu.ccs.pyramid.util.Pair)9 BufferedWriter (java.io.BufferedWriter)9 FileWriter (java.io.FileWriter)9 java.util (java.util)9 StopWatch (org.apache.commons.lang3.time.StopWatch)9 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)7