Search in sources :

Example 66 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportGeneral.

private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("computing other predictor-independent metrics");
    String output = config.getString("output.dir");
    File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
    double labelProbThreshold = config.getDouble("report.labelProbThreshold");
    try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
            br.newLine();
        }
    }
    System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
    List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
    // Here we do not use approximation
    double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
    double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
    File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
    FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
    System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
    System.out.println("average log likelihood of the test ground truth label sets = " + average);
    if (!unobservedLabels.isEmpty()) {
        System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
        System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
        System.out.println(ListUtil.toSimpleString(unobservedLabels));
    }
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 67 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMEN method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 68 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMGB method reportGeneral.

private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("computing other predictor-independent metrics");
    String output = config.getString("output.dir");
    File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
    double labelProbThreshold = config.getDouble("report.labelProbThreshold");
    try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
            br.newLine();
        }
    }
    System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
    List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
    // Here we do not use approximation
    double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
    double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
    File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
    FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
    System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
    System.out.println("average log likelihood of the test ground truth label sets = " + average);
    if (!unobservedLabels.isEmpty()) {
        System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
        System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
        System.out.println(ListUtil.toSimpleString(unobservedLabels));
    }
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 69 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App1 method loadAugmentedLabelTranslator.

static LabelTranslator loadAugmentedLabelTranslator(Config config, MultiLabelIndex index, String[] testIndexIds, LabelTranslator trainLabelTranslator, Logger logger) {
    File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
    Config savedConfig = new Config(new File(metaDataFolder, "saved_config_app1"));
    List<String> extLabels = new ArrayList<>();
    for (int i = 0; i < trainLabelTranslator.getNumClasses(); i++) {
        extLabels.add(trainLabelTranslator.toExtLabel(i));
    }
    Collection<Terms.Bucket> buckets = index.termAggregation(savedConfig.getString("train.label.field"), testIndexIds);
    if (savedConfig.getBoolean("train.label.filter")) {
        String prefix = savedConfig.getString("train.label.filter.prefix");
        buckets = buckets.stream().filter(bucket -> bucket.getKey().startsWith(prefix)).collect(Collectors.toList());
    }
    List<String> newLabels = new ArrayList<>();
    logger.info("label distribution in data set:");
    StringBuilder stringBuilder = new StringBuilder();
    for (Terms.Bucket bucket : buckets) {
        stringBuilder.append(bucket.getKey());
        stringBuilder.append(":");
        stringBuilder.append(bucket.getDocCount());
        stringBuilder.append(", ");
        if (!extLabels.contains(bucket.getKey())) {
            extLabels.add(bucket.getKey());
            newLabels.add(bucket.getKey());
        }
    }
    logger.info(stringBuilder.toString());
    if (!newLabels.isEmpty()) {
        logger.warning("found new labels in data set: " + newLabels);
    }
    return new LabelTranslator(extLabels);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) Terms(org.elasticsearch.search.aggregations.bucket.terms.Terms) File(java.io.File)

Example 70 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class ArffFormat method writeConfigFile.

private static void writeConfigFile(RegDataSet dataSet, File arffFile) {
    File configFile = new File(arffFile, ARFF_CONFIG_FILE_NAME);
    Config config = new Config();
    config.setInt(ARFF_CONFIG_NUM_DATA_POINTS, dataSet.getNumDataPoints());
    config.setInt(ARFF_CONFIG_NUM_FEATURES, dataSet.getNumFeatures());
    try {
        config.store(configFile);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) File(java.io.File) IOException(java.io.IOException)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)119 File (java.io.File)68 Collectors (java.util.stream.Collectors)40 FileUtils (org.apache.commons.io.FileUtils)40 Paths (java.nio.file.Paths)39 IntStream (java.util.stream.IntStream)37 Pair (edu.neu.ccs.pyramid.util.Pair)36 Serialization (edu.neu.ccs.pyramid.util.Serialization)35 StopWatch (org.apache.commons.lang3.time.StopWatch)34 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)33 BufferedWriter (java.io.BufferedWriter)32 FileWriter (java.io.FileWriter)32 java.util (java.util)32 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)31 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)29 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)28 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)26 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)25 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)25 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)22